ONNX-based Transformer models for text embeddings within the Spring AI framework
Generate vector embeddings for text strings, documents, or batches of text using ONNX-based transformer models. The embedding model converts text into fixed-size numerical vectors that capture semantic meaning, enabling similarity comparisons, clustering, and retrieval-augmented generation (RAG) workflows.
Embed a single text string into a vector. This is the most basic embedding operation, suitable for single-item processing or sequential workflows.
/**
* Embeds the given text into a vector.
*
* @param text the text to embed (must not be null)
* @return float array representing the embedding vector (384 dimensions for default model)
* @throws IllegalArgumentException if text is null
*/
float[] embed(String text);Usage:
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
float[] embedding = model.embed("Hello world");
// embedding.length == 384 (for default all-MiniLM-L6-v2 model)
// Verify embedding dimensions
assert embedding.length == model.dimensions();Edge Cases:
// Empty string - returns valid embedding
float[] emptyEmbedding = model.embed("");
// emptyEmbedding.length == 384, all zeros or model-specific default
// Whitespace-only text - treated as valid input
float[] whitespaceEmbedding = model.embed(" ");
// Valid embedding, likely close to empty string embedding
// Very long text - automatically truncated
String longText = "word ".repeat(1000); // 1000 words
float[] longEmbedding = model.embed(longText);
// Truncated to model's max length (512 tokens for all-MiniLM-L6-v2)
// No error thrown - silent truncation based on tokenizer config
// Special characters and Unicode
float[] unicodeEmbedding = model.embed("Hello 🌍 世界 \n\t");
// All Unicode handled correctly by tokenizer
// Null text - throws exception
try {
model.embed(null);
// Never reaches here
} catch (IllegalArgumentException e) {
// Exception: "text must not be null" or similar
}Performance:
// Single embedding typical latency:
// - CPU: 5-20ms (after model warm-up)
// - GPU: 2-5ms
// - First call: +100-500ms (ONNX session initialization)
long start = System.nanoTime();
float[] embedding = model.embed("Performance test");
long duration = System.nanoTime() - start;
System.out.println("Embedding time: " + duration / 1_000_000.0 + "ms");Embed a Spring AI Document object, which contains text content and metadata. The metadata mode determines whether metadata is included in the embedding computation.
/**
* Embeds the given document's content into a vector.
* Metadata handling depends on the MetadataMode configured in the constructor.
*
* @param document the document to embed (must not be null)
* @return float array representing the embedding vector
* @throws IllegalArgumentException if document is null
*/
float[] embed(Document document);Usage:
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import java.util.Map;
// With MetadataMode.NONE (default) - embeds text only
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
Document doc = new Document("Document text content");
float[] embedding = model.embed(doc);
// With metadata inclusion
TransformersEmbeddingModel modelWithMeta =
new TransformersEmbeddingModel(MetadataMode.EMBED);
modelWithMeta.afterPropertiesSet();
Document docWithMeta = new Document(
"Document text",
Map.of("source", "file.txt", "author", "John", "date", "2024-01-01")
);
float[] embeddingWithMeta = modelWithMeta.embed(docWithMeta);
// Embedding includes formatted metadata in the text
// Format: "source: file.txt\nauthor: John\ndate: 2024-01-01\nDocument text"Metadata Modes:
// MetadataMode enum values
public enum MetadataMode {
NONE, // Embed text content only (default)
EMBED, // Include metadata marked for embedding
ALL // Include all metadata fields
}Metadata Mode Comparison:
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import java.util.Map;
Document doc = new Document(
"Main content",
Map.of("title", "Doc Title", "author", "Jane")
);
// Mode NONE: Only content
TransformersEmbeddingModel modeNone =
new TransformersEmbeddingModel(MetadataMode.NONE);
modeNone.afterPropertiesSet();
float[] embeddingNone = modeNone.embed(doc);
// Embeds: "Main content"
// Mode EMBED: Content + metadata
TransformersEmbeddingModel modeEmbed =
new TransformersEmbeddingModel(MetadataMode.EMBED);
modeEmbed.afterPropertiesSet();
float[] embeddingEmbed = modeEmbed.embed(doc);
// Embeds: "title: Doc Title\nauthor: Jane\nMain content"
// Mode ALL: Content + all metadata
TransformersEmbeddingModel modeAll =
new TransformersEmbeddingModel(MetadataMode.ALL);
modeAll.afterPropertiesSet();
float[] embeddingAll = modeAll.embed(doc);
// Embeds: All metadata fields + content
// Different metadata modes produce different embeddings
assert !Arrays.equals(embeddingNone, embeddingEmbed);Edge Cases:
// Document with null text - handled gracefully
Document nullTextDoc = new Document(null);
float[] embeddingNull = model.embed(nullTextDoc);
// Treated as empty string
// Document with empty metadata
Document noMetaDoc = new Document("Text", Map.of());
float[] embeddingNoMeta = model.embed(noMetaDoc);
// Same as text-only document
// Document with large metadata
Map<String, Object> largeMeta = new HashMap<>();
for (int i = 0; i < 100; i++) {
largeMeta.put("key" + i, "value" + i);
}
Document largeMetaDoc = new Document("Text", largeMeta);
float[] embeddingLargeMeta = modelWithMeta.embed(largeMetaDoc);
// Metadata + text may be truncated if exceeds token limit
// Null document - throws exception
try {
model.embed((Document) null);
} catch (IllegalArgumentException e) {
// Exception thrown
}Embed multiple text strings in a single call. This is more efficient than calling embed(String) repeatedly as it batches the tokenization and inference operations.
/**
* Embeds a batch of texts into vectors.
* More efficient than calling embed(String) in a loop.
*
* @param texts list of texts to embed (must not be null, can be empty)
* @return list of float arrays, each representing an embedding vector
* @throws IllegalArgumentException if texts is null
*/
List<float[]> embed(List<String> texts);Usage:
import java.util.List;
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
List<String> texts = List.of(
"First sentence",
"Second sentence",
"Third sentence"
);
List<float[]> embeddings = model.embed(texts);
// embeddings.size() == 3
// embeddings.get(0).length == 384
// embeddings.get(1).length == 384
// embeddings.get(2).length == 384
// Order is preserved
assert embeddings.get(0) == model.embed("First sentence");Performance Comparison:
import java.util.ArrayList;
import java.util.List;
List<String> texts = new ArrayList<>();
for (int i = 0; i < 100; i++) {
texts.add("Text number " + i);
}
// Inefficient: Loop with individual calls
long start = System.currentTimeMillis();
List<float[]> embeddingsLoop = new ArrayList<>();
for (String text : texts) {
embeddingsLoop.add(model.embed(text)); // 100 separate inference calls
}
long loopTime = System.currentTimeMillis() - start;
// Efficient: Single batch call
start = System.currentTimeMillis();
List<float[]> embeddingsBatch = model.embed(texts); // 1 batched inference call
long batchTime = System.currentTimeMillis() - start;
// Batch is typically 30-50% faster
System.out.println("Loop time: " + loopTime + "ms");
System.out.println("Batch time: " + batchTime + "ms");
System.out.println("Speedup: " + (loopTime / (double) batchTime) + "x");Edge Cases:
// Empty list - returns empty list
List<float[]> emptyResult = model.embed(List.of());
assert emptyResult.isEmpty();
// Single item - works correctly
List<float[]> singleResult = model.embed(List.of("Single text"));
assert singleResult.size() == 1;
// List with null elements - behavior depends on implementation
List<String> textsWithNull = new ArrayList<>();
textsWithNull.add("Valid text");
textsWithNull.add(null); // May throw exception or skip
textsWithNull.add("Another valid text");
// Safer: filter nulls before calling
List<String> filtered = textsWithNull.stream()
.filter(t -> t != null)
.collect(Collectors.toList());
List<float[]> safeEmbeddings = model.embed(filtered);
// Mixed empty and non-empty texts
List<String> mixedTexts = List.of("", "Non-empty", "", "Another");
List<float[]> mixedEmbeddings = model.embed(mixedTexts);
// All produce valid embeddings, order preserved
// Very large batch
List<String> largeBatch = new ArrayList<>();
for (int i = 0; i < 10000; i++) {
largeBatch.add("Text " + i);
}
// May need to split into smaller batches to avoid OOM
int batchSize = 1000;
List<float[]> allEmbeddings = new ArrayList<>();
for (int i = 0; i < largeBatch.size(); i += batchSize) {
int end = Math.min(i + batchSize, largeBatch.size());
List<String> subBatch = largeBatch.subList(i, end);
allEmbeddings.addAll(model.embed(subBatch));
}Batch Size Optimization:
// Find optimal batch size for your environment
public int findOptimalBatchSize(TransformersEmbeddingModel model) {
List<String> testTexts = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
testTexts.add("Test text " + i);
}
int[] batchSizes = {10, 50, 100, 200, 500, 1000};
int optimalSize = 10;
long bestTime = Long.MAX_VALUE;
for (int batchSize : batchSizes) {
long start = System.currentTimeMillis();
for (int i = 0; i < testTexts.size(); i += batchSize) {
int end = Math.min(i + batchSize, testTexts.size());
model.embed(testTexts.subList(i, end));
}
long elapsed = System.currentTimeMillis() - start;
if (elapsed < bestTime) {
bestTime = elapsed;
optimalSize = batchSize;
}
}
return optimalSize;
}Embed multiple texts and receive a structured EmbeddingResponse containing results with indices and metadata. This provides additional context about the embedding operation.
/**
* Embeds a batch of texts into vectors and returns the EmbeddingResponse.
* Response includes embeddings with their indices and operation metadata.
*
* @param texts list of texts to embed (must not be null, can be empty)
* @return EmbeddingResponse containing embedding results with indices
* @throws IllegalArgumentException if texts is null
*/
EmbeddingResponse embedForResponse(List<String> texts);Usage:
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
import org.springframework.ai.embedding.Embedding;
import java.util.List;
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
EmbeddingResponse response = model.embedForResponse(
List.of("Text one", "Text two", "Text three")
);
// Access individual embeddings with indices
List<Embedding> results = response.getResults();
for (Embedding emb : results) {
int index = emb.getIndex(); // 0, 1, 2, ...
float[] vector = emb.getOutput(); // 384-dimensional vector
System.out.println("Embedding " + index + ": " +
Arrays.toString(Arrays.copyOf(vector, 5)) + "...");
}
// Check metadata (typically empty for TransformersEmbeddingModel)
EmbeddingResponseMetadata metadata = response.getMetadata();
if (metadata != null && !metadata.isEmpty()) {
String model = metadata.getModel(); // May be null
// Other metadata fields if available
}Accessing Results:
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.Embedding;
EmbeddingResponse response = model.embedForResponse(
List.of("First", "Second", "Third")
);
// Method 1: Iterate through results
for (Embedding embedding : response.getResults()) {
int idx = embedding.getIndex();
float[] vector = embedding.getOutput();
// Process embedding
}
// Method 2: Access by index
Embedding firstEmbedding = response.getResults().get(0);
assert firstEmbedding.getIndex() == 0;
// Method 3: Convert to simple float array list
List<float[]> simpleEmbeddings = response.getResults().stream()
.map(Embedding::getOutput)
.collect(Collectors.toList());Comparison with embed(List):
List<String> texts = List.of("Text 1", "Text 2", "Text 3");
// Simple method: Just get embeddings
List<float[]> embeddings = model.embed(texts);
// Returns: List of float arrays
// Response method: Get embeddings with metadata
EmbeddingResponse response = model.embedForResponse(texts);
List<Embedding> embeddingsWithMeta = response.getResults();
// Returns: EmbeddingResponse with Embedding objects (includes indices)
// Embeddings are equivalent
for (int i = 0; i < embeddings.size(); i++) {
assert Arrays.equals(
embeddings.get(i),
embeddingsWithMeta.get(i).getOutput()
);
}Edge Cases:
// Empty list
EmbeddingResponse emptyResponse = model.embedForResponse(List.of());
assert emptyResponse.getResults().isEmpty();
// Single text
EmbeddingResponse singleResponse = model.embedForResponse(List.of("Single"));
assert singleResponse.getResults().size() == 1;
assert singleResponse.getResults().get(0).getIndex() == 0;
// Null list - throws exception
try {
model.embedForResponse(null);
} catch (IllegalArgumentException e) {
// Exception thrown
}Execute a complete embedding request with custom options. This is the most flexible method, accepting an EmbeddingRequest object that can include options and configuration.
/**
* Executes an embedding request and returns the response.
* Implements the Model interface's call method.
*
* @param request the embedding request containing texts and options
* @return EmbeddingResponse containing embedding results
* @throws IllegalArgumentException if request is null
*/
EmbeddingResponse call(EmbeddingRequest request);Usage:
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingOptions;
import org.springframework.ai.embedding.EmbeddingResponse;
import java.util.List;
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
// Create request with options
EmbeddingRequest request = new EmbeddingRequest(
List.of("Text to embed"),
EmbeddingOptions.builder().build()
);
EmbeddingResponse response = model.call(request);
// Access results
List<float[]> embeddings = response.getResults().stream()
.map(Embedding::getOutput)
.collect(Collectors.toList());With Options (Future Extension Point):
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingOptions;
// Note: TransformersEmbeddingModel currently does not use custom options,
// but the API accepts them for interface compatibility and future extensions
EmbeddingOptions options = EmbeddingOptions.builder()
.model("custom-model-name") // Not used by TransformersEmbeddingModel
.dimensions(384) // Not used by TransformersEmbeddingModel
.build();
EmbeddingRequest request = new EmbeddingRequest(
List.of("Text 1", "Text 2"),
options
);
EmbeddingResponse response = model.call(request);
// Options are ignored but request is processed
// Useful for polymorphic code that works with multiple EmbeddingModel implementationsRelationship with Other Methods:
// These are equivalent:
List<String> texts = List.of("Text 1", "Text 2");
// Method 1: Using embedForResponse
EmbeddingResponse response1 = model.embedForResponse(texts);
// Method 2: Using call
EmbeddingRequest request = new EmbeddingRequest(texts, null);
EmbeddingResponse response2 = model.call(request);
// Results are identical
assert response1.getResults().size() == response2.getResults().size();Edge Cases:
// Request with null texts - behavior depends on EmbeddingRequest validation
EmbeddingRequest nullTextsRequest = new EmbeddingRequest(null, null);
// May throw exception from EmbeddingRequest constructor
// Request with empty texts
EmbeddingRequest emptyRequest = new EmbeddingRequest(List.of(), null);
EmbeddingResponse emptyResponse = model.call(emptyRequest);
assert emptyResponse.getResults().isEmpty();
// Request with null options (valid)
EmbeddingRequest noOptionsRequest = new EmbeddingRequest(
List.of("Text"),
null // null options are fine
);
EmbeddingResponse response = model.call(noOptionsRequest);
// Works correctly
// Null request - throws exception
try {
model.call(null);
} catch (IllegalArgumentException e) {
// Exception thrown
}Retrieve the dimensionality of the embedding vectors produced by the model. This value is cached after the first call.
/**
* Get the number of dimensions of the embedded vectors.
* Cached after first call to avoid repeated inference.
*
* @return the number of dimensions (384 for default all-MiniLM-L6-v2 model)
*/
int dimensions();Usage:
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
int dims = model.dimensions(); // 384 for default model
// Value is cached for subsequent calls
// Use dimensions for validation
float[] embedding = model.embed("Test text");
assert embedding.length == dims;
// Useful for database schema setup
System.out.println("Create table embeddings with vector of size " + dims);Caching Behavior:
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
// First call: Performs inference with test text to determine dimensions
long start = System.currentTimeMillis();
int dims1 = model.dimensions();
long firstCallTime = System.currentTimeMillis() - start; // ~5-20ms
// Second call: Returns cached value
start = System.currentTimeMillis();
int dims2 = model.dimensions();
long secondCallTime = System.currentTimeMillis() - start; // <1ms
assert dims1 == dims2;
assert secondCallTime < firstCallTime;Different Models:
// Default model: all-MiniLM-L6-v2 (384 dimensions)
TransformersEmbeddingModel defaultModel = new TransformersEmbeddingModel();
defaultModel.afterPropertiesSet();
assert defaultModel.dimensions() == 384;
// Custom model: may have different dimensions
TransformersEmbeddingModel customModel = new TransformersEmbeddingModel();
customModel.setModelResource("https://example.com/models/large-model.onnx");
customModel.setTokenizerResource("https://example.com/tokenizers/large-tokenizer.json");
customModel.afterPropertiesSet();
int customDims = customModel.dimensions(); // e.g., 768 or 1024The embedding operations are instrumented with Micrometer observations for monitoring and metrics:
import io.micrometer.observation.ObservationRegistry;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention;
// Create model with observation registry
ObservationRegistry registry = ObservationRegistry.create();
TransformersEmbeddingModel model = new TransformersEmbeddingModel(
MetadataMode.NONE,
registry
);
model.afterPropertiesSet();
// All embedding calls are now observed
float[] embedding = model.embed("Monitored text");
// Observation recorded with timing, success/failure, etc.Observation Details:
// Provider identifier for observations
public static final String PROVIDER_ONNX = "ONNX";
// Operation name for observations
public static final String OPERATION_EMBEDDING = "EMBEDDING_MODEL_OPERATION";Custom Observation Convention:
import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention;
import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext;
import io.micrometer.observation.Observation;
public class CustomEmbeddingObservationConvention
implements EmbeddingModelObservationConvention {
@Override
public String getName() {
return "custom.embedding.operation";
}
@Override
public KeyValues getLowCardinalityKeyValues(EmbeddingModelObservationContext context) {
return KeyValues.of(
"model.provider", "onnx",
"model.name", "all-MiniLM-L6-v2",
"operation.type", "embedding"
);
}
@Override
public KeyValues getHighCardinalityKeyValues(EmbeddingModelObservationContext context) {
return KeyValues.of(
"text.count", String.valueOf(context.getRequest().getInstructions().size()),
"embedding.dimensions", String.valueOf(context.getModel().dimensions())
);
}
}
// Use custom convention
TransformersEmbeddingModel model = new TransformersEmbeddingModel(
MetadataMode.NONE,
ObservationRegistry.create()
);
model.setObservationConvention(new CustomEmbeddingObservationConvention());
model.afterPropertiesSet();Metrics Collected:
// Typical metrics available from observations:
// - embedding.operation.duration: Time to generate embeddings
// - embedding.operation.count: Number of embedding operations
// - embedding.operation.error.count: Number of failed operations
// - embedding.text.count: Number of texts embedded per operationpackage org.springframework.ai.embedding;
public class EmbeddingRequest implements ModelRequest<List<String>> {
/**
* Create an embedding request.
*
* @param inputs List of text strings to embed
* @param options Optional configuration (can be null)
*/
public EmbeddingRequest(List<String> inputs, EmbeddingOptions options);
// From ModelRequest interface
public List<String> getInstructions();
public EmbeddingOptions getOptions();
}Description: Request object for embedding operations. Contains a list of input texts and optional configuration. Implements ModelRequest interface from Spring AI.
package org.springframework.ai.embedding;
public class EmbeddingResponse implements ModelResponse<Embedding> {
/**
* Create an embedding response with results only.
*/
public EmbeddingResponse(List<Embedding> results);
/**
* Create an embedding response with results and metadata.
*/
public EmbeddingResponse(
List<Embedding> results,
EmbeddingResponseMetadata metadata
);
/**
* Get the embedding results.
*/
public List<Embedding> getResults();
/**
* Get response metadata (may be null or empty).
*/
public EmbeddingResponseMetadata getMetadata();
}package org.springframework.ai.embedding;
public class Embedding {
/**
* Create an embedding with vector and index.
*
* @param output The embedding vector
* @param index The index in the batch (0-based)
*/
public Embedding(float[] output, int index);
/**
* Get the embedding vector.
*/
public float[] getOutput();
/**
* Get the index of this embedding in the batch.
*/
public int getIndex();
}package org.springframework.ai.embedding;
public class EmbeddingResponseMetadata extends AbstractResponseMetadata {
public EmbeddingResponseMetadata();
public EmbeddingResponseMetadata(String model, Usage usage);
public EmbeddingResponseMetadata(
String model,
Usage usage,
Map<String, Object> metadata
);
public String getModel();
public void setModel(String model);
public Usage getUsage();
public void setUsage(Usage usage);
// Inherited from AbstractResponseMetadata
public <T> T get(String key);
public <T> T getRequired(Object key);
public boolean containsKey(Object key);
public <T> T getOrDefault(Object key, T defaultObject);
public Set<Map.Entry<String, Object>> entrySet();
public Set<String> keySet();
public boolean isEmpty();
}Description: Metadata container for embedding responses. Provides information about the model used and token usage. Extends AbstractResponseMetadata which provides a Map<String, Object> for custom metadata access.
package org.springframework.ai.document;
public class Document {
/**
* Create a document with text only.
*/
public Document(String text);
/**
* Create a document with text and metadata.
*/
public Document(String text, Map<String, Object> metadata);
/**
* Get the document text.
*/
public String getText();
/**
* Get the document metadata.
*/
public Map<String, Object> getMetadata();
/**
* Get formatted content based on metadata mode.
* Used internally for embedding with metadata.
*
* @param metadataMode How to format metadata
* @return Formatted string with text and metadata
*/
public String getFormattedContent(MetadataMode metadataMode);
}package org.springframework.ai.document;
public enum MetadataMode {
/**
* Text content only, no metadata.
*/
NONE,
/**
* Include metadata marked for embedding.
*/
EMBED,
/**
* Include all metadata fields.
*/
ALL
}package org.springframework.ai.embedding;
public interface EmbeddingOptions extends ModelOptions {
/**
* Get the model name (may be null).
*/
String getModel();
/**
* Get desired dimensions (may be null).
*/
Integer getDimensions();
/**
* Create a builder for EmbeddingOptions.
*/
static Builder builder();
interface Builder {
Builder model(String model);
Builder dimensions(Integer dimensions);
EmbeddingOptions build();
}
}Description: Configuration options for embedding requests. Allows specification of model name and desired embedding dimensions. Use the builder pattern to create instances.
package org.springframework.ai.chat.metadata;
public interface Usage {
/**
* Number of tokens in the prompt.
*/
Integer getPromptTokens();
/**
* Number of tokens in the completion/generation.
*/
Integer getCompletionTokens();
/**
* Total tokens (prompt + completion).
*/
default Integer getTotalTokens() {
Integer prompt = getPromptTokens();
Integer completion = getCompletionTokens();
return (prompt != null && completion != null)
? prompt + completion
: null;
}
/**
* Native usage object from the underlying API.
*/
Object getNativeUsage();
}Description: Encapsulates token usage information from API requests. Tracks prompt tokens, completion tokens, and provides total count.
package org.springframework.ai.embedding;
public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {
// Main embedding operation
EmbeddingResponse call(EmbeddingRequest request);
// Embed single text string
default float[] embed(String text) {
EmbeddingResponse response = this.embedForResponse(List.of(text));
return response.getResults().get(0).getOutput();
}
// Embed document content
float[] embed(Document document);
// Embed batch of texts
default List<float[]> embed(List<String> texts) {
return this.embedForResponse(texts).getResults().stream()
.map(Embedding::getOutput)
.collect(Collectors.toList());
}
// Embed batch of documents with batching strategy
default List<float[]> embed(
List<Document> documents,
EmbeddingOptions options,
BatchingStrategy batchingStrategy
);
// Embed batch and return response object
default EmbeddingResponse embedForResponse(List<String> texts) {
EmbeddingRequest request = new EmbeddingRequest(texts, null);
return this.call(request);
}
// Get embedding vector dimensions
default int dimensions() {
return this.embed("Test").length;
}
}Description: Core interface for embedding models in Spring AI. Defines all embedding operations. TransformersEmbeddingModel implements this interface.
package org.springframework.ai.embedding;
public interface BatchingStrategy {
/**
* Split documents into optimized sub-batches for embedding.
* Preserves document order for proper mapping to embeddings.
*
* @param documents Documents to batch
* @return List of sub-batches containing documents
*/
List<List<Document>> batch(List<Document> documents);
}Description: Strategy for batching document embedding operations. Implementations can optimize based on token limits or other constraints.
// Use batch processing for multiple texts
List<String> texts = List.of(/* many texts */);
// ❌ SLOW: Sequential processing
List<float[]> embeddings1 = new ArrayList<>();
for (String text : texts) {
embeddings1.add(model.embed(text));
}
// ✅ FAST: Batch processing
List<float[]> embeddings2 = model.embed(texts);
// Batch is 30-50% faster due to:
// - Single tokenization pass
// - Batched ONNX inference
// - Reduced overhead// Enable GPU for large-scale embedding generation
TransformersEmbeddingModel gpuModel = new TransformersEmbeddingModel();
gpuModel.setGpuDeviceId(0); // Use first GPU
gpuModel.afterPropertiesSet();
// GPU typically 2-5x faster than CPU
// Especially beneficial for:
// - Large batches (100+ texts)
// - Real-time applications
// - High throughput requirements// Model and tokenizer are cached automatically
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setResourceCacheDirectory("/persistent/cache");
model.setDisableCaching(false); // Default
model.afterPropertiesSet();
// First run: Downloads and caches (~50MB for default model)
// Subsequent runs: Uses cached files (much faster startup)// dimensions() result is cached
int dims1 = model.dimensions(); // First call: performs inference
int dims2 = model.dimensions(); // Subsequent calls: returns cached value// Model instance can be reused across multiple calls after initialization
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
// Thread-safe: Can be used from multiple threads
ExecutorService executor = Executors.newFixedThreadPool(10);
for (int i = 0; i < 100; i++) {
final String text = "Text " + i;
executor.submit(() -> {
float[] embedding = model.embed(text); // Thread-safe
});
}
executor.shutdown();Embedding operations may throw:
// Null parameters
try {
model.embed((String) null);
} catch (IllegalArgumentException e) {
// "text must not be null" or similar
}
try {
model.embed((List<String>) null);
} catch (IllegalArgumentException e) {
// "texts must not be null" or similar
}
try {
model.embed((Document) null);
} catch (IllegalArgumentException e) {
// "document must not be null" or similar
}// ONNX Runtime exceptions during inference (rare)
try {
float[] embedding = model.embed("text");
} catch (RuntimeException e) {
// Possible causes:
// - Corrupted model file
// - ONNX Runtime internal error
// - GPU memory exhaustion
// - Invalid model output format
}// Initialization errors
try {
model.afterPropertiesSet();
} catch (Exception e) {
// Possible causes:
// - Model download failure
// - Invalid model format
// - CUDA/GPU initialization failure
// - Cache directory permission denied
}import java.util.List;
import java.util.ArrayList;
import java.util.Comparator;
public class SemanticSearch {
private final TransformersEmbeddingModel model;
private final List<String> documents;
private final List<float[]> documentEmbeddings;
public SemanticSearch(List<String> documents) throws Exception {
this.model = new TransformersEmbeddingModel();
this.model.afterPropertiesSet();
this.documents = documents;
// Pre-compute document embeddings
this.documentEmbeddings = model.embed(documents);
}
public List<SearchResult> search(String query, int topK) {
// Embed query
float[] queryEmbedding = model.embed(query);
// Compute similarities
List<SearchResult> results = new ArrayList<>();
for (int i = 0; i < documents.size(); i++) {
float similarity = cosineSimilarity(
queryEmbedding,
documentEmbeddings.get(i)
);
results.add(new SearchResult(documents.get(i), similarity));
}
// Sort by similarity and return top K
results.sort(Comparator.comparing(SearchResult::similarity).reversed());
return results.subList(0, Math.min(topK, results.size()));
}
private float cosineSimilarity(float[] a, float[] b) {
float dotProduct = 0f;
float normA = 0f;
float normB = 0f;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (float)(Math.sqrt(normA) * Math.sqrt(normB));
}
record SearchResult(String document, float similarity) {}
}import java.util.List;
import java.util.ArrayList;
import java.util.Random;
public class TextClusterer {
private final TransformersEmbeddingModel model;
public TextClusterer() throws Exception {
this.model = new TransformersEmbeddingModel();
this.model.afterPropertiesSet();
}
public List<List<String>> cluster(List<String> texts, int numClusters) {
// Generate embeddings
List<float[]> embeddings = model.embed(texts);
// K-means clustering
List<float[]> centroids = initializeCentroids(embeddings, numClusters);
List<Integer> assignments = new ArrayList<>();
// Iterate until convergence
for (int iter = 0; iter < 100; iter++) {
// Assign to nearest centroid
assignments.clear();
for (float[] embedding : embeddings) {
int nearest = findNearestCentroid(embedding, centroids);
assignments.add(nearest);
}
// Update centroids
updateCentroids(embeddings, assignments, centroids);
}
// Group texts by cluster
List<List<String>> clusters = new ArrayList<>();
for (int i = 0; i < numClusters; i++) {
clusters.add(new ArrayList<>());
}
for (int i = 0; i < texts.size(); i++) {
clusters.get(assignments.get(i)).add(texts.get(i));
}
return clusters;
}
// Helper methods omitted for brevity
private List<float[]> initializeCentroids(List<float[]> embeddings, int k) { /* ... */ }
private int findNearestCentroid(float[] embedding, List<float[]> centroids) { /* ... */ }
private void updateCentroids(List<float[]> embeddings, List<Integer> assignments, List<float[]> centroids) { /* ... */ }
}public class TextDeduplicator {
private final TransformersEmbeddingModel model;
private final float similarityThreshold;
public TextDeduplicator(float similarityThreshold) throws Exception {
this.model = new TransformersEmbeddingModel();
this.model.afterPropertiesSet();
this.similarityThreshold = similarityThreshold;
}
public List<String> deduplicate(List<String> texts) {
// Generate embeddings
List<float[]> embeddings = model.embed(texts);
// Find duplicates
List<String> unique = new ArrayList<>();
List<float[]> uniqueEmbeddings = new ArrayList<>();
for (int i = 0; i < texts.size(); i++) {
String text = texts.get(i);
float[] embedding = embeddings.get(i);
// Check similarity with existing unique texts
boolean isDuplicate = false;
for (float[] uniqueEmbedding : uniqueEmbeddings) {
float similarity = cosineSimilarity(embedding, uniqueEmbedding);
if (similarity >= similarityThreshold) {
isDuplicate = true;
break;
}
}
if (!isDuplicate) {
unique.add(text);
uniqueEmbeddings.add(embedding);
}
}
return unique;
}
private float cosineSimilarity(float[] a, float[] b) {
// Implementation as shown in SemanticSearch example
}
}tessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1