Common classes used across Spring AI providing document processing, text transformation, embedding utilities, observability support, and tokenization capabilities for AI application development
Embedding optimization provides strategies for batching documents to maximize efficiency when calling embedding APIs.
The embedding optimization layer consists of:
These components help optimize embedding API calls by grouping documents into batches that fit within model token limits while maximizing throughput.
Interface for batching Document objects to optimize embedding calls.
package org.springframework.ai.embedding;
import org.springframework.ai.document.Document;
import java.util.List;
interface BatchingStrategy {
/**
* Batch documents for optimal embedding API calls.
* @param documents documents to batch
* @return list of document batches
*/
List<List<Document>> batch(List<Document> documents);
}import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import java.util.List;
// Create batching strategy
BatchingStrategy batchingStrategy = new TokenCountBatchingStrategy();
// Documents to embed
List<Document> documents = List.of(
new Document("First document content"),
new Document("Second document content"),
new Document("Third document content")
// ... potentially hundreds or thousands of documents
);
// Batch documents
List<List<Document>> batches = batchingStrategy.batch(documents);
System.out.println("Total documents: " + documents.size());
System.out.println("Number of batches: " + batches.size());
// Process each batch
for (int i = 0; i < batches.size(); i++) {
List<Document> batch = batches.get(i);
System.out.println("Batch " + (i + 1) + " has " + batch.size() + " documents");
// Send batch to embedding API
// embedDocuments(batch);
}Token count based batching strategy with configurable limits and reserve buffer.
package org.springframework.ai.embedding;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.ContentFormatter;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
class TokenCountBatchingStrategy implements BatchingStrategy {
/**
* Create with defaults.
* - Encoding: CL100K_BASE (GPT-3.5/GPT-4)
* - Max tokens: 8191
* - Reserve: 10%
*/
TokenCountBatchingStrategy();
/**
* Create with custom configuration.
* @param encodingType JTokkit encoding type
* @param maxInputTokenCount maximum tokens per batch
* @param reservePercentage percentage to reserve (0.0 to 1.0)
*/
TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage);
/**
* Create with formatter and metadata mode.
* @param encodingType JTokkit encoding type
* @param maxInputTokenCount maximum tokens per batch
* @param reservePercentage percentage to reserve (0.0 to 1.0)
* @param contentFormatter formatter for document content
* @param metadataMode metadata mode for formatting
*/
TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage,
ContentFormatter contentFormatter, MetadataMode metadataMode);
/**
* Create with custom token estimator.
* @param tokenCountEstimator estimator for counting tokens
* @param maxInputTokenCount maximum tokens per batch
* @param reservePercentage percentage to reserve (0.0 to 1.0)
* @param contentFormatter formatter for document content
* @param metadataMode metadata mode for formatting
*/
TokenCountBatchingStrategy(TokenCountEstimator tokenCountEstimator, int maxInputTokenCount,
double reservePercentage, ContentFormatter contentFormatter,
MetadataMode metadataMode);
/**
* Batch documents based on token count.
* @param documents documents to batch
* @return list of document batches
*/
List<List<Document>> batch(List<Document> documents);
}import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
// Default configuration (8191 tokens, 10% reserve)
TokenCountBatchingStrategy defaultStrategy = new TokenCountBatchingStrategy();
List<Document> documents = List.of(
new Document("Document 1 with some content"),
new Document("Document 2 with more content"),
new Document("Document 3 with even more content")
);
List<List<Document>> batches = defaultStrategy.batch(documents);
// Custom token limit
// Example: OpenAI text-embedding-3-small has 8191 token limit
TokenCountBatchingStrategy openAIStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE, // OpenAI encoding
8191, // Max tokens
0.1 // 10% reserve (uses ~7372 tokens per batch)
);
List<List<Document>> openAIBatches = openAIStrategy.batch(documents);
// Lower token limit for smaller models
// Example: Some models have 512 token limits
TokenCountBatchingStrategy smallModelStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
512, // Smaller limit
0.05 // 5% reserve
);
List<List<Document>> smallBatches = smallModelStrategy.batch(documents);
// Conservative batching with higher reserve
TokenCountBatchingStrategy conservativeStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191,
0.2 // 20% reserve for safety
);
List<List<Document>> conservativeBatches = conservativeStrategy.batch(documents);import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.document.MetadataMode;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
// Create documents with metadata
List<Document> docs = List.of(
Document.builder()
.text("Content 1")
.metadata("category", "tech")
.metadata("internal_id", "123")
.build(),
Document.builder()
.text("Content 2")
.metadata("category", "science")
.metadata("internal_id", "456")
.build()
);
// Configure formatter to exclude internal metadata
DefaultContentFormatter formatter = DefaultContentFormatter.builder()
.withExcludedEmbedMetadataKeys("internal_id")
.build();
// Create batching strategy with formatter
TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191,
0.1,
formatter,
MetadataMode.EMBED // Use EMBED mode for embedding
);
List<List<Document>> batches = strategy.batch(docs);
// Token count is calculated on formatted content (without internal_id)import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.document.MetadataMode;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
// Create custom token estimator
TokenCountEstimator estimator = new JTokkitTokenCountEstimator(
EncodingType.CL100K_BASE
);
// Create batching strategy with custom estimator
TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy(
estimator,
8191,
0.1,
DefaultContentFormatter.defaultConfig(),
MetadataMode.EMBED
);
List<Document> documents = List.of(/* documents */);
List<List<Document>> batches = strategy.batch(documents);import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentReader;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.document.MetadataMode;
import org.springframework.core.io.ClassPathResource;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
/**
* Complete RAG embedding pipeline with batching.
*/
class RagEmbeddingPipeline {
public void processDocuments() {
// 1. Read documents
DocumentReader reader = new TextReader(new ClassPathResource("knowledge-base.txt"));
List<Document> documents = reader.get();
// 2. Split into chunks
TokenTextSplitter splitter = TokenTextSplitter.builder()
.withChunkSize(500) // 500 token chunks
.build();
List<Document> chunks = splitter.apply(documents);
System.out.println("Created " + chunks.size() + " chunks");
// 3. Configure formatter for embedding
DefaultContentFormatter formatter = DefaultContentFormatter.builder()
.withExcludedEmbedMetadataKeys("internal_id", "timestamp")
.build();
// 4. Create batching strategy
TokenCountBatchingStrategy batchingStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191, // OpenAI embedding limit
0.1, // 10% reserve
formatter,
MetadataMode.EMBED
);
// 5. Batch chunks
List<List<Document>> batches = batchingStrategy.batch(chunks);
System.out.println("Created " + batches.size() + " batches");
// 6. Process each batch
for (int i = 0; i < batches.size(); i++) {
List<Document> batch = batches.get(i);
System.out.println("Batch " + (i + 1) + ": " + batch.size() + " documents");
// Call embedding API
// List<float[]> embeddings = embeddingClient.embed(batch);
// Store in vector database
// vectorStore.add(batch, embeddings);
}
}
}import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Process batches in parallel for faster embedding.
*/
class ParallelBatchProcessor {
private final TokenCountBatchingStrategy batchingStrategy;
private final ExecutorService executor;
public ParallelBatchProcessor() {
this.batchingStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191,
0.1
);
this.executor = Executors.newFixedThreadPool(4);
}
public void processAsync(List<Document> documents) {
// Batch documents
List<List<Document>> batches = batchingStrategy.batch(documents);
// Process batches in parallel
List<CompletableFuture<Void>> futures = batches.stream()
.map(batch -> CompletableFuture.runAsync(() -> {
processBatch(batch);
}, executor))
.toList();
// Wait for all batches
CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join();
System.out.println("All batches processed");
}
private void processBatch(List<Document> batch) {
// Call embedding API for this batch
System.out.println("Processing batch of " + batch.size() + " documents");
// embedDocuments(batch);
}
public void shutdown() {
executor.shutdown();
}
}
// Usage
ParallelBatchProcessor processor = new ParallelBatchProcessor();
List<Document> documents = List.of(/* large document list */);
processor.processAsync(documents);
processor.shutdown();import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
/**
* Dynamically adjust batch size based on document characteristics.
*/
class DynamicBatchStrategy {
private final TokenCountEstimator estimator;
public DynamicBatchStrategy() {
this.estimator = new JTokkitTokenCountEstimator(EncodingType.CL100K_BASE);
}
public List<List<Document>> batchAdaptively(List<Document> documents) {
// Estimate average document size
int totalTokens = 0;
for (Document doc : documents) {
totalTokens += estimator.estimate(doc.getText());
}
int avgTokensPerDoc = totalTokens / documents.size();
// Adjust reserve based on document size variance
double reserve;
if (avgTokensPerDoc < 100) {
reserve = 0.05; // Small docs - less reserve needed
} else if (avgTokensPerDoc < 500) {
reserve = 0.1; // Medium docs - standard reserve
} else {
reserve = 0.2; // Large docs - more reserve for safety
}
// Choose token limit based on average size
int tokenLimit = 8191; // Default
TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
tokenLimit,
reserve
);
return strategy.batch(documents);
}
}
// Usage
DynamicBatchStrategy dynamicStrategy = new DynamicBatchStrategy();
List<Document> mixedSizeDocs = List.of(/* documents of varying sizes */);
List<List<Document>> adaptiveBatches = dynamicStrategy.batchAdaptively(mixedSizeDocs);import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.document.MetadataMode;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
/**
* Model-specific batching strategies.
*/
class ModelBatchingStrategies {
/**
* OpenAI text-embedding-3-small/large
* Limit: 8191 tokens
*/
public static TokenCountBatchingStrategy openAIStrategy() {
return new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191,
0.1,
DefaultContentFormatter.defaultConfig(),
MetadataMode.EMBED
);
}
/**
* Cohere embed models
* Limit: 512 tokens
*/
public static TokenCountBatchingStrategy cohereStrategy() {
return new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
512,
0.05,
DefaultContentFormatter.defaultConfig(),
MetadataMode.EMBED
);
}
/**
* Custom model with specific limits
*/
public static TokenCountBatchingStrategy customModelStrategy(int tokenLimit) {
return new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
tokenLimit,
0.1,
DefaultContentFormatter.defaultConfig(),
MetadataMode.EMBED
);
}
}
// Usage
List<Document> documents = List.of(/* documents */);
// OpenAI batching
List<List<Document>> openAIBatches = ModelBatchingStrategies.openAIStrategy()
.batch(documents);
// Cohere batching (smaller batches)
List<List<Document>> cohereBatches = ModelBatchingStrategies.cohereStrategy()
.batch(documents);
// Custom model
List<List<Document>> customBatches = ModelBatchingStrategies.customModelStrategy(2048)
.batch(documents);import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
/**
* Track batching statistics for monitoring.
*/
class BatchingMonitor {
private final TokenCountBatchingStrategy strategy;
private final TokenCountEstimator estimator;
public BatchingMonitor() {
this.strategy = new TokenCountBatchingStrategy();
this.estimator = new JTokkitTokenCountEstimator(EncodingType.CL100K_BASE);
}
public BatchStatistics batchWithStats(List<Document> documents) {
long startTime = System.currentTimeMillis();
// Calculate pre-batch statistics
int totalDocs = documents.size();
int totalTokens = 0;
int minTokens = Integer.MAX_VALUE;
int maxTokens = 0;
for (Document doc : documents) {
int tokens = estimator.estimate(doc.getText());
totalTokens += tokens;
minTokens = Math.min(minTokens, tokens);
maxTokens = Math.max(maxTokens, tokens);
}
// Batch documents
List<List<Document>> batches = strategy.batch(documents);
long endTime = System.currentTimeMillis();
// Calculate batch statistics
int numBatches = batches.size();
int minBatchSize = batches.stream().mapToInt(List::size).min().orElse(0);
int maxBatchSize = batches.stream().mapToInt(List::size).max().orElse(0);
double avgBatchSize = (double) totalDocs / numBatches;
return new BatchStatistics(
totalDocs,
numBatches,
totalTokens,
totalTokens / totalDocs,
minTokens,
maxTokens,
minBatchSize,
maxBatchSize,
avgBatchSize,
endTime - startTime
);
}
record BatchStatistics(
int totalDocuments,
int numBatches,
int totalTokens,
int avgTokensPerDoc,
int minTokensPerDoc,
int maxTokensPerDoc,
int minBatchSize,
int maxBatchSize,
double avgBatchSize,
long batchingTimeMs
) {
@Override
public String toString() {
return String.format("""
Batch Statistics:
- Total documents: %d
- Number of batches: %d
- Total tokens: %d
- Avg tokens/doc: %d (min: %d, max: %d)
- Batch sizes: %.1f avg (min: %d, max: %d)
- Batching time: %d ms
""",
totalDocuments, numBatches, totalTokens,
avgTokensPerDoc, minTokensPerDoc, maxTokensPerDoc,
avgBatchSize, minBatchSize, maxBatchSize,
batchingTimeMs
);
}
}
}
// Usage
BatchingMonitor monitor = new BatchingMonitor();
List<Document> documents = List.of(/* documents */);
BatchingMonitor.BatchStatistics stats = monitor.batchWithStats(documents);
System.out.println(stats);import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.document.Document;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
/**
* Optimize batching to minimize API costs.
*/
class CostOptimizer {
private final TokenCountBatchingStrategy strategy;
private final JTokkitTokenCountEstimator estimator;
// Example: OpenAI charges per 1000 tokens
private static final double COST_PER_1K_TOKENS = 0.0001;
public CostOptimizer(int tokenLimit) {
// Maximize batch size to minimize overhead
this.strategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
tokenLimit,
0.05 // Minimal reserve to maximize batch size
);
this.estimator = new JTokkitTokenCountEstimator(EncodingType.CL100K_BASE);
}
public EmbeddingCostEstimate estimateCost(List<Document> documents) {
// Calculate total tokens
int totalTokens = 0;
for (Document doc : documents) {
totalTokens += estimator.estimate(doc.getText());
}
// Calculate batches
List<List<Document>> batches = strategy.batch(documents);
// Estimate cost
double estimatedCost = (totalTokens / 1000.0) * COST_PER_1K_TOKENS;
return new EmbeddingCostEstimate(
documents.size(),
batches.size(),
totalTokens,
estimatedCost
);
}
record EmbeddingCostEstimate(
int documents,
int batches,
int tokens,
double estimatedCostUSD
) {
@Override
public String toString() {
return String.format(
"Documents: %d | Batches: %d | Tokens: %,d | Est. Cost: $%.4f",
documents, batches, tokens, estimatedCostUSD
);
}
}
}
// Usage
CostOptimizer optimizer = new CostOptimizer(8191);
List<Document> documents = List.of(/* large document set */);
CostOptimizer.EmbeddingCostEstimate estimate = optimizer.estimateCost(documents);
System.out.println(estimate);
// Documents: 10,000 | Batches: 50 | Tokens: 1,500,000 | Est. Cost: $0.1500The reserve percentage creates a buffer below the maximum token limit to account for:
// Example with 8191 token limit and 10% reserve:
// Effective limit: 8191 * 0.9 = 7372 tokens per batch
TokenCountBatchingStrategy strategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191,
0.1 // 10% reserve
);
// The strategy will ensure batches stay under 7372 tokensReserve percentage recommendations:
Thread Safety:
TokenCountBatchingStrategy is thread-safe and can be reusedPerformance:
Common Exceptions:
IllegalArgumentException: If maxInputTokenCount <= 0 or reservePercentage < 0 or > 1NullPointerException: If documents list is nullRuntimeException: Token counting errors (encoding issues)Edge Cases:
// Empty document list
List<Document> empty = List.of();
List<List<Document>> batches = strategy.batch(empty); // Returns empty list
// Single large document exceeding limit
Document huge = new Document("...10000 tokens...");
List<List<Document>> batches = strategy.batch(List.of(huge));
// Returns single batch with single document (may exceed limit)
// Documents with no text
Document noText = Document.builder().media(media).build();
// Token count is 0, batches normally
// Reserve percentage edge cases
new TokenCountBatchingStrategy(EncodingType.CL100K_BASE, 1000, 0.0); // No reserve
new TokenCountBatchingStrategy(EncodingType.CL100K_BASE, 1000, 1.0); // 100% reserve (effective limit = 0)
new TokenCountBatchingStrategy(EncodingType.CL100K_BASE, 1000, 0.5); // 50% reserve (effective limit = 500)Install with Tessl CLI
npx tessl i tessl/maven-org-springframework-ai--spring-ai-commons