Core model interfaces and abstractions for Spring AI framework providing portable API for chat, embeddings, images, audio, and tool calling across multiple AI providers
This document provides comprehensive real-world usage examples for Spring AI Model capabilities.
@Service
public class CustomerSupportService {
private final ChatModel chatModel;
private final ChatMemory chatMemory;
public String handleCustomerQuery(String customerId, String query) {
// Get conversation history
List<Message> history = chatMemory.get(customerId);
// Add system instruction
if (history.isEmpty()) {
history = new ArrayList<>();
history.add(new SystemMessage(
"You are a helpful customer support agent for TechCorp. " +
"Be professional, empathetic, and solution-oriented."
));
}
// Add user query
UserMessage userMessage = new UserMessage(query);
history.add(userMessage);
// Create prompt with options
ChatOptions options = ChatOptions.builder()
.temperature(0.7)
.maxTokens(500)
.build();
Prompt prompt = new Prompt(history, options);
// Get response
ChatResponse response = chatModel.call(prompt);
AssistantMessage assistant = response.getResult().getOutput();
// Update memory
chatMemory.add(customerId, userMessage);
chatMemory.add(customerId, assistant);
return assistant.getText();
}
}@Service
public class DocumentQAService {
private final ChatModel chatModel;
private final EmbeddingModel embeddingModel;
private final List<Document> documentStore;
public String answerQuestion(String question) {
// 1. Embed the question
float[] questionEmbedding = embeddingModel.embed(question);
// 2. Find relevant documents (simplified similarity search)
List<Document> relevantDocs = findSimilarDocuments(questionEmbedding, 3);
// 3. Build context from documents
String context = relevantDocs.stream()
.map(Document::getText)
.collect(Collectors.joining("\n\n"));
// 4. Create prompt with context
String promptText = String.format("""
Context information:
%s
Question: %s
Answer the question based only on the context above.
""", context, question);
// 5. Get answer
return chatModel.call(promptText);
}
private List<Document> findSimilarDocuments(float[] queryEmbedding, int topK) {
return documentStore.stream()
.map(doc -> new ScoredDoc(doc, cosineSimilarity(queryEmbedding, doc.getEmbedding())))
.sorted((a, b) -> Double.compare(b.score, a.score))
.limit(topK)
.map(sd -> sd.doc)
.toList();
}
private double cosineSimilarity(float[] a, float[] b) {
double dotProduct = 0.0, normA = 0.0, normB = 0.0;
for (int i = 0; i < a.length; i++) {
dotProduct += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
record ScoredDoc(Document doc, double score) {}
}@Service
public class DataExtractionService {
private final ChatModel chatModel;
public InvoiceData extractInvoiceData(String invoiceText) {
// Create converter
BeanOutputConverter<InvoiceData> converter =
new BeanOutputConverter<>(InvoiceData.class);
// Build prompt with format instructions
String promptText = String.format("""
Extract invoice information from the following text:
%s
%s
""", invoiceText, converter.getFormat());
// Get structured response
String response = chatModel.call(promptText);
// Convert to object
return converter.convert(response);
}
record InvoiceData(
String invoiceNumber,
String date,
String vendor,
double amount,
String currency,
List<LineItem> items
) {}
record LineItem(String description, int quantity, double price) {}
}@Service
public class AgentService {
private final ChatModel chatModel;
private final ToolCallbackResolver toolResolver;
public String executeAgentTask(String task) {
List<Message> conversation = new ArrayList<>();
conversation.add(new UserMessage(task));
int maxIterations = 10;
int iteration = 0;
while (iteration < maxIterations) {
// Call model
Prompt prompt = new Prompt(conversation);
ChatResponse response = chatModel.call(prompt);
AssistantMessage assistant = response.getResult().getOutput();
conversation.add(assistant);
// Check if tool calls needed
if (!assistant.hasToolCalls()) {
// Task complete
return assistant.getText();
}
// Execute tool calls
List<ToolResponse> toolResponses = new ArrayList<>();
for (AssistantMessage.ToolCall toolCall : assistant.getToolCalls()) {
ToolCallback tool = toolResolver.resolve(toolCall.name());
String result = tool.call(toolCall.arguments());
toolResponses.add(new ToolResponseMessage.ToolResponse(
toolCall.id(),
toolCall.name(),
result
));
}
// Add tool responses to conversation
conversation.add(new ToolResponseMessage(toolResponses));
iteration++;
}
return "Task exceeded maximum iterations";
}
}
@Component
public class AgentTools {
@Tool(description = "Search the web for information")
public String webSearch(@ToolParam(description = "Search query") String query) {
// Implementation
return "Search results for: " + query;
}
@Tool(description = "Calculate mathematical expressions")
public String calculate(@ToolParam(description = "Expression") String expression) {
// Implementation
return "Result: " + evaluateExpression(expression);
}
@Tool(description = "Query database")
public String queryDatabase(@ToolParam(description = "SQL query") String query) {
// Implementation
return "Database results";
}
}@Service
public class ImageAnalysisService {
private final ChatModel chatModel;
public String analyzeImage(String imageUrl, String question) {
// Create media for image
Media imageMedia = new Media(MimeType.IMAGE_JPEG, imageUrl);
// Create user message with image and question
UserMessage userMessage = new UserMessage(
question,
List.of(imageMedia)
);
// Call model
Prompt prompt = new Prompt(userMessage);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getText();
}
public String compareImages(String imageUrl1, String imageUrl2) {
Media image1 = new Media(MimeType.IMAGE_JPEG, imageUrl1);
Media image2 = new Media(MimeType.IMAGE_JPEG, imageUrl2);
UserMessage message = new UserMessage(
"Compare these two images and describe the key differences.",
List.of(image1, image2)
);
return chatModel.call(new Prompt(message))
.getResult().getOutput().getText();
}
}@Service
public class ContentModerationPipeline {
private final ModerationModel moderationModel;
private final ChatModel chatModel;
public ProcessedContent processUserContent(String content) {
// Step 1: Check content safety
ModerationPrompt moderationPrompt = new ModerationPrompt(content);
ModerationResponse modResponse = moderationModel.call(moderationPrompt);
Moderation moderation = modResponse.getResult().getOutput();
if (moderation.isFlagged()) {
Categories categories = moderation.getCategories();
List<String> violations = new ArrayList<>();
if (categories.isHate()) violations.add("hate speech");
if (categories.isViolence()) violations.add("violence");
if (categories.isSexual()) violations.add("sexual content");
return new ProcessedContent(
content,
false,
"Content blocked: " + String.join(", ", violations),
null
);
}
// Step 2: Process safe content
String processed = chatModel.call(
"Summarize this user content: " + content
);
return new ProcessedContent(content, true, "Approved", processed);
}
record ProcessedContent(
String original,
boolean approved,
String status,
String processed
) {}
}@Service
public class SemanticSearchService {
private final EmbeddingModel embeddingModel;
private final List<IndexedDocument> index = new ArrayList<>();
public void indexDocuments(List<String> documents) {
// Generate embeddings
List<float[]> embeddings = embeddingModel.embed(documents);
// Store in index
for (int i = 0; i < documents.size(); i++) {
index.add(new IndexedDocument(
"doc_" + i,
documents.get(i),
embeddings.get(i)
));
}
}
public List<SearchResult> search(String query, int topK) {
// Embed query
float[] queryEmbedding = embeddingModel.embed(query);
// Score all documents
return index.stream()
.map(doc -> new SearchResult(
doc.id(),
doc.content(),
cosineSimilarity(queryEmbedding, doc.embedding())
))
.sorted((a, b) -> Double.compare(b.score(), a.score()))
.limit(topK)
.toList();
}
record IndexedDocument(String id, String content, float[] embedding) {}
record SearchResult(String id, String content, double score) {}
}@Service
public class ContentGenerationService {
private final ChatModel chatModel;
private final ImageModel imageModel;
private final TextToSpeechModel ttsModel;
public MultiModalContent generateContent(String topic) {
// 1. Generate text
String text = chatModel.call(
"Write a 200-word article about: " + topic
);
// 2. Generate image
ImagePrompt imagePrompt = new ImagePrompt(
"Create an illustration for: " + topic
);
String imageUrl = imageModel.call(imagePrompt)
.getResult().getOutput().getUrl();
// 3. Generate audio narration
byte[] audio = ttsModel.call(text);
return new MultiModalContent(text, imageUrl, audio);
}
record MultiModalContent(String text, String imageUrl, byte[] audio) {}
}@Service
public class BatchProcessingService {
private final ChatModel chatModel;
private static final long RATE_LIMIT_DELAY_MS = 1000;
public List<ProcessedItem> processBatch(List<String> items) {
List<ProcessedItem> results = new ArrayList<>();
for (String item : items) {
try {
// Process item
Prompt prompt = new Prompt("Process: " + item);
ChatResponse response = chatModel.call(prompt);
// Check rate limits
RateLimit rateLimit = response.getMetadata().getRateLimit();
if (rateLimit != null && rateLimit.getRequestsRemaining() != null) {
if (rateLimit.getRequestsRemaining() < 5) {
// Approaching limit, slow down
Thread.sleep(RATE_LIMIT_DELAY_MS * 2);
}
}
results.add(new ProcessedItem(
item,
true,
response.getResult().getOutput().getText(),
response.getMetadata().getUsage().getTotalTokens()
));
// Brief delay between requests
Thread.sleep(RATE_LIMIT_DELAY_MS);
} catch (Exception e) {
results.add(new ProcessedItem(item, false, e.getMessage(), 0));
}
}
return results;
}
record ProcessedItem(String input, boolean success, String output, int tokensUsed) {}
}@Service
public class AudioProcessingService {
private final TranscriptionModel transcriptionModel;
private final ChatModel chatModel;
private final TextToSpeechModel ttsModel;
public AudioProcessingResult processAudio(Resource audioFile) {
// 1. Transcribe audio
String transcription = transcriptionModel.transcribe(audioFile);
// 2. Analyze transcription
String analysis = chatModel.call(
"Analyze this transcription and provide key insights: " + transcription
);
// 3. Generate summary
String summary = chatModel.call(
"Summarize in one sentence: " + transcription
);
// 4. Generate audio summary
byte[] audioSummary = ttsModel.call(summary);
return new AudioProcessingResult(
transcription,
analysis,
summary,
audioSummary
);
}
record AudioProcessingResult(
String transcription,
String analysis,
String summary,
byte[] audioSummary
) {}
}@Service
public class DocumentEnrichmentPipeline {
private final ChatModel chatModel;
private final EmbeddingModel embeddingModel;
public List<EnrichedDocument> enrichDocuments(List<String> rawDocuments) {
// Create documents
List<Document> documents = rawDocuments.stream()
.map(Document::new)
.toList();
// Step 1: Extract keywords
KeywordMetadataEnricher keywordEnricher =
KeywordMetadataEnricher.builder(chatModel)
.keywordCount(5)
.build();
documents = keywordEnricher.apply(documents);
// Step 2: Generate summaries
SummaryMetadataEnricher summaryEnricher = new SummaryMetadataEnricher(
chatModel,
List.of(SummaryType.CURRENT)
);
documents = summaryEnricher.apply(documents);
// Step 3: Generate embeddings
List<float[]> embeddings = embeddingModel.embed(
documents.stream().map(Document::getText).toList()
);
for (int i = 0; i < documents.size(); i++) {
documents.get(i).setEmbedding(embeddings.get(i));
}
// Return enriched documents
return documents.stream()
.map(doc -> new EnrichedDocument(
doc.getId(),
doc.getText(),
(String) doc.getMetadata().get("excerpt_keywords"),
(String) doc.getMetadata().get("section_summary"),
doc.getEmbedding()
))
.toList();
}
record EnrichedDocument(
String id,
String content,
String keywords,
String summary,
float[] embedding
) {}
}@RestController
@RequestMapping("/api/chat")
public class StreamingChatController {
private final ChatModel chatModel;
private final ChatMemory chatMemory;
@GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<ServerSentEvent<String>> streamChat(
@RequestParam String userId,
@RequestParam String message
) {
// Get conversation history
List<Message> history = new ArrayList<>(chatMemory.get(userId));
UserMessage userMsg = new UserMessage(message);
history.add(userMsg);
// Save user message
chatMemory.add(userId, userMsg);
// Stream response
Prompt prompt = new Prompt(history);
StringBuilder fullResponse = new StringBuilder();
return chatModel.stream(prompt)
.map(response -> response.getResult().getOutput().getText())
.map(text -> {
fullResponse.append(text);
return ServerSentEvent.builder(text)
.event("message")
.build();
})
.concatWith(Mono.fromCallable(() -> {
// Save complete assistant message
chatMemory.add(userId, new AssistantMessage(fullResponse.toString()));
return ServerSentEvent.<String>builder()
.event("done")
.build();
}));
}
}@Service
public class ImageGenerationService {
private final ImageModel imageModel;
private final StorageService storageService;
public GeneratedImage generateAndStore(ImageRequest request) {
// Configure image options
ImageOptions options = ImageOptions.builder()
.width(request.width())
.height(request.height())
.style(request.style())
.n(1)
.build();
// Generate image
ImagePrompt prompt = new ImagePrompt(request.description(), options);
ImageResponse response = imageModel.call(prompt);
Image image = response.getResult().getOutput();
// Download and store
byte[] imageData = downloadImage(image.getUrl());
String storedPath = storageService.store(imageData, "image.png");
return new GeneratedImage(
storedPath,
image.getUrl(),
request.description()
);
}
private byte[] downloadImage(String url) {
try (InputStream in = new URL(url).openStream()) {
return in.readAllBytes();
} catch (Exception e) {
throw new RuntimeException("Failed to download image", e);
}
}
record ImageRequest(String description, int width, int height, String style) {}
record GeneratedImage(String storedPath, String originalUrl, String prompt) {}
}@Service
public class ContextualToolService {
private final ChatModel chatModel;
public String processWithUserContext(String userId, String sessionId, String request) {
// Build tool context
Map<String, Object> toolContext = Map.of(
"userId", userId,
"sessionId", sessionId,
"timestamp", System.currentTimeMillis()
);
// Configure options
ToolCallingChatOptions options = ToolCallingChatOptions.builder()
.toolCallbacks(getUserTools(), getDatabaseTools())
.toolContext(toolContext)
.internalToolExecutionEnabled(true)
.temperature(0.7)
.build();
// Call model with tools
Prompt prompt = new Prompt(request, options);
ChatResponse response = chatModel.call(prompt);
return response.getResult().getOutput().getText();
}
private List<ToolCallback> getUserTools() {
return new MethodToolCallbackProvider(new UserTools()).getToolCallbacks();
}
private List<ToolCallback> getDatabaseTools() {
return new MethodToolCallbackProvider(new DatabaseTools()).getToolCallbacks();
}
}
@Component
class UserTools {
@Tool(description = "Get user profile information")
public String getUserProfile(@ToolParam(description = "User ID") String userId) {
// Access userId from context if needed
return "{\"name\": \"John Doe\", \"email\": \"john@example.com\"}";
}
}@Service
public class MonitoredAiService {
private final ChatModel chatModel;
private final MeterRegistry meterRegistry;
private final ObservationRegistry observationRegistry;
public String monitoredChat(String message) {
Timer.Sample sample = Timer.start(meterRegistry);
try {
// Create observation context
Prompt prompt = new Prompt(message);
ChatModelObservationContext context =
new ChatModelObservationContext(prompt, chatModel);
// Execute with observation
return Observation.createNotStarted("ai.chat.call", context, observationRegistry)
.observe(() -> {
ChatResponse response = chatModel.call(prompt);
context.setResponse(response);
// Record metrics
Usage usage = response.getMetadata().getUsage();
meterRegistry.counter("ai.tokens.used",
"type", "total").increment(usage.getTotalTokens());
return response.getResult().getOutput().getText();
});
} finally {
sample.stop(Timer.builder("ai.chat.duration")
.tag("model", "chat")
.register(meterRegistry));
}
}
}@Service
public class ResilientAiService {
private final ChatModel chatModel;
public String chatWithRetry(String message) {
int maxRetries = 3;
int attempt = 0;
long delay = 1000; // Start with 1 second
while (attempt < maxRetries) {
try {
ChatResponse response = chatModel.call(new Prompt(message));
// Check if we should slow down
RateLimit rateLimit = response.getMetadata().getRateLimit();
if (rateLimit != null && rateLimit.getRequestsRemaining() != null) {
if (rateLimit.getRequestsRemaining() < 10) {
Thread.sleep(2000); // Proactive slowdown
}
}
return response.getResult().getOutput().getText();
} catch (Exception e) {
attempt++;
if (attempt >= maxRetries) {
throw new RuntimeException("Max retries exceeded", e);
}
try {
Thread.sleep(delay);
delay *= 2; // Exponential backoff
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
throw new RuntimeException("Interrupted during retry", ie);
}
}
}
throw new RuntimeException("Unexpected retry loop exit");
}
}For more scenarios including edge cases, error handling patterns, and advanced configurations, see: