ONNX-based Transformer models for text embeddings within the Spring AI framework
Handling special cases, edge conditions, and advanced usage patterns.
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
// Empty string - returns valid embedding
float[] emptyEmbedding = model.embed("");
assert emptyEmbedding.length == 384;
// Whitespace-only text
float[] whitespaceEmbedding = model.embed(" \n\t ");
// Valid embedding, likely similar to empty string
// Empty list
List<float[]> emptyResults = model.embed(List.of());
assert emptyResults.isEmpty();// Null text - throws IllegalArgumentException
try {
model.embed((String) null);
fail("Should throw exception");
} catch (IllegalArgumentException e) {
// Expected
}
// Null list - throws IllegalArgumentException
try {
model.embed((List<String>) null);
fail("Should throw exception");
} catch (IllegalArgumentException e) {
// Expected
}
// List with null elements - filter before embedding
List<String> textsWithNull = Arrays.asList("Valid", null, "Another");
List<String> filtered = textsWithNull.stream()
.filter(Objects::nonNull)
.collect(Collectors.toList());
List<float[]> embeddings = model.embed(filtered);// Default model truncates at 512 tokens
String veryLongText = "word ".repeat(1000); // 1000 words
float[] embedding = model.embed(veryLongText);
// Automatically truncated, no error thrown
// Check if text will be truncated
int estimatedTokens = veryLongText.split("\\s+").length;
if (estimatedTokens > 512) {
System.out.println("Text will be truncated");
}public List<float[]> embedLongDocument(String longText, int windowSize, int stride) {
String[] words = longText.split("\\s+");
List<float[]> embeddings = new ArrayList<>();
for (int i = 0; i < words.length; i += stride) {
int end = Math.min(i + windowSize, words.length);
String window = String.join(" ", Arrays.copyOfRange(words, i, end));
embeddings.add(model.embed(window));
if (end >= words.length) break;
}
return embeddings;
}
// Usage
List<float[]> windowEmbeddings = embedLongDocument(longText, 400, 200);
// Average or concatenate embeddings as needed// Emojis
float[] emojiEmbedding = model.embed("Hello 🌍 😊");
// Correctly tokenized and embedded
// Non-Latin scripts
float[] chineseEmbedding = model.embed("你好世界");
float[] arabicEmbedding = model.embed("مرحبا بالعالم");
float[] cyrillicEmbedding = model.embed("Привет мир");
// All handled correctly
// Mixed scripts
float[] mixedEmbedding = model.embed("Hello 世界 🌍");// Newlines and tabs
float[] multilineEmbedding = model.embed("Line 1\nLine 2\tTabbed");
// Control characters tokenized as-is
// Special punctuation
float[] punctuationEmbedding = model.embed("Hello! How are you? I'm fine.");
// Punctuation preserved in tokenizationTransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
// Safe to use from multiple threads
ExecutorService executor = Executors.newFixedThreadPool(10);
List<Future<float[]>> futures = new ArrayList<>();
for (int i = 0; i < 100; i++) {
final String text = "Text " + i;
futures.add(executor.submit(() -> model.embed(text)));
}
// Collect results
for (Future<float[]> future : futures) {
float[] embedding = future.get();
// Process embedding
}
executor.shutdown();// DON'T: Initialize from multiple threads
// BAD PATTERN
ExecutorService executor = Executors.newFixedThreadPool(2);
executor.submit(() -> {
TransformersEmbeddingModel model1 = new TransformersEmbeddingModel();
model1.afterPropertiesSet(); // May conflict with other thread
});
executor.submit(() -> {
TransformersEmbeddingModel model2 = new TransformersEmbeddingModel();
model2.afterPropertiesSet(); // May conflict with other thread
});
// DO: Initialize once, share instance
TransformersEmbeddingModel sharedModel = new TransformersEmbeddingModel();
sharedModel.afterPropertiesSet();
executor.submit(() -> sharedModel.embed("text1"));
executor.submit(() -> sharedModel.embed("text2"));// Process large datasets in chunks to avoid OOM
public List<float[]> embedLargeDataset(List<String> texts, int batchSize) {
List<float[]> allEmbeddings = new ArrayList<>();
for (int i = 0; i < texts.size(); i += batchSize) {
int end = Math.min(i + batchSize, texts.size());
List<String> batch = texts.subList(i, end);
List<float[]> batchEmbeddings = model.embed(batch);
allEmbeddings.addAll(batchEmbeddings);
// Optional: Clear references to allow GC
batchEmbeddings = null;
System.gc();
}
return allEmbeddings;
}public void processStream(Stream<String> textStream) {
textStream
.map(model::embed)
.forEach(embedding -> {
// Process embedding immediately
storeEmbedding(embedding);
// Embedding can be garbage collected after this
});
}public TransformersEmbeddingModel createModelWithFallback() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(0);
try {
model.afterPropertiesSet();
System.out.println("Using GPU");
return model;
} catch (Exception e) {
if (e.getMessage().contains("CUDA") || e.getMessage().contains("GPU")) {
System.out.println("GPU failed, falling back to CPU");
model = new TransformersEmbeddingModel();
model.setGpuDeviceId(-1);
model.afterPropertiesSet();
return model;
}
throw e;
}
}public class MultiGPUEmbeddingService {
private final List<TransformersEmbeddingModel> models;
private final AtomicInteger counter = new AtomicInteger(0);
public MultiGPUEmbeddingService(int numGPUs) throws Exception {
models = new ArrayList<>();
for (int i = 0; i < numGPUs; i++) {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(i);
model.afterPropertiesSet();
models.add(model);
}
}
public float[] embed(String text) {
// Round-robin across GPUs
int idx = counter.getAndIncrement() % models.size();
return models.get(idx).embed(text);
}
}@Component
public class CacheWarmer implements ApplicationRunner {
private final TransformersEmbeddingModel model;
public CacheWarmer(TransformersEmbeddingModel model) {
this.model = model;
}
@Override
public void run(ApplicationArguments args) {
// Warm up model with dummy embedding
model.embed("warmup");
System.out.println("Model cache warmed");
}
}@Scheduled(cron = "0 0 2 * * *") // 2 AM daily
public void cleanupCache() {
File cacheDir = new File("/var/cache/spring-ai");
long size = calculateDirectorySize(cacheDir);
long maxSize = 10L * 1024 * 1024 * 1024; // 10 GB
if (size > maxSize) {
System.out.println("Cache exceeds " + (maxSize / 1024 / 1024 / 1024) + " GB, cleaning up");
ResourceCacheService cache = new ResourceCacheService(cacheDir.getAbsolutePath());
cache.deleteCacheFolder();
}
}public float[] embedWithRetry(String text, int maxRetries) {
int attempt = 0;
Exception lastException = null;
while (attempt < maxRetries) {
try {
return model.embed(text);
} catch (RuntimeException e) {
lastException = e;
attempt++;
if (attempt < maxRetries) {
try {
Thread.sleep(1000 * attempt); // Exponential backoff
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted during retry", ie);
}
}
}
}
throw new RuntimeException("Failed after " + maxRetries + " attempts", lastException);
}public class CircuitBreakerEmbeddingService {
private final TransformersEmbeddingModel model;
private final AtomicInteger failureCount = new AtomicInteger(0);
private final int threshold = 5;
private volatile boolean circuitOpen = false;
private volatile long circuitOpenTime = 0;
private final long resetTimeout = 60000; // 1 minute
public CircuitBreakerEmbeddingService(TransformersEmbeddingModel model) {
this.model = model;
}
public float[] embed(String text) {
if (circuitOpen) {
if (System.currentTimeMillis() - circuitOpenTime > resetTimeout) {
circuitOpen = false;
failureCount.set(0);
} else {
throw new RuntimeException("Circuit breaker is open");
}
}
try {
float[] embedding = model.embed(text);
failureCount.set(0); // Reset on success
return embedding;
} catch (Exception e) {
int failures = failureCount.incrementAndGet();
if (failures >= threshold) {
circuitOpen = true;
circuitOpenTime = System.currentTimeMillis();
}
throw e;
}
}
}// Handle models with different dimensions
public class MultiModelService {
private final Map<String, TransformersEmbeddingModel> models;
private final Map<String, Integer> dimensions;
public MultiModelService() throws Exception {
models = new HashMap<>();
dimensions = new HashMap<>();
// Model 1: 384 dimensions
TransformersEmbeddingModel model1 = new TransformersEmbeddingModel();
model1.afterPropertiesSet();
models.put("small", model1);
dimensions.put("small", model1.dimensions());
// Model 2: Custom dimensions
TransformersEmbeddingModel model2 = new TransformersEmbeddingModel();
model2.setModelResource("classpath:/models/large-model.onnx");
model2.afterPropertiesSet();
models.put("large", model2);
dimensions.put("large", model2.dimensions());
}
public float[] embed(String text, String modelName) {
return models.get(modelName).embed(text);
}
public int getDimensions(String modelName) {
return dimensions.get(modelName);
}
}public class ModelMigrationService {
public void migrateEmbeddings(
String oldModelUri,
String newModelUri,
List<String> texts) throws Exception {
// Generate embeddings with old model
TransformersEmbeddingModel oldModel = new TransformersEmbeddingModel();
oldModel.setModelResource(oldModelUri);
oldModel.afterPropertiesSet();
List<float[]> oldEmbeddings = oldModel.embed(texts);
// Generate embeddings with new model
TransformersEmbeddingModel newModel = new TransformersEmbeddingModel();
newModel.setModelResource(newModelUri);
newModel.afterPropertiesSet();
List<float[]> newEmbeddings = newModel.embed(texts);
// Store both for comparison or gradual migration
for (int i = 0; i < texts.size(); i++) {
storeEmbedding(texts.get(i), "old", oldEmbeddings.get(i));
storeEmbedding(texts.get(i), "new", newEmbeddings.get(i));
}
}
private void storeEmbedding(String text, String version, float[] embedding) {
// Store in database or vector store
}
}@Service
public class CachedEmbeddingService {
private final TransformersEmbeddingModel model;
private final Map<String, float[]> cache;
private final int maxCacheSize = 10000;
public CachedEmbeddingService(TransformersEmbeddingModel model) {
this.model = model;
this.cache = new LinkedHashMap<>(maxCacheSize, 0.75f, true) {
@Override
protected boolean removeEldestEntry(Map.Entry<String, float[]> eldest) {
return size() > maxCacheSize;
}
};
}
public float[] embed(String text) {
return cache.computeIfAbsent(text, model::embed);
}
public void clearCache() {
cache.clear();
}
}public int findOptimalBatchSize() {
List<String> testTexts = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
testTexts.add("Test text " + i);
}
int[] batchSizes = {10, 50, 100, 200, 500};
int optimalSize = 10;
long bestTime = Long.MAX_VALUE;
for (int batchSize : batchSizes) {
long start = System.currentTimeMillis();
for (int i = 0; i < testTexts.size(); i += batchSize) {
int end = Math.min(i + batchSize, testTexts.size());
model.embed(testTexts.subList(i, end));
}
long elapsed = System.currentTimeMillis() - start;
if (elapsed < bestTime) {
bestTime = elapsed;
optimalSize = batchSize;
}
}
return optimalSize;
}tessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1