Core model interfaces and abstractions for Spring AI framework providing portable API for chat, embeddings, images, audio, and tool calling across multiple AI providers
Production-ready recommendations for using Spring AI Model effectively.
public String robustAiCall(String input) {
try {
return chatModel.call(input);
} catch (Exception e) {
log.error("AI call failed for input: {}", truncate(input), e);
// Return fallback or rethrow based on your requirements
return "I apologize, but I'm having trouble processing your request.";
}
}@Service
public class TokenTrackingService {
private final AtomicLong totalTokens = new AtomicLong(0);
public String trackedChat(String message) {
ChatResponse response = chatModel.call(new Prompt(message));
Usage usage = response.getMetadata().getUsage();
if (usage != null && usage.getTotalTokens() != null) {
totalTokens.addAndGet(usage.getTotalTokens());
// Alert if threshold exceeded
if (totalTokens.get() > 1_000_000) {
alertHighUsage();
}
}
return response.getResult().getOutput().getText();
}
}@ConfigurationProperties(prefix = "app.ai")
public class AiProperties {
private String defaultModel = "gpt-4";
private double defaultTemperature = 0.7;
private int maxTokens = 2000;
private int maxRetries = 3;
// Getters and setters
}
@Service
public class ConfiguredAiService {
private final ChatModel chatModel;
private final AiProperties properties;
public String chat(String message) {
ChatOptions options = ChatOptions.builder()
.model(properties.getDefaultModel())
.temperature(properties.getDefaultTemperature())
.maxTokens(properties.getMaxTokens())
.build();
return chatModel.call(new Prompt(message, options))
.getResult().getOutput().getText();
}
}@Service
public class RateLimitedService {
private final ChatModel chatModel;
private final RateLimiter rateLimiter = RateLimiter.create(10.0); // 10 requests/second
public String rateLimitedChat(String message) {
rateLimiter.acquire();
return chatModel.call(message);
}
}// Prefer streaming for better user experience
public Flux<String> streamForLongContent(String message) {
return chatModel.stream(new Prompt(message))
.map(response -> response.getResult().getOutput().getText());
}
// vs synchronous call that blocks
public String blockingLongContent(String message) {
return chatModel.call(message); // User waits for complete response
}// Good: Clear system instruction
SystemMessage system = new SystemMessage(
"You are a professional customer support agent. " +
"Always be polite, clear, and solution-oriented. " +
"If you don't know something, admit it rather than guessing."
);
// Avoid: Vague or overly long system messages// Good: Clear structure
String prompt = """
Context: [relevant background]
Task: [what you want the AI to do]
Constraints: [any limitations or requirements]
Input: [the actual user input]
""";
// vs unclear, unstructured promptspublic String validatedChat(String input) {
// Validate
if (input == null || input.isBlank()) {
throw new IllegalArgumentException("Input cannot be empty");
}
if (input.length() > 10000) {
throw new IllegalArgumentException("Input too long");
}
// Sanitize if needed
String sanitized = input.replaceAll("[\\p{Cntrl}&&[^\\r\\n\\t]]", "");
return chatModel.call(sanitized);
}// For short conversations (customer support sessions)
ChatMemory shortTerm = new MessageWindowChatMemory(repository, 20);
// For long conversations (ongoing projects)
ChatMemory longTerm = new MessageWindowChatMemory(repository, 100);
// For summarization-based memory (unlimited history)
ChatMemory summarized = new SummarizingChatMemory(repository, chatModel, 50);@Service
public class SessionAwareChat {
private final ChatMemory chatMemory;
public void endSession(String userId) {
// Clear memory when conversation ends
chatMemory.clear(userId);
log.info("Cleared chat memory for user: {}", userId);
}
@Scheduled(cron = "0 0 * * * *") // Every hour
public void cleanupStaleMemory() {
// Cleanup logic for old conversations
}
}// Good: Batch processing
List<float[]> embeddings = embeddingModel.embed(allTexts);
// Avoid: Individual calls in loop
for (String text : allTexts) {
float[] embedding = embeddingModel.embed(text); // Inefficient
}public float[] embedWithNormalization(String text) {
// Normalize: lowercase, trim, remove extra whitespace
String normalized = text.toLowerCase()
.trim()
.replaceAll("\\s+", " ");
return embeddingModel.embed(normalized);
}@Service
public class EmbeddingCacheService {
private final ConcurrentHashMap<String, float[]> cache = new ConcurrentHashMap<>();
private final EmbeddingModel embeddingModel;
public float[] getCachedEmbedding(String text) {
return cache.computeIfAbsent(text, embeddingModel::embed);
}
}// Good: Detailed, clear description
@Tool(description = "Get the current weather forecast for a specific city. " +
"Returns temperature, conditions, and humidity.")
public String getWeather(@ToolParam(description = "City name, e.g., 'London' or 'New York'") String city) {
// ...
}
// Avoid: Vague descriptions
@Tool(description = "Get weather")
public String getWeather(String city) { // No parameter description
// ...
}@Tool(description = "Calculate sum of two numbers")
public String add(
@ToolParam(description = "First number") double a,
@ToolParam(description = "Second number") double b
) {
// Validation
if (Double.isNaN(a) || Double.isNaN(b)) {
throw new ToolExecutionException("Invalid numbers provided");
}
if (Double.isInfinite(a) || Double.isInfinite(b)) {
throw new ToolExecutionException("Numbers too large");
}
double result = a + b;
return "{\"result\": " + result + "}";
}// Good: Structured JSON
@Tool(description = "Get user information")
public String getUser(@ToolParam(description = "User ID") String userId) {
User user = userRepository.findById(userId);
return new ObjectMapper().writeValueAsString(Map.of(
"id", user.getId(),
"name", user.getName(),
"email", user.getEmail(),
"status", "success"
));
}
// Avoid: Plain text that's hard to parse
public String getUser(String userId) {
return "User: John, Email: john@example.com";
}// For factual/deterministic tasks: low temperature
ChatOptions factual = ChatOptions.builder()
.temperature(0.1)
.build();
// For creative tasks: higher temperature
ChatOptions creative = ChatOptions.builder()
.temperature(0.9)
.build();// For short responses (summaries, classifications)
ChatOptions shortResponse = ChatOptions.builder()
.maxTokens(100)
.build();
// For detailed responses
ChatOptions detailedResponse = ChatOptions.builder()
.maxTokens(2000)
.build();@Configuration
public class AiConnectionConfig {
@Bean
public RestTemplate aiRestTemplate() {
HttpComponentsClientHttpRequestFactory factory =
new HttpComponentsClientHttpRequestFactory();
factory.setConnectionRequestTimeout(5000);
factory.setConnectTimeout(5000);
factory.setReadTimeout(30000);
return new RestTemplate(factory);
}
}public String secureChat(String message) {
// Don't log full message if it might contain PII
log.info("Processing AI request (length: {})", message.length());
try {
String response = chatModel.call(message);
log.info("AI response generated (length: {})", response.length());
return response;
} catch (Exception e) {
log.error("AI call failed", e); // Don't include message in error log
throw e;
}
}public String sanitizedChat(String userInput) {
// Remove control characters
String sanitized = userInput.replaceAll("[\\p{Cntrl}&&[^\\r\\n\\t]]", "");
// Limit length
if (sanitized.length() > 10000) {
sanitized = sanitized.substring(0, 10000);
}
return chatModel.call(sanitized);
}# application.properties
spring.ai.openai.api-key=${OPENAI_API_KEY}
# Never hardcode:
# spring.ai.openai.api-key=sk-proj-abc123... # DON'T DO THIS@SpringBootTest
public class AiServiceTest {
@MockBean
private ChatModel chatModel;
@Autowired
private AiService aiService;
@Test
public void testChat() {
// Mock response
ChatResponse mockResponse = new ChatResponse(
List.of(new Generation(new AssistantMessage("Test response")))
);
when(chatModel.call(any(Prompt.class))).thenReturn(mockResponse);
// Test
String result = aiService.chat("Test message");
assertEquals("Test response", result);
}
}@SpringBootTest
@Tag("integration")
@Tag("expensive")
public class RealAiIntegrationTest {
@Autowired
private ChatModel chatModel;
@Test
@EnabledIf("#{systemProperties['run.expensive.tests'] == 'true'}")
public void testRealApiCall() {
String response = chatModel.call("Test message");
assertNotNull(response);
}
}@Service
public class MetricsTrackingService {
private final MeterRegistry registry;
public String trackAndCall(String message) {
Counter requests = registry.counter("ai.requests", "model", "chat");
requests.increment();
Timer timer = registry.timer("ai.duration", "model", "chat");
return timer.record(() -> {
ChatResponse response = chatModel.call(new Prompt(message));
// Track tokens
Usage usage = response.getMetadata().getUsage();
if (usage != null) {
registry.counter("ai.tokens", "type", "total")
.increment(usage.getTotalTokens());
}
return response.getResult().getOutput().getText();
});
}
}@Component
public class AiAlertingService {
private final MeterRegistry registry;
@Scheduled(fixedRate = 60000) // Check every minute
public void checkMetrics() {
// Check error rate
Counter errors = registry.counter("ai.errors");
Counter requests = registry.counter("ai.requests");
double errorRate = errors.count() / Math.max(1, requests.count());
if (errorRate > 0.1) { // > 10% error rate
alertTeam("High AI error rate: " + errorRate);
}
// Check cost
Counter tokens = registry.counter("ai.tokens");
double estimatedCost = (tokens.count() / 1000.0) * 0.03; // Example rate
if (estimatedCost > 100) { // $100 threshold
alertTeam("High AI usage cost: $" + estimatedCost);
}
}
}public String processAudioFile(MultipartFile file) {
File tempFile = null;
try {
tempFile = File.createTempFile("audio", ".mp3");
file.transferTo(tempFile);
Resource resource = new FileSystemResource(tempFile);
return transcriptionModel.transcribe(resource);
} catch (Exception e) {
throw new RuntimeException("Processing failed", e);
} finally {
if (tempFile != null && tempFile.exists()) {
tempFile.delete();
}
}
}@Configuration
public class AiTimeoutConfig {
@Bean
public ChatModel configuredChatModel(ChatModel delegate) {
// Wrap with timeout logic
return new TimeoutChatModelWrapper(delegate, Duration.ofSeconds(30));
}
}@Service
public class CachedChatService {
private final LoadingCache<String, String> cache = Caffeine.newBuilder()
.maximumSize(1000)
.expireAfterWrite(1, TimeUnit.HOURS)
.build(this::computeResponse);
private String computeResponse(String message) {
return chatModel.call(message);
}
public String getCachedResponse(String message) {
return cache.get(message);
}
}@Service
public class OptimizedModelSelection {
@Autowired @Qualifier("gpt-4")
private ChatModel advancedModel;
@Autowired @Qualifier("gpt-3.5-turbo")
private ChatModel basicModel;
public String smartChat(String message, TaskComplexity complexity) {
ChatModel model = switch (complexity) {
case SIMPLE -> basicModel; // Use cheaper model
case COMPLEX -> advancedModel; // Use advanced model when needed
};
return model.call(message);
}
enum TaskComplexity { SIMPLE, COMPLEX }
}// Good: Single batch call
List<float[]> embeddings = embeddingModel.embed(allTexts);
// Avoid: Multiple individual calls
List<float[]> embeddings = new ArrayList<>();
for (String text : allTexts) {
embeddings.add(embeddingModel.embed(text)); // Inefficient, multiple API calls
}// Good: Clear, focused model
record ProductInfo(
String name,
double price,
String category
) {}
// Avoid: Overly complex or vague models
record ComplexData(
Map<String, List<Map<String, Object>>> data,
Object misc
) {}public <T> T safeConvert(String response, Class<T> clazz) {
CompositeResponseTextCleaner cleaner = CompositeResponseTextCleaner.builder()
.add(new ThinkingTagCleaner())
.add(new MarkdownCodeBlockCleaner())
.add(new WhitespaceCleaner())
.build();
String cleaned = cleaner.clean(response);
return new BeanOutputConverter<>(clazz).convert(cleaned);
}public String observableChat(String message) {
MDC.put("ai.model", "gpt-4");
MDC.put("ai.operation", "chat");
try {
ChatResponse response = chatModel.call(new Prompt(message));
MDC.put("ai.tokens", String.valueOf(response.getMetadata().getUsage().getTotalTokens()));
MDC.put("ai.finishReason", response.getResult().getMetadata().getFinishReason());
log.info("AI chat completed successfully");
return response.getResult().getOutput().getText();
} catch (Exception e) {
log.error("AI chat failed", e);
throw e;
} finally {
MDC.clear();
}
}@Service
public class BusinessMetricsService {
private final MeterRegistry registry;
public void trackConversation(String userId, ChatResponse response) {
// Business metrics
registry.counter("conversations.total", "user", userId).increment();
// Quality metrics
String finishReason = response.getResult().getMetadata().getFinishReason();
registry.counter("conversations.finish_reason",
"reason", finishReason).increment();
// Cost metrics
Usage usage = response.getMetadata().getUsage();
if (usage != null) {
double cost = estimateCost(usage);
registry.counter("ai.cost.usd").increment(cost);
}
}
}@Component
public class AiHealthIndicator implements HealthIndicator {
private final ChatModel chatModel;
@Override
public Health health() {
try {
// Simple health check
String response = chatModel.call("ping");
return Health.up()
.withDetail("status", "AI model responding")
.withDetail("response", response.substring(0, Math.min(50, response.length())))
.build();
} catch (Exception e) {
return Health.down()
.withDetail("error", e.getMessage())
.build();
}
}
}@Service
public class FeatureFlaggedService {
@Value("${features.ai.enabled:true}")
private boolean aiEnabled;
@Value("${features.ai.tools.enabled:false}")
private boolean toolsEnabled;
public String conditionalChat(String message) {
if (!aiEnabled) {
return getFallbackResponse(message);
}
ChatOptions.Builder optionsBuilder = ChatOptions.builder()
.temperature(0.7);
if (toolsEnabled) {
optionsBuilder.toolCallbacks(getAvailableTools());
}
return chatModel.call(new Prompt(message, optionsBuilder.build()))
.getResult().getOutput().getText();
}
}@Service
public class GracefulDegradationService {
private final ChatModel primaryModel;
private final Cache<String, String> responseCache;
public String resilientChat(String message) {
// Try cache first
String cached = responseCache.getIfPresent(message);
if (cached != null) {
return cached;
}
try {
// Try primary model
String response = primaryModel.call(message);
responseCache.put(message, response);
return response;
} catch (Exception e) {
log.warn("Primary model unavailable, using fallback", e);
// Fallback to simpler response
return generateFallbackResponse(message);
}
}
private String generateFallbackResponse(String message) {
return "I'm currently experiencing technical difficulties. " +
"Please try again in a few moments.";
}
}Key takeaways for production usage:
For more details, see: