CtrlK
CommunityDocumentationLog inGet started
Tessl Logo

tessl/maven-org-springframework-ai--spring-ai-transformers

ONNX-based Transformer models for text embeddings within the Spring AI framework

Overview
Eval results
Files

real-world-scenarios.mddocs/examples/

Real-World Scenarios

Practical examples of using Spring AI Transformers in real applications.

Semantic Search

Build a semantic search engine that finds documents by meaning, not just keywords.

import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;

@Service
public class SemanticSearchService {
    
    private final TransformersEmbeddingModel model;
    private final Map<String, float[]> documentEmbeddings;
    private final Map<String, String> documents;
    
    public SemanticSearchService(TransformersEmbeddingModel model) {
        this.model = model;
        this.documentEmbeddings = new HashMap<>();
        this.documents = new HashMap<>();
    }
    
    // Index documents
    public void indexDocuments(Map<String, String> docs) {
        List<String> texts = new ArrayList<>(docs.values());
        List<float[]> embeddings = model.embed(texts);
        
        int i = 0;
        for (Map.Entry<String, String> entry : docs.entrySet()) {
            String id = entry.getKey();
            documents.put(id, entry.getValue());
            documentEmbeddings.put(id, embeddings.get(i++));
        }
    }
    
    // Search by query
    public List<SearchResult> search(String query, int topK) {
        float[] queryEmbedding = model.embed(query);
        
        List<SearchResult> results = new ArrayList<>();
        for (Map.Entry<String, float[]> entry : documentEmbeddings.entrySet()) {
            String id = entry.getKey();
            float similarity = cosineSimilarity(queryEmbedding, entry.getValue());
            results.add(new SearchResult(id, documents.get(id), similarity));
        }
        
        return results.stream()
            .sorted(Comparator.comparing(SearchResult::similarity).reversed())
            .limit(topK)
            .collect(Collectors.toList());
    }
    
    private float cosineSimilarity(float[] a, float[] b) {
        float dotProduct = 0f;
        float normA = 0f;
        float normB = 0f;
        
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        
        return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
    }
    
    public record SearchResult(String id, String text, float similarity) {}
}

Usage:

@RestController
@RequestMapping("/api/search")
public class SearchController {
    
    private final SemanticSearchService searchService;
    
    public SearchController(SemanticSearchService searchService) {
        this.searchService = searchService;
    }
    
    @PostMapping("/index")
    public void index(@RequestBody Map<String, String> documents) {
        searchService.indexDocuments(documents);
    }
    
    @GetMapping
    public List<SearchResult> search(
            @RequestParam String query,
            @RequestParam(defaultValue = "10") int topK) {
        return searchService.search(query, topK);
    }
}

Text Clustering

Group similar texts together using K-means clustering on embeddings.

import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;

@Service
public class TextClusteringService {
    
    private final TransformersEmbeddingModel model;
    
    public TextClusteringService(TransformersEmbeddingModel model) {
        this.model = model;
    }
    
    public Map<Integer, List<String>> cluster(List<String> texts, int numClusters) {
        // Generate embeddings
        List<float[]> embeddings = model.embed(texts);
        
        // K-means clustering
        List<float[]> centroids = initializeCentroids(embeddings, numClusters);
        List<Integer> assignments = new ArrayList<>(Collections.nCopies(texts.size(), 0));
        
        // Iterate until convergence
        for (int iter = 0; iter < 100; iter++) {
            boolean changed = false;
            
            // Assign to nearest centroid
            for (int i = 0; i < embeddings.size(); i++) {
                int nearest = findNearestCentroid(embeddings.get(i), centroids);
                if (assignments.get(i) != nearest) {
                    assignments.set(i, nearest);
                    changed = true;
                }
            }
            
            if (!changed) break;
            
            // Update centroids
            updateCentroids(embeddings, assignments, centroids);
        }
        
        // Group texts by cluster
        Map<Integer, List<String>> clusters = new HashMap<>();
        for (int i = 0; i < numClusters; i++) {
            clusters.put(i, new ArrayList<>());
        }
        
        for (int i = 0; i < texts.size(); i++) {
            clusters.get(assignments.get(i)).add(texts.get(i));
        }
        
        return clusters;
    }
    
    private List<float[]> initializeCentroids(List<float[]> embeddings, int k) {
        Random random = new Random();
        List<float[]> centroids = new ArrayList<>();
        Set<Integer> selected = new HashSet<>();
        
        while (centroids.size() < k) {
            int idx = random.nextInt(embeddings.size());
            if (selected.add(idx)) {
                centroids.add(embeddings.get(idx).clone());
            }
        }
        
        return centroids;
    }
    
    private int findNearestCentroid(float[] embedding, List<float[]> centroids) {
        int nearest = 0;
        float minDist = Float.MAX_VALUE;
        
        for (int i = 0; i < centroids.size(); i++) {
            float dist = euclideanDistance(embedding, centroids.get(i));
            if (dist < minDist) {
                minDist = dist;
                nearest = i;
            }
        }
        
        return nearest;
    }
    
    private void updateCentroids(
            List<float[]> embeddings,
            List<Integer> assignments,
            List<float[]> centroids) {
        
        int dims = embeddings.get(0).length;
        
        for (int k = 0; k < centroids.length; k++) {
            float[] sum = new float[dims];
            int count = 0;
            
            for (int i = 0; i < embeddings.size(); i++) {
                if (assignments.get(i) == k) {
                    for (int d = 0; d < dims; d++) {
                        sum[d] += embeddings.get(i)[d];
                    }
                    count++;
                }
            }
            
            if (count > 0) {
                for (int d = 0; d < dims; d++) {
                    centroids.get(k)[d] = sum[d] / count;
                }
            }
        }
    }
    
    private float euclideanDistance(float[] a, float[] b) {
        float sum = 0;
        for (int i = 0; i < a.length; i++) {
            float diff = a[i] - b[i];
            sum += diff * diff;
        }
        return (float) Math.sqrt(sum);
    }
}

Duplicate Detection

Find and remove duplicate or near-duplicate texts.

import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;

@Service
public class DeduplicationService {
    
    private final TransformersEmbeddingModel model;
    private final float similarityThreshold;
    
    public DeduplicationService(TransformersEmbeddingModel model) {
        this.model = model;
        this.similarityThreshold = 0.95f; // 95% similarity threshold
    }
    
    public List<String> deduplicate(List<String> texts) {
        if (texts.isEmpty()) return new ArrayList<>();
        
        // Generate embeddings
        List<float[]> embeddings = model.embed(texts);
        
        // Find unique texts
        List<String> unique = new ArrayList<>();
        List<float[]> uniqueEmbeddings = new ArrayList<>();
        
        for (int i = 0; i < texts.size(); i++) {
            String text = texts.get(i);
            float[] embedding = embeddings.get(i);
            
            boolean isDuplicate = false;
            for (float[] uniqueEmbedding : uniqueEmbeddings) {
                float similarity = cosineSimilarity(embedding, uniqueEmbedding);
                if (similarity >= similarityThreshold) {
                    isDuplicate = true;
                    break;
                }
            }
            
            if (!isDuplicate) {
                unique.add(text);
                uniqueEmbeddings.add(embedding);
            }
        }
        
        return unique;
    }
    
    public Map<String, List<String>> findDuplicateGroups(List<String> texts) {
        List<float[]> embeddings = model.embed(texts);
        
        Map<String, List<String>> groups = new HashMap<>();
        boolean[] processed = new boolean[texts.size()];
        
        for (int i = 0; i < texts.size(); i++) {
            if (processed[i]) continue;
            
            List<String> group = new ArrayList<>();
            group.add(texts.get(i));
            processed[i] = true;
            
            for (int j = i + 1; j < texts.size(); j++) {
                if (processed[j]) continue;
                
                float similarity = cosineSimilarity(embeddings.get(i), embeddings.get(j));
                if (similarity >= similarityThreshold) {
                    group.add(texts.get(j));
                    processed[j] = true;
                }
            }
            
            if (group.size() > 1) {
                groups.put(texts.get(i), group);
            }
        }
        
        return groups;
    }
    
    private float cosineSimilarity(float[] a, float[] b) {
        float dotProduct = 0f;
        float normA = 0f;
        float normB = 0f;
        
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        
        return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
    }
}

Content Recommendation

Recommend similar content based on user preferences.

import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;

@Service
public class RecommendationService {
    
    private final TransformersEmbeddingModel model;
    private final Map<String, float[]> contentEmbeddings;
    private final Map<String, ContentItem> contentItems;
    
    public RecommendationService(TransformersEmbeddingModel model) {
        this.model = model;
        this.contentEmbeddings = new HashMap<>();
        this.contentItems = new HashMap<>();
    }
    
    public void indexContent(List<ContentItem> items) {
        List<String> texts = items.stream()
            .map(ContentItem::description)
            .collect(Collectors.toList());
        
        List<float[]> embeddings = model.embed(texts);
        
        for (int i = 0; i < items.size(); i++) {
            ContentItem item = items.get(i);
            contentItems.put(item.id(), item);
            contentEmbeddings.put(item.id(), embeddings.get(i));
        }
    }
    
    public List<ContentItem> recommend(String userId, int topK) {
        // Get user's interaction history
        List<String> likedItems = getUserLikedItems(userId);
        
        if (likedItems.isEmpty()) {
            return getPopularItems(topK);
        }
        
        // Compute user preference vector (average of liked items)
        float[] userVector = computeUserVector(likedItems);
        
        // Find similar items
        List<ScoredItem> scored = new ArrayList<>();
        for (Map.Entry<String, float[]> entry : contentEmbeddings.entrySet()) {
            String itemId = entry.getKey();
            
            // Skip already liked items
            if (likedItems.contains(itemId)) continue;
            
            float similarity = cosineSimilarity(userVector, entry.getValue());
            scored.add(new ScoredItem(itemId, similarity));
        }
        
        return scored.stream()
            .sorted(Comparator.comparing(ScoredItem::score).reversed())
            .limit(topK)
            .map(s -> contentItems.get(s.itemId()))
            .collect(Collectors.toList());
    }
    
    private float[] computeUserVector(List<String> likedItems) {
        int dims = model.dimensions();
        float[] userVector = new float[dims];
        
        for (String itemId : likedItems) {
            float[] embedding = contentEmbeddings.get(itemId);
            if (embedding != null) {
                for (int i = 0; i < dims; i++) {
                    userVector[i] += embedding[i];
                }
            }
        }
        
        // Normalize
        float norm = 0;
        for (float v : userVector) {
            norm += v * v;
        }
        norm = (float) Math.sqrt(norm);
        
        if (norm > 0) {
            for (int i = 0; i < dims; i++) {
                userVector[i] /= norm;
            }
        }
        
        return userVector;
    }
    
    private float cosineSimilarity(float[] a, float[] b) {
        float dotProduct = 0f;
        float normA = 0f;
        float normB = 0f;
        
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        
        return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
    }
    
    private List<String> getUserLikedItems(String userId) {
        // Implementation depends on your data store
        return new ArrayList<>();
    }
    
    private List<ContentItem> getPopularItems(int topK) {
        // Implementation depends on your data store
        return new ArrayList<>();
    }
    
    public record ContentItem(String id, String title, String description) {}
    private record ScoredItem(String itemId, float score) {}
}

Question Answering

Find relevant passages for question answering.

import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;

@Service
public class QuestionAnsweringService {
    
    private final TransformersEmbeddingModel model;
    private final Map<String, float[]> passageEmbeddings;
    private final Map<String, String> passages;
    
    public QuestionAnsweringService(TransformersEmbeddingModel model) {
        this.model = model;
        this.passageEmbeddings = new HashMap<>();
        this.passages = new HashMap<>();
    }
    
    public void indexPassages(Map<String, String> docs) {
        // Split documents into passages
        Map<String, String> allPassages = new HashMap<>();
        for (Map.Entry<String, String> entry : docs.entrySet()) {
            String docId = entry.getKey();
            List<String> docPassages = splitIntoPassages(entry.getValue());
            
            for (int i = 0; i < docPassages.size(); i++) {
                String passageId = docId + "_p" + i;
                allPassages.put(passageId, docPassages.get(i));
            }
        }
        
        // Generate embeddings
        List<String> passageTexts = new ArrayList<>(allPassages.values());
        List<float[]> embeddings = model.embed(passageTexts);
        
        int i = 0;
        for (Map.Entry<String, String> entry : allPassages.entrySet()) {
            passages.put(entry.getKey(), entry.getValue());
            passageEmbeddings.put(entry.getKey(), embeddings.get(i++));
        }
    }
    
    public List<RelevantPassage> findRelevantPassages(String question, int topK) {
        float[] questionEmbedding = model.embed(question);
        
        List<RelevantPassage> results = new ArrayList<>();
        for (Map.Entry<String, float[]> entry : passageEmbeddings.entrySet()) {
            String id = entry.getKey();
            float similarity = cosineSimilarity(questionEmbedding, entry.getValue());
            results.add(new RelevantPassage(id, passages.get(id), similarity));
        }
        
        return results.stream()
            .sorted(Comparator.comparing(RelevantPassage::similarity).reversed())
            .limit(topK)
            .collect(Collectors.toList());
    }
    
    private List<String> splitIntoPassages(String text) {
        // Simple sentence-based splitting
        String[] sentences = text.split("\\. ");
        List<String> passages = new ArrayList<>();
        
        StringBuilder current = new StringBuilder();
        int sentenceCount = 0;
        
        for (String sentence : sentences) {
            current.append(sentence).append(". ");
            sentenceCount++;
            
            // Create passage every 3-5 sentences
            if (sentenceCount >= 3) {
                passages.add(current.toString().trim());
                current = new StringBuilder();
                sentenceCount = 0;
            }
        }
        
        if (current.length() > 0) {
            passages.add(current.toString().trim());
        }
        
        return passages;
    }
    
    private float cosineSimilarity(float[] a, float[] b) {
        float dotProduct = 0f;
        float normA = 0f;
        float normB = 0f;
        
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        
        return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
    }
    
    public record RelevantPassage(String id, String text, float similarity) {}
}

Document Classification

Classify documents using embedding similarity to category prototypes.

import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;

@Service
public class DocumentClassificationService {
    
    private final TransformersEmbeddingModel model;
    private final Map<String, float[]> categoryPrototypes;
    
    public DocumentClassificationService(TransformersEmbeddingModel model) {
        this.model = model;
        this.categoryPrototypes = new HashMap<>();
    }
    
    public void trainCategories(Map<String, List<String>> trainingData) {
        for (Map.Entry<String, List<String>> entry : trainingData.entrySet()) {
            String category = entry.getKey();
            List<String> examples = entry.getValue();
            
            // Generate embeddings for examples
            List<float[]> embeddings = model.embed(examples);
            
            // Compute prototype (centroid)
            float[] prototype = computeCentroid(embeddings);
            categoryPrototypes.put(category, prototype);
        }
    }
    
    public String classify(String text) {
        float[] embedding = model.embed(text);
        
        String bestCategory = null;
        float bestSimilarity = Float.NEGATIVE_INFINITY;
        
        for (Map.Entry<String, float[]> entry : categoryPrototypes.entrySet()) {
            float similarity = cosineSimilarity(embedding, entry.getValue());
            if (similarity > bestSimilarity) {
                bestSimilarity = similarity;
                bestCategory = entry.getKey();
            }
        }
        
        return bestCategory;
    }
    
    public Map<String, Float> classifyWithScores(String text) {
        float[] embedding = model.embed(text);
        
        Map<String, Float> scores = new HashMap<>();
        for (Map.Entry<String, float[]> entry : categoryPrototypes.entrySet()) {
            float similarity = cosineSimilarity(embedding, entry.getValue());
            scores.put(entry.getKey(), similarity);
        }
        
        return scores;
    }
    
    private float[] computeCentroid(List<float[]> embeddings) {
        int dims = embeddings.get(0).length;
        float[] centroid = new float[dims];
        
        for (float[] embedding : embeddings) {
            for (int i = 0; i < dims; i++) {
                centroid[i] += embedding[i];
            }
        }
        
        for (int i = 0; i < dims; i++) {
            centroid[i] /= embeddings.size();
        }
        
        return centroid;
    }
    
    private float cosineSimilarity(float[] a, float[] b) {
        float dotProduct = 0f;
        float normA = 0f;
        float normB = 0f;
        
        for (int i = 0; i < a.length; i++) {
            dotProduct += a[i] * b[i];
            normA += a[i] * a[i];
            normB += b[i] * b[i];
        }
        
        return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
    }
}

See Also

  • Edge Cases - Advanced scenarios and corner cases
  • Quick Start Guide - Getting started
  • API Reference - Complete API documentation
tessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1

docs

examples

edge-cases.md

real-world-scenarios.md

index.md

tile.json