ONNX-based Transformer models for text embeddings within the Spring AI framework
ONNX-based Transformer models for text embeddings within the Spring AI framework. Provides local, CPU and GPU-accelerated embedding generation using ONNX Runtime, with zero external API dependencies for inference.
org.springframework.ai:spring-ai-transformers<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-transformers</artifactId>
<version>1.1.2</version>
</dependency>implementation 'org.springframework.ai:spring-ai-transformers:1.1.2'@Service
public class MyService {
private final TransformersEmbeddingModel embeddingModel;
public MyService(TransformersEmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
}
public float[] embed(String text) {
return embeddingModel.embed(text);
}
}TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet(); // Required before use
float[] embedding = model.embed("Hello world");EmbeddingModel interface// Single text
float[] embed(String text);
// Document with metadata
float[] embed(Document document);
// Batch processing
List<float[]> embed(List<String> texts);
// With response metadata
EmbeddingResponse embedForResponse(List<String> texts);
// Full request
EmbeddingResponse call(EmbeddingRequest request);
// Get dimensions
int dimensions();// Model and tokenizer
void setModelResource(String modelResourceUri);
void setTokenizerResource(String tokenizerResourceUri);
void setTokenizerOptions(Map<String, String> tokenizerOptions);
// Hardware
void setGpuDeviceId(int gpuDeviceId);
// Caching
void setDisableCaching(boolean disableCaching);
void setResourceCacheDirectory(String resourceCacheDir);
// Model output
void setModelOutputName(String modelOutputName);
// Initialization (required)
void afterPropertiesSet() throws Exception;# Model configuration
spring.ai.embedding.transformer.onnx.model-uri=classpath:/models/model.onnx
spring.ai.embedding.transformer.onnx.gpu-device-id=0
spring.ai.embedding.transformer.onnx.model-output-name=last_hidden_state
# Tokenizer configuration
spring.ai.embedding.transformer.tokenizer.uri=classpath:/tokenizers/tokenizer.json
spring.ai.embedding.transformer.tokenizer.options.modelMaxLength=512
# Cache configuration
spring.ai.embedding.transformer.cache.enabled=true
spring.ai.embedding.transformer.cache.directory=/var/cache/spring-ai
# Metadata mode
spring.ai.embedding.transformer.metadata-mode=NONEimport org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.ai.transformers.ResourceCacheService;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;| Setting | Default Value |
|---|---|
| Model | all-MiniLM-L6-v2 |
| Dimensions | 384 |
| Hardware | CPU (GPU via setGpuDeviceId(0)) |
| Cache Location | {java.io.tmpdir}/spring-ai-onnx-generative |
| Metadata Mode | NONE |
| Caching | Enabled |
@Configuration
public class EmbeddingConfig {
@Bean
public TransformersEmbeddingModel embeddingModel() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.afterPropertiesSet();
return model;
}
}TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(0);
try {
model.afterPropertiesSet();
} catch (Exception e) {
if (e.getMessage().contains("GPU")) {
model.setGpuDeviceId(-1); // Fallback to CPU
model.afterPropertiesSet();
}
}List<String> texts = List.of("text1", "text2", "text3");
List<float[]> embeddings = model.embed(texts); // More efficient than loop// Main model class
public class TransformersEmbeddingModel extends AbstractEmbeddingModel implements InitializingBean
// Resource caching
public class ResourceCacheService
// Metadata handling
public enum MetadataMode { NONE, EMBED, ALL }
// Response types
public class EmbeddingResponse
public class Embedding
public class EmbeddingRequest// Initialization errors
try {
model.afterPropertiesSet();
} catch (Exception e) {
// Handle: model loading, GPU, cache, network errors
}
// Runtime errors
try {
float[] embedding = model.embed(text);
} catch (IllegalArgumentException e) {
// Null parameters
} catch (RuntimeException e) {
// ONNX inference errors
}# Enable debug logging
debug=true
logging.level.org.springframework.boot.autoconfigure=DEBUG# Fallback to CPU
spring.ai.embedding.transformer.onnx.gpu-device-id=-1# Use local resources
spring.ai.embedding.transformer.onnx.model-uri=classpath:/models/model.onnx
spring.ai.embedding.transformer.cache.directory=/persistent/cachetessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1