ONNX-based Transformer models for text embeddings within the Spring AI framework
This guide walks you through getting started with Spring AI Transformers.
Add the dependency to your pom.xml:
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-transformers</artifactId>
<version>1.1.2</version>
</dependency>Add to your build.gradle:
implementation 'org.springframework.ai:spring-ai-transformers:1.1.2'For GPU acceleration, add ONNX Runtime GPU:
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime-gpu</artifactId>
<version>1.16.0</version>
</dependency>Create application.properties:
# Optional: Use GPU
spring.ai.embedding.transformer.onnx.gpu-device-id=0
# Optional: Custom cache location
spring.ai.embedding.transformer.cache.directory=/var/cache/spring-aiimport org.springframework.ai.transformers.TransformersEmbeddingModel;
import org.springframework.stereotype.Service;
@Service
public class EmbeddingService {
private final TransformersEmbeddingModel embeddingModel;
// Constructor injection
public EmbeddingService(TransformersEmbeddingModel embeddingModel) {
this.embeddingModel = embeddingModel;
}
public float[] generateEmbedding(String text) {
return embeddingModel.embed(text);
}
public List<float[]> generateBatchEmbeddings(List<String> texts) {
return embeddingModel.embed(texts);
}
}@RestController
public class EmbeddingController {
private final EmbeddingService embeddingService;
public EmbeddingController(EmbeddingService embeddingService) {
this.embeddingService = embeddingService;
}
@PostMapping("/embed")
public float[] embed(@RequestBody String text) {
return embeddingService.generateEmbedding(text);
}
}import org.springframework.ai.transformers.TransformersEmbeddingModel;
public class ManualSetup {
public static void main(String[] args) throws Exception {
// Create model with defaults
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
// CRITICAL: Initialize before use
model.afterPropertiesSet();
// Generate embedding
float[] embedding = model.embed("Hello, world!");
System.out.println("Embedding dimensions: " + embedding.length);
}
}import org.springframework.ai.transformers.TransformersEmbeddingModel;
import java.util.Map;
public class CustomSetup {
public static TransformersEmbeddingModel createModel() throws Exception {
TransformersEmbeddingModel model = new TransformersEmbeddingModel();
// Configure GPU
model.setGpuDeviceId(0);
// Configure cache
model.setResourceCacheDirectory("/var/cache/models");
// Configure tokenizer
model.setTokenizerOptions(Map.of(
"modelMaxLength", "512",
"truncation", "true"
));
// Initialize
model.afterPropertiesSet();
return model;
}
}String text = "The quick brown fox jumps over the lazy dog";
float[] embedding = model.embed(text);
System.out.println("Generated " + embedding.length + "-dimensional embedding");List<String> texts = List.of(
"First sentence",
"Second sentence",
"Third sentence"
);
List<float[]> embeddings = model.embed(texts);
System.out.println("Generated " + embeddings.size() + " embeddings");import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import java.util.Map;
// Create model with metadata support
TransformersEmbeddingModel model = new TransformersEmbeddingModel(MetadataMode.EMBED);
model.afterPropertiesSet();
// Create document
Document doc = new Document(
"Document content",
Map.of("source", "file.txt", "author", "John")
);
// Embed with metadata
float[] embedding = model.embed(doc);int dimensions = model.dimensions();
System.out.println("Model produces " + dimensions + "-dimensional embeddings");
// Default model: 384 dimensionsOn first run, the model will:
This may take 10-30 seconds depending on network speed.
Subsequent runs will use cached files and start in <1 second.
Test your setup:
@SpringBootTest
public class EmbeddingModelTest {
@Autowired
private TransformersEmbeddingModel model;
@Test
public void testEmbedding() {
float[] embedding = model.embed("test");
assertNotNull(embedding);
assertEquals(384, embedding.length); // Default model dimensions
// Verify non-zero values
boolean hasNonZero = false;
for (float v : embedding) {
if (v != 0.0f) {
hasNonZero = true;
break;
}
}
assertTrue(hasNonZero);
}
}Problem: Network timeout or connection error
Solution:
# Use local model
spring.ai.embedding.transformer.onnx.model-uri=classpath:/models/model.onnxProblem: CUDA initialization fails
Solution:
# Use CPU
spring.ai.embedding.transformer.onnx.gpu-device-id=-1Problem: Large batch causes OOM
Solution:
// Process in smaller batches
int batchSize = 100;
for (int i = 0; i < texts.size(); i += batchSize) {
List<String> batch = texts.subList(i, Math.min(i + batchSize, texts.size()));
List<float[]> embeddings = model.embed(batch);
// Process embeddings
}tessl i tessl/maven-org-springframework-ai--spring-ai-transformers@1.1.1