OpenAI models support for Spring AI, providing comprehensive integration for chat completion, embeddings, image generation, audio transcription, text-to-speech, and content moderation capabilities within Spring Boot applications.
Production-ready patterns and use cases for Spring AI OpenAI integration.
@RestController
@RequestMapping("/api/ai")
public class AiController {
private final OpenAiChatModel chatModel;
private final OpenAiEmbeddingModel embeddingModel;
public AiController(OpenAiChatModel chatModel, OpenAiEmbeddingModel embeddingModel) {
this.chatModel = chatModel;
this.embeddingModel = embeddingModel;
}
@PostMapping("/chat")
public ChatResponseDto chat(@RequestBody ChatRequestDto request) {
var response = chatModel.call(new Prompt(request.getMessage()));
return new ChatResponseDto(
response.getResult().getOutput().getContent(),
response.getMetadata().getUsage().getTotalTokens()
);
}
@PostMapping("/embed")
public EmbeddingResponseDto embed(@RequestBody String text) {
var embedding = embeddingModel.embed(new Document(text));
return new EmbeddingResponseDto(embedding);
}
}import org.springframework.scheduling.annotation.Async;
import java.util.concurrent.CompletableFuture;
@Service
public class AsyncAiService {
private final OpenAiChatModel chatModel;
@Async
public CompletableFuture<String> generateAsync(String prompt) {
var response = chatModel.call(new Prompt(prompt));
return CompletableFuture.completedFuture(
response.getResult().getOutput().getContent()
);
}
@Async
public CompletableFuture<List<String>> batchGenerate(List<String> prompts) {
return CompletableFuture.completedFuture(
prompts.stream()
.map(prompt -> chatModel.call(new Prompt(prompt)))
.map(response -> response.getResult().getOutput().getContent())
.toList()
);
}
}import reactor.core.publisher.Flux;
@RestController
public class StreamingController {
private final OpenAiChatModel chatModel;
@GetMapping(value = "/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public Flux<String> stream(@RequestParam String message) {
return chatModel.stream(new Prompt(message))
.map(response -> response.getResult().getOutput().getContent())
.filter(content -> content != null && !content.isEmpty());
}
}@Service
public class BatchProcessor {
private final OpenAiEmbeddingModel embeddingModel;
public Map<String, float[]> batchEmbed(List<String> texts) {
Map<String, float[]> results = new HashMap<>();
// Process in batches of 100 (API limit)
for (int i = 0; i < texts.size(); i += 100) {
List<String> batch = texts.subList(i, Math.min(i + 100, texts.size()));
var request = new EmbeddingRequest(
batch,
OpenAiEmbeddingOptions.builder()
.model(OpenAiApi.EmbeddingModel.TEXT_EMBEDDING_3_SMALL.getValue())
.build()
);
var response = embeddingModel.call(request);
for (int j = 0; j < batch.size(); j++) {
results.put(batch.get(j), response.getData().get(j).getEmbedding());
}
}
return results;
}
}import org.springframework.ai.openai.api.common.OpenAiApiClientErrorException;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpServerErrorException;
@Service
public class RobustAiService {
private final OpenAiChatModel chatModel;
public String generateWithErrorHandling(String prompt) {
try {
var response = chatModel.call(new Prompt(prompt));
return response.getResult().getOutput().getContent();
} catch (OpenAiApiClientErrorException e) {
return handleApiError(e);
} catch (HttpServerErrorException e) {
return handleServerError(e);
} catch (Exception e) {
return handleUnexpectedError(e);
}
}
private String handleApiError(OpenAiApiClientErrorException e) {
return switch (e.getStatusCode()) {
case 401 -> "Authentication failed. Check API key.";
case 429 -> "Rate limit exceeded. Please try again later.";
case 400 -> "Invalid request: " + e.getMessage();
case 404 -> "Resource not found.";
default -> "API error: " + e.getMessage();
};
}
private String handleServerError(HttpServerErrorException e) {
log.error("OpenAI service error", e);
return "Service temporarily unavailable. Please try again.";
}
private String handleUnexpectedError(Exception e) {
log.error("Unexpected error", e);
return "An unexpected error occurred.";
}
}import org.springframework.retry.support.RetryTemplate;
import org.springframework.retry.backoff.ExponentialBackOffPolicy;
import org.springframework.retry.policy.SimpleRetryPolicy;
@Configuration
public class RetryConfig {
@Bean
public RetryTemplate retryTemplate() {
var backoffPolicy = new ExponentialBackOffPolicy();
backoffPolicy.setInitialInterval(1000); // 1 second
backoffPolicy.setMultiplier(2.0);
backoffPolicy.setMaxInterval(10000); // 10 seconds
var retryPolicy = new SimpleRetryPolicy();
retryPolicy.setMaxAttempts(3);
var retryTemplate = new RetryTemplate();
retryTemplate.setBackOffPolicy(backoffPolicy);
retryTemplate.setRetryPolicy(retryPolicy);
return retryTemplate;
}
@Bean
public OpenAiChatModel chatModel(OpenAiApi openAiApi, RetryTemplate retryTemplate) {
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.retryTemplate(retryTemplate)
.build();
}
}import io.github.resilience4j.circuitbreaker.annotation.CircuitBreaker;
@Service
public class ResilientAiService {
private final OpenAiChatModel chatModel;
@CircuitBreaker(name = "openai", fallbackMethod = "fallbackGenerate")
public String generate(String prompt) {
return chatModel.call(new Prompt(prompt))
.getResult()
.getOutput()
.getContent();
}
public String fallbackGenerate(String prompt, Exception e) {
log.warn("Circuit breaker activated, using fallback", e);
return "Service temporarily unavailable. Please try again later.";
}
}# application.properties
spring.ai.openai.api-key=${OPENAI_API_KEY}
spring.ai.openai.base-url=https://api.openai.com
spring.ai.openai.organization-id=${OPENAI_ORG_ID:}
spring.ai.openai.project-id=${OPENAI_PROJECT_ID:}
# Chat model configuration
spring.ai.openai.chat.enabled=true
spring.ai.openai.chat.options.model=gpt-4o
spring.ai.openai.chat.options.temperature=0.7
spring.ai.openai.chat.options.max-tokens=1000
# Embedding model configuration
spring.ai.openai.embedding.enabled=true
spring.ai.openai.embedding.options.model=text-embedding-3-small@Configuration
public class AiConfig {
@Bean
public OpenAiApi openAiApi(@Value("${openai.api.key}") String apiKey) {
return OpenAiApi.builder()
.apiKey(apiKey)
.build();
}
@Bean
public OpenAiChatModel chatModel(
OpenAiApi openAiApi,
RetryTemplate retryTemplate,
ObservationRegistry observationRegistry
) {
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.defaultOptions(OpenAiChatOptions.builder()
.model(OpenAiApi.ChatModel.GPT_4_O.getValue())
.temperature(0.7)
.maxTokens(1000)
.build())
.retryTemplate(retryTemplate)
.observationRegistry(observationRegistry)
.build();
}
}import io.micrometer.observation.ObservationRegistry;
import io.micrometer.core.instrument.MeterRegistry;
@Configuration
public class ObservabilityConfig {
@Bean
public OpenAiChatModel observableChatModel(
OpenAiApi openAiApi,
ObservationRegistry observationRegistry
) {
return OpenAiChatModel.builder()
.openAiApi(openAiApi)
.observationRegistry(observationRegistry)
.build();
}
}
// Metrics will be automatically collected:
// - gen_ai_client_token_usage
// - gen_ai_client_operation_duration
// - gen_ai_client_operation_countimport org.springframework.web.client.RestClient;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.apache.hc.client5.http.impl.classic.HttpClients;
import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager;
@Configuration
public class PerformanceConfig {
@Bean
public OpenAiApi optimizedOpenAiApi(@Value("${openai.api.key}") String apiKey) {
// Configure connection pooling
var connectionManager = new PoolingHttpClientConnectionManager();
connectionManager.setMaxTotal(100);
connectionManager.setDefaultMaxPerRoute(20);
var httpClient = HttpClients.custom()
.setConnectionManager(connectionManager)
.build();
var requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
requestFactory.setConnectTimeout(Duration.ofSeconds(10));
requestFactory.setReadTimeout(Duration.ofSeconds(60));
var restClient = RestClient.builder()
.requestFactory(requestFactory)
.build();
return OpenAiApi.builder()
.apiKey(apiKey)
.restClient(restClient)
.build();
}
}import org.springframework.cache.annotation.Cacheable;
@Service
public class CachedAiService {
private final OpenAiEmbeddingModel embeddingModel;
@Cacheable(value = "embeddings", key = "#text")
public float[] getEmbedding(String text) {
return embeddingModel.embed(new Document(text));
}
@Cacheable(value = "chat-responses", key = "#prompt")
public String getChatResponse(String prompt) {
return chatModel.call(new Prompt(prompt))
.getResult()
.getOutput()
.getContent();
}
}import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
@Service
public class ParallelAiService {
private final OpenAiChatModel chatModel;
private final Executor taskExecutor;
public List<String> processInParallel(List<String> prompts) {
List<CompletableFuture<String>> futures = prompts.stream()
.map(prompt -> CompletableFuture.supplyAsync(
() -> chatModel.call(new Prompt(prompt))
.getResult()
.getOutput()
.getContent(),
taskExecutor
))
.toList();
return futures.stream()
.map(CompletableFuture::join)
.toList();
}
}@Service
public class CostTrackingService {
private static final double PROMPT_TOKEN_PRICE = 0.00001; // $0.01 per 1K tokens
private static final double COMPLETION_TOKEN_PRICE = 0.00003; // $0.03 per 1K tokens
private final OpenAiChatModel chatModel;
private final MeterRegistry meterRegistry;
public String generateWithCostTracking(String prompt) {
var response = chatModel.call(new Prompt(prompt));
var usage = response.getMetadata().getUsage();
// Track token usage
meterRegistry.counter("ai.tokens.prompt", "model", "gpt-4o")
.increment(usage.getPromptTokens());
meterRegistry.counter("ai.tokens.completion", "model", "gpt-4o")
.increment(usage.getCompletionTokens());
// Calculate cost
double promptCost = usage.getPromptTokens() * PROMPT_TOKEN_PRICE;
double completionCost = usage.getCompletionTokens() * COMPLETION_TOKEN_PRICE;
double totalCost = promptCost + completionCost;
meterRegistry.counter("ai.cost.total", "model", "gpt-4o")
.increment(totalCost);
log.info("Request cost: ${} (prompt: {}, completion: {})",
totalCost, usage.getPromptTokens(), usage.getCompletionTokens());
return response.getResult().getOutput().getContent();
}
}@Service
public class RateLimitMonitor {
private final OpenAiApi openAiApi;
public ChatResponse callWithRateLimitCheck(ChatCompletionRequest request) {
var response = openAiApi.chatCompletionEntity(request);
// Extract rate limit information
var headers = response.getHeaders();
var requestsLimit = headers.getFirst("x-ratelimit-limit-requests");
var requestsRemaining = headers.getFirst("x-ratelimit-remaining-requests");
var tokensLimit = headers.getFirst("x-ratelimit-limit-tokens");
var tokensRemaining = headers.getFirst("x-ratelimit-remaining-tokens");
log.info("Rate limits - Requests: {}/{}, Tokens: {}/{}",
requestsRemaining, requestsLimit, tokensRemaining, tokensLimit);
// Warn if approaching limits
if (Integer.parseInt(requestsRemaining) < 10) {
log.warn("Approaching request rate limit!");
}
return response.getBody();
}
}@Service
public class BudgetControlService {
private final AtomicLong dailyTokens = new AtomicLong(0);
private final long DAILY_TOKEN_LIMIT = 1_000_000; // 1M tokens per day
public Optional<String> generateWithBudgetCheck(String prompt) {
var estimatedTokens = estimateTokens(prompt);
if (dailyTokens.get() + estimatedTokens > DAILY_TOKEN_LIMIT) {
log.warn("Daily token budget exceeded");
return Optional.empty();
}
var response = chatModel.call(new Prompt(prompt));
var usage = response.getMetadata().getUsage();
dailyTokens.addAndGet(usage.getTotalTokens());
return Optional.of(response.getResult().getOutput().getContent());
}
private long estimateTokens(String text) {
// Rough estimate: 1 token ≈ 4 characters
return text.length() / 4;
}
@Scheduled(cron = "0 0 0 * * *") // Reset at midnight
public void resetDailyBudget() {
dailyTokens.set(0);
log.info("Daily token budget reset");
}
}// Good - Environment variable
@Value("${OPENAI_API_KEY}")
private String apiKey;
// Good - Secrets manager
@Service
public class SecretsService {
public String getOpenAiApiKey() {
// Retrieve from AWS Secrets Manager, Azure Key Vault, etc.
return secretsManager.getSecret("openai-api-key");
}
}
// Bad - Hardcoded (NEVER do this)
private String apiKey = "sk-..."; // Security risk!@Service
public class SecureAiService {
private final OpenAiChatModel chatModel;
public String generateSafely(String userInput) {
// Sanitize input
String sanitized = sanitizeInput(userInput);
// Validate length
if (sanitized.length() > 10000) {
throw new IllegalArgumentException("Input too long");
}
return chatModel.call(new Prompt(sanitized))
.getResult()
.getOutput()
.getContent();
}
private String sanitizeInput(String input) {
// Remove potentially harmful content
return input.replaceAll("[<>]", "")
.trim();
}
}@Service
public class ModeratedChatService {
private final OpenAiChatModel chatModel;
private final OpenAiModerationModel moderationModel;
public String generateWithModeration(String userInput) {
// Check input
var inputModeration = moderationModel.call(new ModerationPrompt(userInput));
if (inputModeration.getResult().isFlagged()) {
throw new IllegalArgumentException("Input violates content policy");
}
// Generate response
var response = chatModel.call(new Prompt(userInput));
String output = response.getResult().getOutput().getContent();
// Check output
var outputModeration = moderationModel.call(new ModerationPrompt(output));
if (outputModeration.getResult().isFlagged()) {
log.warn("Generated content flagged by moderation");
return "I cannot provide that information.";
}
return output;
}
}@Service
public class AuditedAiService {
private final OpenAiChatModel chatModel;
public String generateWithAudit(String userId, String prompt) {
log.info("AI request - User: {}, Prompt length: {}", userId, prompt.length());
var startTime = System.currentTimeMillis();
var response = chatModel.call(new Prompt(prompt));
var duration = System.currentTimeMillis() - startTime;
var usage = response.getMetadata().getUsage();
log.info("AI response - User: {}, Duration: {}ms, Tokens: {}",
userId, duration, usage.getTotalTokens());
return response.getResult().getOutput().getContent();
}
}Install with Tessl CLI
npx tessl i tessl/maven-org-springframework-ai--spring-ai-openai@1.1.0