ONNX-based Transformer models for text embeddings within the Spring AI framework
Complete guide to integrating Spring AI Transformers with Spring Boot.
Spring Boot automatically configures TransformersEmbeddingModel when:
spring.ai.embedding.model property is transformers or not setTransformersEmbeddingModel bean exists<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-transformers</artifactId>
<version>1.1.2</version>
</dependency>@Service
public class EmbeddingService {
private final TransformersEmbeddingModel model;
public EmbeddingService(TransformersEmbeddingModel model) {
this.model = model;
}
public float[] embed(String text) {
return model.embed(text);
}
}# Enable auto-configuration
spring.ai.embedding.model=transformers
# Configure model
spring.ai.embedding.transformer.onnx.model-uri=classpath:/models/model.onnx
spring.ai.embedding.transformer.onnx.gpu-device-id=0
# Configure cache
spring.ai.embedding.transformer.cache.directory=/var/cache/spring-ai# application-dev.properties
spring.ai.embedding.transformer.onnx.gpu-device-id=-1
spring.ai.embedding.transformer.cache.directory=./cache
# application-prod.properties
spring.ai.embedding.transformer.onnx.gpu-device-id=0
spring.ai.embedding.transformer.cache.directory=/var/cache/spring-ai@Configuration
public class CustomEmbeddingConfig {
@Bean
public TransformersEmbeddingModel embeddingModel() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(0);
model.setResourceCacheDirectory("/custom/cache");
model.afterPropertiesSet();
return model;
}
}@Configuration
public class MultiModelConfig {
@Bean
@Qualifier("cpuModel")
public TransformersEmbeddingModel cpuModel() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(-1);
model.afterPropertiesSet();
return model;
}
@Bean
@Qualifier("gpuModel")
public TransformersEmbeddingModel gpuModel() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(0);
model.afterPropertiesSet();
return model;
}
}
@Service
public class DualModelService {
@Autowired
@Qualifier("cpuModel")
private TransformersEmbeddingModel cpuModel;
@Autowired
@Qualifier("gpuModel")
private TransformersEmbeddingModel gpuModel;
}@RestController
@RequestMapping("/api/embeddings")
public class EmbeddingController {
private final TransformersEmbeddingModel model;
public EmbeddingController(TransformersEmbeddingModel model) {
this.model = model;
}
@PostMapping("/single")
public float[] embedSingle(@RequestBody String text) {
return model.embed(text);
}
@PostMapping("/batch")
public List<float[]> embedBatch(@RequestBody List<String> texts) {
return model.embed(texts);
}
}@RestController
@RequestMapping("/api/embeddings")
public class RobustEmbeddingController {
private final TransformersEmbeddingModel model;
public RobustEmbeddingController(TransformersEmbeddingModel model) {
this.model = model;
}
@PostMapping("/embed")
public ResponseEntity<?> embed(@RequestBody EmbedRequest request) {
try {
if (request.getText() == null || request.getText().isEmpty()) {
return ResponseEntity.badRequest()
.body(Map.of("error", "Text cannot be empty"));
}
float[] embedding = model.embed(request.getText());
return ResponseEntity.ok(Map.of(
"embedding", embedding,
"dimensions", embedding.length
));
} catch (Exception e) {
return ResponseEntity.status(500)
.body(Map.of("error", e.getMessage()));
}
}
record EmbedRequest(String text) {}
}@Service
public class EmbeddingService {
private final TransformersEmbeddingModel model;
public EmbeddingService(TransformersEmbeddingModel model) {
this.model = model;
}
public float[] generateEmbedding(String text) {
return model.embed(text);
}
public List<float[]> generateBatchEmbeddings(List<String> texts) {
return model.embed(texts);
}
public int getDimensions() {
return model.dimensions();
}
}@Service
public class DocumentService {
private final TransformersEmbeddingModel model;
private final DocumentRepository repository;
public DocumentService(
TransformersEmbeddingModel model,
DocumentRepository repository) {
this.model = model;
this.repository = repository;
}
public void indexDocument(String id, String content) {
float[] embedding = model.embed(content);
DocumentEntity doc = new DocumentEntity(id, content, embedding);
repository.save(doc);
}
public List<DocumentEntity> search(String query, int topK) {
float[] queryEmbedding = model.embed(query);
return repository.findSimilar(queryEmbedding, topK);
}
}@Service
public class AsyncEmbeddingService {
private final TransformersEmbeddingModel model;
public AsyncEmbeddingService(TransformersEmbeddingModel model) {
this.model = model;
}
@Async
public CompletableFuture<float[]> embedAsync(String text) {
return CompletableFuture.completedFuture(model.embed(text));
}
@Async
public CompletableFuture<List<float[]>> embedBatchAsync(List<String> texts) {
return CompletableFuture.completedFuture(model.embed(texts));
}
}
@Configuration
@EnableAsync
public class AsyncConfig {
@Bean
public Executor taskExecutor() {
ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor();
executor.setCorePoolSize(4);
executor.setMaxPoolSize(8);
executor.setQueueCapacity(100);
executor.setThreadNamePrefix("embedding-");
executor.initialize();
return executor;
}
}@Configuration
public class ObservationConfig {
@Bean
public TransformersEmbeddingModel embeddingModel(
ObservationRegistry observationRegistry) throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel(
MetadataMode.NONE,
observationRegistry
);
model.afterPropertiesSet();
return model;
}
}@Service
public class MetricsEmbeddingService {
private final TransformersEmbeddingModel model;
private final MeterRegistry meterRegistry;
private final Counter embeddingCounter;
private final Timer embeddingTimer;
public MetricsEmbeddingService(
TransformersEmbeddingModel model,
MeterRegistry meterRegistry) {
this.model = model;
this.meterRegistry = meterRegistry;
this.embeddingCounter = Counter.builder("embeddings.generated")
.description("Number of embeddings generated")
.register(meterRegistry);
this.embeddingTimer = Timer.builder("embeddings.duration")
.description("Time to generate embeddings")
.register(meterRegistry);
}
public float[] embed(String text) {
return embeddingTimer.record(() -> {
float[] embedding = model.embed(text);
embeddingCounter.increment();
return embedding;
});
}
}@Component
public class EmbeddingHealthIndicator implements HealthIndicator {
private final TransformersEmbeddingModel model;
public EmbeddingHealthIndicator(TransformersEmbeddingModel model) {
this.model = model;
}
@Override
public Health health() {
try {
// Test embedding generation
float[] testEmbedding = model.embed("health check");
int dimensions = model.dimensions();
return Health.up()
.withDetail("dimensions", dimensions)
.withDetail("status", "operational")
.build();
} catch (Exception e) {
return Health.down()
.withDetail("error", e.getMessage())
.build();
}
}
}@Service
@CacheConfig(cacheNames = "embeddings")
public class CachedEmbeddingService {
private final TransformersEmbeddingModel model;
public CachedEmbeddingService(TransformersEmbeddingModel model) {
this.model = model;
}
@Cacheable(key = "#text")
public float[] embed(String text) {
return model.embed(text);
}
@CacheEvict(allEntries = true)
public void clearCache() {
// Cache cleared
}
}
@Configuration
@EnableCaching
public class CacheConfig {
@Bean
public CacheManager cacheManager() {
return new ConcurrentMapCacheManager("embeddings");
}
}@SpringBootTest
public class EmbeddingIntegrationTest {
@Autowired
private TransformersEmbeddingModel model;
@Test
public void testEmbedding() {
float[] embedding = model.embed("test");
assertNotNull(embedding);
assertEquals(384, embedding.length);
}
}@TestConfiguration
public class TestEmbeddingConfig {
@Bean
@Primary
public TransformersEmbeddingModel testEmbeddingModel() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
model.setGpuDeviceId(-1); // CPU for tests
model.setResourceCacheDirectory(
System.getProperty("java.io.tmpdir") + "/test-cache"
);
model.afterPropertiesSet();
return model;
}
}spring.autoconfigure.exclude=org.springframework.ai.model.transformers.autoconfigure.TransformersEmbeddingModelAutoConfiguration@SpringBootApplication(exclude = {
TransformersEmbeddingModelAutoConfiguration.class
})
public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
}tessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1