Common classes used across Spring AI providing document processing, text transformation, embedding utilities, observability support, and tokenization capabilities for AI application development
This document provides complete, production-ready examples for common use cases.
Complete retrieval-augmented generation pipeline with document processing, chunking, and embedding.
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.document.DefaultContentFormatter;
import org.springframework.ai.reader.TextReader;
import org.springframework.ai.transformer.splitter.TokenTextSplitter;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.core.io.ClassPathResource;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
public class RAGPipeline {
private final TokenTextSplitter splitter;
private final TokenCountBatchingStrategy batchingStrategy;
private final DefaultContentFormatter formatter;
public RAGPipeline() {
// Configure text splitter for 512-token chunks
this.splitter = TokenTextSplitter.builder()
.withChunkSize(512)
.withMinChunkSizeChars(100)
.build();
// Configure formatter to exclude internal metadata from embeddings
this.formatter = DefaultContentFormatter.builder()
.withExcludedEmbedMetadataKeys("internal_id", "timestamp", "chunk_index")
.build();
// Configure batching for OpenAI embedding API
this.batchingStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191, // OpenAI text-embedding-3 limit
0.1 // 10% reserve
);
}
public void processKnowledgeBase(String resourcePath) {
try {
// 1. Read documents
TextReader reader = new TextReader(new ClassPathResource(resourcePath));
reader.getCustomMetadata().put("source", resourcePath);
reader.getCustomMetadata().put("import_date", System.currentTimeMillis());
List<Document> documents = reader.get();
System.out.println("Loaded " + documents.size() + " documents");
// 2. Split into chunks
List<Document> chunks = splitter.apply(documents);
System.out.println("Created " + chunks.size() + " chunks");
// 3. Add chunk metadata
for (int i = 0; i < chunks.size(); i++) {
chunks.get(i).getMetadata().put("chunk_index", i);
chunks.get(i).getMetadata().put("total_chunks", chunks.size());
}
// 4. Batch for embedding
List<List<Document>> batches = batchingStrategy.batch(chunks);
System.out.println("Organized into " + batches.size() + " batches");
// 5. Process each batch
for (int i = 0; i < batches.size(); i++) {
List<Document> batch = batches.get(i);
System.out.println("Processing batch " + (i + 1) + " with " + batch.size() + " documents");
for (Document doc : batch) {
// Get formatted content for embedding
String embedContent = doc.getFormattedContent(MetadataMode.EMBED);
// Send to embedding API
// float[] embedding = embeddingClient.embed(embedContent);
// Store in vector database
// vectorStore.add(doc.getId(), embedding, doc.getMetadata());
}
}
System.out.println("RAG pipeline complete");
} catch (Exception e) {
System.err.println("RAG pipeline failed: " + e.getMessage());
throw new RuntimeException("Failed to process knowledge base", e);
}
}
}Load and process documents from multiple formats (JSON, text, etc.).
import org.springframework.ai.document.Document;
import org.springframework.ai.reader.JsonReader;
import org.springframework.ai.reader.TextReader;
import org.springframework.core.io.Resource;
import org.springframework.core.io.support.PathMatchingResourcePatternResolver;
import java.util.ArrayList;
import java.util.List;
public class MultiFormatDocumentLoader {
private final PathMatchingResourcePatternResolver resolver =
new PathMatchingResourcePatternResolver();
public List<Document> loadAllDocuments(String basePattern) {
List<Document> allDocuments = new ArrayList<>();
// Load JSON files
allDocuments.addAll(loadJsonDocuments(basePattern + "/**/*.json"));
// Load text files
allDocuments.addAll(loadTextDocuments(basePattern + "/**/*.txt"));
// Load markdown files
allDocuments.addAll(loadTextDocuments(basePattern + "/**/*.md"));
System.out.println("Loaded " + allDocuments.size() + " total documents");
return allDocuments;
}
private List<Document> loadJsonDocuments(String pattern) {
List<Document> documents = new ArrayList<>();
try {
Resource[] resources = resolver.getResources("classpath:" + pattern);
for (Resource resource : resources) {
JsonReader reader = new JsonReader(
resource,
"title", "content", "description"
);
List<Document> docs = reader.get();
// Add source metadata
for (Document doc : docs) {
doc.getMetadata().put("format", "json");
doc.getMetadata().put("source_file", resource.getFilename());
}
documents.addAll(docs);
}
System.out.println("Loaded " + documents.size() + " JSON documents");
} catch (Exception e) {
System.err.println("Failed to load JSON documents: " + e.getMessage());
}
return documents;
}
private List<Document> loadTextDocuments(String pattern) {
List<Document> documents = new ArrayList<>();
try {
Resource[] resources = resolver.getResources("classpath:" + pattern);
for (Resource resource : resources) {
TextReader reader = new TextReader(resource);
reader.getCustomMetadata().put("format", "text");
reader.getCustomMetadata().put("source_file", resource.getFilename());
documents.addAll(reader.get());
}
System.out.println("Loaded " + documents.size() + " text documents");
} catch (Exception e) {
System.err.println("Failed to load text documents: " + e.getMessage());
}
return documents;
}
}Optimize embedding costs by estimating and batching efficiently.
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import com.knuddels.jtokkit.api.EncodingType;
import java.util.List;
public class CostOptimizedEmbedding {
private final JTokkitTokenCountEstimator estimator;
private final TokenCountBatchingStrategy batchingStrategy;
// Example: OpenAI pricing (per 1M tokens)
private static final double EMBEDDING_COST_PER_1M_TOKENS = 0.13;
public CostOptimizedEmbedding() {
this.estimator = new JTokkitTokenCountEstimator(EncodingType.CL100K_BASE);
// Maximize batch size to minimize API calls
this.batchingStrategy = new TokenCountBatchingStrategy(
EncodingType.CL100K_BASE,
8191,
0.05 // Minimal reserve for cost optimization
);
}
public EmbeddingCostReport processDocuments(List<Document> documents) {
// Calculate total tokens
int totalTokens = 0;
for (Document doc : documents) {
totalTokens += estimator.estimate(doc.getText());
}
// Estimate cost
double estimatedCost = (totalTokens / 1_000_000.0) * EMBEDDING_COST_PER_1M_TOKENS;
// Batch documents
List<List<Document>> batches = batchingStrategy.batch(documents);
System.out.println("Cost Estimate:");
System.out.println(" Documents: " + documents.size());
System.out.println(" Total tokens: " + String.format("%,d", totalTokens));
System.out.println(" Batches: " + batches.size());
System.out.println(" Estimated cost: $" + String.format("%.4f", estimatedCost));
// Process batches
for (int i = 0; i < batches.size(); i++) {
List<Document> batch = batches.get(i);
System.out.println("Processing batch " + (i + 1) + "/" + batches.size() +
" (" + batch.size() + " documents)");
// Send to embedding API
// embedBatch(batch);
}
return new EmbeddingCostReport(
documents.size(),
batches.size(),
totalTokens,
estimatedCost
);
}
public record EmbeddingCostReport(
int documentCount,
int batchCount,
int totalTokens,
double estimatedCostUSD
) {}
}Automated evaluation system for AI responses.
import org.springframework.ai.evaluation.Evaluator;
import org.springframework.ai.evaluation.EvaluationRequest;
import org.springframework.ai.evaluation.EvaluationResponse;
import org.springframework.ai.document.Document;
import java.util.List;
import java.util.Map;
public class ResponseEvaluationSystem {
private final Evaluator evaluator;
public ResponseEvaluationSystem() {
this.evaluator = new RAGResponseEvaluator();
}
public EvaluationReport evaluateResponses(List<TestCase> testCases) {
int passed = 0;
int failed = 0;
double totalScore = 0.0;
for (TestCase testCase : testCases) {
EvaluationRequest request = new EvaluationRequest(
testCase.query(),
testCase.context(),
testCase.response()
);
EvaluationResponse result = evaluator.evaluate(request);
if (result.isPass()) {
passed++;
} else {
failed++;
}
totalScore += result.getScore();
System.out.println("Test: " + testCase.query());
System.out.println(" Result: " + (result.isPass() ? "PASS" : "FAIL"));
System.out.println(" Score: " + result.getScore());
System.out.println(" Feedback: " + result.getFeedback());
}
double averageScore = totalScore / testCases.size();
double passRate = (double) passed / testCases.size() * 100;
System.out.println("\nEvaluation Summary:");
System.out.println(" Total tests: " + testCases.size());
System.out.println(" Passed: " + passed);
System.out.println(" Failed: " + failed);
System.out.println(" Pass rate: " + String.format("%.1f%%", passRate));
System.out.println(" Average score: " + String.format("%.2f", averageScore));
return new EvaluationReport(testCases.size(), passed, failed, averageScore, passRate);
}
public record TestCase(String query, List<Document> context, String response) {}
public record EvaluationReport(
int totalTests,
int passed,
int failed,
double averageScore,
double passRate
) {}
// Simple RAG evaluator implementation
private static class RAGResponseEvaluator implements Evaluator {
@Override
public EvaluationResponse evaluate(EvaluationRequest request) {
String response = request.getResponseContent();
String query = request.getUserText();
// Simple relevance check
String[] queryWords = query.toLowerCase().split("\\s+");
String responseLower = response.toLowerCase();
long matchingWords = java.util.Arrays.stream(queryWords)
.filter(responseLower::contains)
.count();
float relevanceScore = (float) matchingWords / queryWords.length;
boolean pass = relevanceScore > 0.5f;
String feedback = pass
? "Response appears relevant to query"
: "Response may not adequately address query";
return new EvaluationResponse(
pass,
relevanceScore,
feedback,
Map.of("relevance_score", relevanceScore)
);
}
}
}Integrate observability for monitoring AI operations.
import org.springframework.ai.observation.conventions.*;
import io.micrometer.observation.Observation;
import io.micrometer.observation.ObservationRegistry;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
public class ObservableAIService {
private final ObservationRegistry observationRegistry;
private final AtomicInteger requestCount = new AtomicInteger(0);
private final AtomicLong totalTokens = new AtomicLong(0);
public ObservableAIService(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
}
public void processChatRequest(String model, String prompt, int estimatedResponseTokens) {
Observation observation = Observation.createNotStarted(
"ai.chat.request",
observationRegistry
);
observation
.lowCardinalityKeyValue(
AiObservationAttributes.AI_OPERATION_TYPE.value(),
AiOperationType.CHAT.value()
)
.lowCardinalityKeyValue(
AiObservationAttributes.AI_PROVIDER.value(),
AiProvider.OPENAI.value()
)
.lowCardinalityKeyValue(
AiObservationAttributes.REQUEST_MODEL.value(),
model
);
observation.observe(() -> {
// Simulate chat request
int inputTokens = prompt.length() / 4; // Rough estimate
int outputTokens = estimatedResponseTokens;
// Record metrics
requestCount.incrementAndGet();
totalTokens.addAndGet(inputTokens + outputTokens);
observation
.highCardinalityKeyValue(
AiObservationAttributes.USAGE_INPUT_TOKENS.value(),
String.valueOf(inputTokens)
)
.highCardinalityKeyValue(
AiObservationAttributes.USAGE_OUTPUT_TOKENS.value(),
String.valueOf(outputTokens)
);
System.out.println("Chat request processed:");
System.out.println(" Model: " + model);
System.out.println(" Input tokens: " + inputTokens);
System.out.println(" Output tokens: " + outputTokens);
});
}
public void printMetrics() {
System.out.println("\nService Metrics:");
System.out.println(" Total requests: " + requestCount.get());
System.out.println(" Total tokens: " + totalTokens.get());
System.out.println(" Avg tokens/request: " +
(requestCount.get() > 0 ? totalTokens.get() / requestCount.get() : 0));
}
}Install with Tessl CLI
npx tessl i tessl/maven-org-springframework-ai--spring-ai-commons@1.1.0