ONNX-based Transformer models for text embeddings within the Spring AI framework
Practical examples of using Spring AI Transformers in real applications.
Build a semantic search engine that finds documents by meaning, not just keywords.
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@Service
public class SemanticSearchService {
private final TransformersEmbeddingModel model;
private final Map<String, float[]> documentEmbeddings;
private final Map<String, String> documents;
public SemanticSearchService(TransformersEmbeddingModel model) {
this.model = model;
this.documentEmbeddings = new HashMap<>();
this.documents = new HashMap<>();
}
// Index documents
public void indexDocuments(Map<String, String> docs) {
List<String> texts = new ArrayList<>(docs.values());
List<float[]> embeddings = model.embed(texts);
int i = 0;
for (Map.Entry<String, String> entry : docs.entrySet()) {
String id = entry.getKey();
documents.put(id, entry.getValue());
documentEmbeddings.put(id, embeddings.get(i++));
}
}
// Search by query
public List<SearchResult> search(String query, int topK) {
float[] queryEmbedding = model.embed(query);
List<SearchResult> results = new ArrayList<>();
for (Map.Entry<String, float[]> entry : documentEmbeddings.entrySet()) {
String id = entry.getKey();
float similarity = cosineSimilarity(queryEmbedding, entry.getValue());
results.add(new SearchResult(id, documents.get(id), similarity));
}
return results.stream()
.sorted(Comparator.comparing(SearchResult::similarity).reversed())
.limit(topK)
.collect(Collectors.toList());
}
private float cosineSimilarity(float[] a, float[] b) {
float dotProduct = 0f;
float normA = 0f;
float normB = 0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
}
public record SearchResult(String id, String text, float similarity) {}
}Usage:
@RestController
@RequestMapping("/api/search")
public class SearchController {
private final SemanticSearchService searchService;
public SearchController(SemanticSearchService searchService) {
this.searchService = searchService;
}
@PostMapping("/index")
public void index(@RequestBody Map<String, String> documents) {
searchService.indexDocuments(documents);
}
@GetMapping
public List<SearchResult> search(
@RequestParam String query,
@RequestParam(defaultValue = "10") int topK) {
return searchService.search(query, topK);
}
}Group similar texts together using K-means clustering on embeddings.
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
@Service
public class TextClusteringService {
private final TransformersEmbeddingModel model;
public TextClusteringService(TransformersEmbeddingModel model) {
this.model = model;
}
public Map<Integer, List<String>> cluster(List<String> texts, int numClusters) {
// Generate embeddings
List<float[]> embeddings = model.embed(texts);
// K-means clustering
List<float[]> centroids = initializeCentroids(embeddings, numClusters);
List<Integer> assignments = new ArrayList<>(Collections.nCopies(texts.size(), 0));
// Iterate until convergence
for (int iter = 0; iter < 100; iter++) {
boolean changed = false;
// Assign to nearest centroid
for (int i = 0; i < embeddings.size(); i++) {
int nearest = findNearestCentroid(embeddings.get(i), centroids);
if (assignments.get(i) != nearest) {
assignments.set(i, nearest);
changed = true;
}
}
if (!changed) break;
// Update centroids
updateCentroids(embeddings, assignments, centroids);
}
// Group texts by cluster
Map<Integer, List<String>> clusters = new HashMap<>();
for (int i = 0; i < numClusters; i++) {
clusters.put(i, new ArrayList<>());
}
for (int i = 0; i < texts.size(); i++) {
clusters.get(assignments.get(i)).add(texts.get(i));
}
return clusters;
}
private List<float[]> initializeCentroids(List<float[]> embeddings, int k) {
Random random = new Random();
List<float[]> centroids = new ArrayList<>();
Set<Integer> selected = new HashSet<>();
while (centroids.size() < k) {
int idx = random.nextInt(embeddings.size());
if (selected.add(idx)) {
centroids.add(embeddings.get(idx).clone());
}
}
return centroids;
}
private int findNearestCentroid(float[] embedding, List<float[]> centroids) {
int nearest = 0;
float minDist = Float.MAX_VALUE;
for (int i = 0; i < centroids.size(); i++) {
float dist = euclideanDistance(embedding, centroids.get(i));
if (dist < minDist) {
minDist = dist;
nearest = i;
}
}
return nearest;
}
private void updateCentroids(
List<float[]> embeddings,
List<Integer> assignments,
List<float[]> centroids) {
int dims = embeddings.get(0).length;
for (int k = 0; k < centroids.length; k++) {
float[] sum = new float[dims];
int count = 0;
for (int i = 0; i < embeddings.size(); i++) {
if (assignments.get(i) == k) {
for (int d = 0; d < dims; d++) {
sum[d] += embeddings.get(i)[d];
}
count++;
}
}
if (count > 0) {
for (int d = 0; d < dims; d++) {
centroids.get(k)[d] = sum[d] / count;
}
}
}
}
private float euclideanDistance(float[] a, float[] b) {
float sum = 0;
for (int i = 0; i < a.length; i++) {
float diff = a[i] - b[i];
sum += diff * diff;
}
return (float) Math.sqrt(sum);
}
}Find and remove duplicate or near-duplicate texts.
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
@Service
public class DeduplicationService {
private final TransformersEmbeddingModel model;
private final float similarityThreshold;
public DeduplicationService(TransformersEmbeddingModel model) {
this.model = model;
this.similarityThreshold = 0.95f; // 95% similarity threshold
}
public List<String> deduplicate(List<String> texts) {
if (texts.isEmpty()) return new ArrayList<>();
// Generate embeddings
List<float[]> embeddings = model.embed(texts);
// Find unique texts
List<String> unique = new ArrayList<>();
List<float[]> uniqueEmbeddings = new ArrayList<>();
for (int i = 0; i < texts.size(); i++) {
String text = texts.get(i);
float[] embedding = embeddings.get(i);
boolean isDuplicate = false;
for (float[] uniqueEmbedding : uniqueEmbeddings) {
float similarity = cosineSimilarity(embedding, uniqueEmbedding);
if (similarity >= similarityThreshold) {
isDuplicate = true;
break;
}
}
if (!isDuplicate) {
unique.add(text);
uniqueEmbeddings.add(embedding);
}
}
return unique;
}
public Map<String, List<String>> findDuplicateGroups(List<String> texts) {
List<float[]> embeddings = model.embed(texts);
Map<String, List<String>> groups = new HashMap<>();
boolean[] processed = new boolean[texts.size()];
for (int i = 0; i < texts.size(); i++) {
if (processed[i]) continue;
List<String> group = new ArrayList<>();
group.add(texts.get(i));
processed[i] = true;
for (int j = i + 1; j < texts.size(); j++) {
if (processed[j]) continue;
float similarity = cosineSimilarity(embeddings.get(i), embeddings.get(j));
if (similarity >= similarityThreshold) {
group.add(texts.get(j));
processed[j] = true;
}
}
if (group.size() > 1) {
groups.put(texts.get(i), group);
}
}
return groups;
}
private float cosineSimilarity(float[] a, float[] b) {
float dotProduct = 0f;
float normA = 0f;
float normB = 0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
}
}Recommend similar content based on user preferences.
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@Service
public class RecommendationService {
private final TransformersEmbeddingModel model;
private final Map<String, float[]> contentEmbeddings;
private final Map<String, ContentItem> contentItems;
public RecommendationService(TransformersEmbeddingModel model) {
this.model = model;
this.contentEmbeddings = new HashMap<>();
this.contentItems = new HashMap<>();
}
public void indexContent(List<ContentItem> items) {
List<String> texts = items.stream()
.map(ContentItem::description)
.collect(Collectors.toList());
List<float[]> embeddings = model.embed(texts);
for (int i = 0; i < items.size(); i++) {
ContentItem item = items.get(i);
contentItems.put(item.id(), item);
contentEmbeddings.put(item.id(), embeddings.get(i));
}
}
public List<ContentItem> recommend(String userId, int topK) {
// Get user's interaction history
List<String> likedItems = getUserLikedItems(userId);
if (likedItems.isEmpty()) {
return getPopularItems(topK);
}
// Compute user preference vector (average of liked items)
float[] userVector = computeUserVector(likedItems);
// Find similar items
List<ScoredItem> scored = new ArrayList<>();
for (Map.Entry<String, float[]> entry : contentEmbeddings.entrySet()) {
String itemId = entry.getKey();
// Skip already liked items
if (likedItems.contains(itemId)) continue;
float similarity = cosineSimilarity(userVector, entry.getValue());
scored.add(new ScoredItem(itemId, similarity));
}
return scored.stream()
.sorted(Comparator.comparing(ScoredItem::score).reversed())
.limit(topK)
.map(s -> contentItems.get(s.itemId()))
.collect(Collectors.toList());
}
private float[] computeUserVector(List<String> likedItems) {
int dims = model.dimensions();
float[] userVector = new float[dims];
for (String itemId : likedItems) {
float[] embedding = contentEmbeddings.get(itemId);
if (embedding != null) {
for (int i = 0; i < dims; i++) {
userVector[i] += embedding[i];
}
}
}
// Normalize
float norm = 0;
for (float v : userVector) {
norm += v * v;
}
norm = (float) Math.sqrt(norm);
if (norm > 0) {
for (int i = 0; i < dims; i++) {
userVector[i] /= norm;
}
}
return userVector;
}
private float cosineSimilarity(float[] a, float[] b) {
float dotProduct = 0f;
float normA = 0f;
float normB = 0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
}
private List<String> getUserLikedItems(String userId) {
// Implementation depends on your data store
return new ArrayList<>();
}
private List<ContentItem> getPopularItems(int topK) {
// Implementation depends on your data store
return new ArrayList<>();
}
public record ContentItem(String id, String title, String description) {}
private record ScoredItem(String itemId, float score) {}
}Find relevant passages for question answering.
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
import java.util.stream.Collectors;
@Service
public class QuestionAnsweringService {
private final TransformersEmbeddingModel model;
private final Map<String, float[]> passageEmbeddings;
private final Map<String, String> passages;
public QuestionAnsweringService(TransformersEmbeddingModel model) {
this.model = model;
this.passageEmbeddings = new HashMap<>();
this.passages = new HashMap<>();
}
public void indexPassages(Map<String, String> docs) {
// Split documents into passages
Map<String, String> allPassages = new HashMap<>();
for (Map.Entry<String, String> entry : docs.entrySet()) {
String docId = entry.getKey();
List<String> docPassages = splitIntoPassages(entry.getValue());
for (int i = 0; i < docPassages.size(); i++) {
String passageId = docId + "_p" + i;
allPassages.put(passageId, docPassages.get(i));
}
}
// Generate embeddings
List<String> passageTexts = new ArrayList<>(allPassages.values());
List<float[]> embeddings = model.embed(passageTexts);
int i = 0;
for (Map.Entry<String, String> entry : allPassages.entrySet()) {
passages.put(entry.getKey(), entry.getValue());
passageEmbeddings.put(entry.getKey(), embeddings.get(i++));
}
}
public List<RelevantPassage> findRelevantPassages(String question, int topK) {
float[] questionEmbedding = model.embed(question);
List<RelevantPassage> results = new ArrayList<>();
for (Map.Entry<String, float[]> entry : passageEmbeddings.entrySet()) {
String id = entry.getKey();
float similarity = cosineSimilarity(questionEmbedding, entry.getValue());
results.add(new RelevantPassage(id, passages.get(id), similarity));
}
return results.stream()
.sorted(Comparator.comparing(RelevantPassage::similarity).reversed())
.limit(topK)
.collect(Collectors.toList());
}
private List<String> splitIntoPassages(String text) {
// Simple sentence-based splitting
String[] sentences = text.split("\\. ");
List<String> passages = new ArrayList<>();
StringBuilder current = new StringBuilder();
int sentenceCount = 0;
for (String sentence : sentences) {
current.append(sentence).append(". ");
sentenceCount++;
// Create passage every 3-5 sentences
if (sentenceCount >= 3) {
passages.add(current.toString().trim());
current = new StringBuilder();
sentenceCount = 0;
}
}
if (current.length() > 0) {
passages.add(current.toString().trim());
}
return passages;
}
private float cosineSimilarity(float[] a, float[] b) {
float dotProduct = 0f;
float normA = 0f;
float normB = 0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
}
public record RelevantPassage(String id, String text, float similarity) {}
}Classify documents using embedding similarity to category prototypes.
import org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
import java.util.*;
@Service
public class DocumentClassificationService {
private final TransformersEmbeddingModel model;
private final Map<String, float[]> categoryPrototypes;
public DocumentClassificationService(TransformersEmbeddingModel model) {
this.model = model;
this.categoryPrototypes = new HashMap<>();
}
public void trainCategories(Map<String, List<String>> trainingData) {
for (Map.Entry<String, List<String>> entry : trainingData.entrySet()) {
String category = entry.getKey();
List<String> examples = entry.getValue();
// Generate embeddings for examples
List<float[]> embeddings = model.embed(examples);
// Compute prototype (centroid)
float[] prototype = computeCentroid(embeddings);
categoryPrototypes.put(category, prototype);
}
}
public String classify(String text) {
float[] embedding = model.embed(text);
String bestCategory = null;
float bestSimilarity = Float.NEGATIVE_INFINITY;
for (Map.Entry<String, float[]> entry : categoryPrototypes.entrySet()) {
float similarity = cosineSimilarity(embedding, entry.getValue());
if (similarity > bestSimilarity) {
bestSimilarity = similarity;
bestCategory = entry.getKey();
}
}
return bestCategory;
}
public Map<String, Float> classifyWithScores(String text) {
float[] embedding = model.embed(text);
Map<String, Float> scores = new HashMap<>();
for (Map.Entry<String, float[]> entry : categoryPrototypes.entrySet()) {
float similarity = cosineSimilarity(embedding, entry.getValue());
scores.put(entry.getKey(), similarity);
}
return scores;
}
private float[] computeCentroid(List<float[]> embeddings) {
int dims = embeddings.get(0).length;
float[] centroid = new float[dims];
for (float[] embedding : embeddings) {
for (int i = 0; i < dims; i++) {
centroid[i] += embedding[i];
}
}
for (int i = 0; i < dims; i++) {
centroid[i] /= embeddings.size();
}
return centroid;
}
private float cosineSimilarity(float[] a, float[] b) {
float dotProduct = 0f;
float normA = 0f;
float normB = 0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
}
}tessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1