Spring AI utility library providing retry mechanisms for AI API interactions with comprehensive error handling and exception classification
This document demonstrates comprehensive error handling patterns for Spring AI Retry.
import org.springframework.ai.retry.*;
import org.springframework.retry.support.RetryTemplate;
@Service
public class AiService {
private final RetryTemplate retryTemplate;
public String generate(String prompt) {
try {
return retryTemplate.execute(context ->
aiClient.generate(prompt)
);
} catch (TransientAiException e) {
// All retries exhausted - service is down
logger.error("AI service unavailable after retries: {}", e.getMessage());
throw new ServiceUnavailableException("AI service temporarily down", e);
} catch (NonTransientAiException e) {
// Immediate failure - request is invalid
logger.error("Invalid AI request: {}", e.getMessage());
throw new BadRequestException("Unable to process request", e);
}
}
}public String generateWithFallback(String prompt) {
return retryTemplate.execute(
// Main operation
context -> aiClient.generate(prompt),
// Recovery callback - only for TransientAiException
recoveryContext -> {
logger.warn("Using cached response after retry exhaustion");
return cacheService.getLastKnownGood(prompt);
}
);
}public Response generateWithPartialFallback(String prompt) {
try {
String result = retryTemplate.execute(context ->
aiClient.generate(prompt)
);
return Response.success(result);
} catch (TransientAiException e) {
// Use degraded service
logger.warn("Primary AI service down, using simplified model");
String fallback = simpleModelClient.generate(prompt);
return Response.degraded(fallback, "Using simplified model");
} catch (NonTransientAiException e) {
// Cannot recover - return error
logger.error("Request failed validation: {}", e.getMessage());
return Response.error(e.getMessage());
}
}public String generateWithConditionalRecovery(String prompt, boolean allowCache) {
return retryTemplate.execute(
// Main operation
context -> {
logger.debug("Attempt {} of {}",
context.getRetryCount() + 1, 10);
return aiClient.generate(prompt);
},
// Conditional recovery
recoveryContext -> {
Throwable error = recoveryContext.getLastThrowable();
if (error instanceof TransientAiException) {
if (allowCache) {
logger.info("Using cached response");
return cacheService.get(prompt);
} else {
throw new ServiceUnavailableException(
"AI service unavailable and cache disabled", error);
}
} else {
// Non-transient - propagate
throw (RuntimeException) error;
}
}
);
}public String generateWithContextLogging(String prompt) {
return retryTemplate.execute(context -> {
int attemptNumber = context.getRetryCount() + 1;
int maxAttempts = 10;
logger.info("AI request attempt {}/{}", attemptNumber, maxAttempts);
if (context.getRetryCount() > 0) {
Throwable lastError = context.getLastThrowable();
logger.warn("Retrying after error: {}", lastError.getMessage());
}
try {
return aiClient.generate(prompt);
} catch (Exception e) {
logger.error("Attempt {} failed: {}", attemptNumber, e.getMessage());
// Re-throw as appropriate exception type
if (isTransient(e)) {
throw new TransientAiException("Transient failure", e);
} else {
throw new NonTransientAiException("Permanent failure", e);
}
}
});
}public String callAiService(String prompt) {
return retryTemplate.execute(context -> {
try {
return aiClient.generate(prompt);
} catch (ServiceUnavailableException e) {
// Wrap as TransientAiException for retry
throw new TransientAiException("AI service unavailable", e);
} catch (RateLimitException e) {
// Rate limits are transient
throw new TransientAiException("Rate limit exceeded", e);
} catch (AuthenticationException e) {
// Auth errors are non-transient
throw new NonTransientAiException("Invalid API key", e);
} catch (ValidationException e) {
// Validation errors are non-transient
throw new NonTransientAiException("Invalid request", e);
}
});
}public class ErrorClassifier {
public static RuntimeException classify(Exception e) {
// Network errors - transient
if (e instanceof SocketTimeoutException ||
e instanceof ConnectException ||
e instanceof UnknownHostException) {
return new TransientAiException("Network error", e);
}
// HTTP errors
if (e instanceof HttpStatusCodeException) {
HttpStatusCodeException httpError = (HttpStatusCodeException) e;
int statusCode = httpError.getStatusCode().value();
if (statusCode >= 500) {
return new TransientAiException(
"Server error: " + statusCode, e);
} else {
return new NonTransientAiException(
"Client error: " + statusCode, e);
}
}
// Validation errors - non-transient
if (e instanceof IllegalArgumentException ||
e instanceof JsonParseException) {
return new NonTransientAiException("Validation error", e);
}
// Unknown errors - treat as transient by default
return new TransientAiException("Unknown error", e);
}
}
// Usage
public String generate(String prompt) {
return retryTemplate.execute(context -> {
try {
return aiClient.generate(prompt);
} catch (Exception e) {
throw ErrorClassifier.classify(e);
}
});
}public String generateWithCircuitBreaker(String prompt) {
return retryTemplate.execute(context -> {
// Check circuit breaker state
if (circuitBreaker.isOpen()) {
throw new TransientAiException(
"Circuit breaker open - service recovering");
}
try {
String result = aiClient.generate(prompt);
circuitBreaker.recordSuccess();
return result;
} catch (Exception e) {
circuitBreaker.recordFailure();
if (isTransient(e)) {
throw new TransientAiException("Transient failure", e);
} else {
throw new NonTransientAiException("Permanent failure", e);
}
}
});
}import org.slf4j.MDC;
public String generateWithStructuredLogging(String prompt, String requestId) {
MDC.put("requestId", requestId);
MDC.put("promptLength", String.valueOf(prompt.length()));
try {
return retryTemplate.execute(context -> {
MDC.put("attemptNumber", String.valueOf(context.getRetryCount() + 1));
try {
logger.info("Calling AI service");
String result = aiClient.generate(prompt);
logger.info("AI service call successful");
return result;
} catch (Exception e) {
logger.error("AI service call failed", e);
throw ErrorClassifier.classify(e);
}
});
} catch (TransientAiException e) {
logger.error("AI service unavailable after all retries");
throw new ServiceUnavailableException("Service down", e);
} catch (NonTransientAiException e) {
logger.error("AI request failed validation");
throw new BadRequestException("Invalid request", e);
} finally {
MDC.clear();
}
}public String generateWithMetrics(String prompt) {
long startTime = System.currentTimeMillis();
AtomicInteger attemptCount = new AtomicInteger(0);
try {
return retryTemplate.execute(context -> {
int attempt = attemptCount.incrementAndGet();
try {
String result = aiClient.generate(prompt);
long duration = System.currentTimeMillis() - startTime;
metricsService.recordSuccess(duration, attempt);
return result;
} catch (Exception e) {
metricsService.recordAttemptFailure(attempt);
throw ErrorClassifier.classify(e);
}
});
} catch (TransientAiException e) {
long duration = System.currentTimeMillis() - startTime;
metricsService.recordExhaustion(duration, attemptCount.get());
throw new ServiceUnavailableException("Service down", e);
} catch (NonTransientAiException e) {
long duration = System.currentTimeMillis() - startTime;
metricsService.recordValidationFailure(duration);
throw new BadRequestException("Invalid request", e);
}
}public String generateWithValidation(String prompt) {
// Validate before attempting (fail fast)
if (prompt == null || prompt.isEmpty()) {
throw new NonTransientAiException("Prompt cannot be empty");
}
if (prompt.length() > MAX_PROMPT_LENGTH) {
throw new NonTransientAiException(
"Prompt exceeds maximum length of " + MAX_PROMPT_LENGTH);
}
if (!apiKeyValidator.isValid(apiKey)) {
throw new NonTransientAiException(
"Invalid API key - check configuration");
}
// Only retry actual API calls
return retryTemplate.execute(context ->
aiClient.generate(prompt)
);
}public String generateWithEnrichment(String prompt, Map<String, String> metadata) {
return retryTemplate.execute(context -> {
// Enrich request with metadata
ChatRequest request = ChatRequest.builder()
.prompt(prompt)
.metadata(metadata)
.attemptNumber(context.getRetryCount() + 1)
.requestId(UUID.randomUUID().toString())
.timestamp(Instant.now())
.build();
try {
return aiClient.generate(request);
} catch (Exception e) {
logger.error("Request failed: {}", request, e);
throw ErrorClassifier.classify(e);
}
});
}import java.util.concurrent.CompletableFuture;
public CompletableFuture<String> generateAsync(String prompt) {
return CompletableFuture.supplyAsync(() -> {
return retryTemplate.execute(context ->
aiClient.generate(prompt)
);
}, executorService).exceptionally(e -> {
if (e.getCause() instanceof TransientAiException) {
logger.error("Service unavailable", e);
return cacheService.getLastKnownGood(prompt);
} else {
logger.error("Request failed", e);
throw new RuntimeException(e);
}
});
}public List<String> generateBatch(List<String> prompts) {
return prompts.parallelStream()
.map(prompt -> {
try {
return retryTemplate.execute(context ->
aiClient.generate(prompt)
);
} catch (TransientAiException e) {
logger.warn("Failed to generate for prompt: {}", prompt, e);
return null; // Or fallback value
} catch (NonTransientAiException e) {
logger.error("Invalid prompt: {}", prompt, e);
return null; // Or throw
}
})
.filter(Objects::nonNull)
.collect(Collectors.toList());
}public String generateMultiStage(String prompt) {
// Stage 1: Try primary model with retry
try {
return retryTemplate.execute(context ->
primaryAiClient.generate(prompt)
);
} catch (TransientAiException e) {
logger.warn("Primary AI service unavailable, trying backup");
}
// Stage 2: Try backup model with retry
try {
return retryTemplate.execute(context ->
backupAiClient.generate(prompt)
);
} catch (TransientAiException e) {
logger.warn("Backup AI service unavailable, using cache");
}
// Stage 3: Fall back to cache
String cached = cacheService.get(prompt);
if (cached != null) {
return cached;
}
throw new ServiceUnavailableException("All AI services unavailable");
}public String generateWithConditionalRetry(String prompt) {
return retryTemplate.execute(context -> {
try {
return aiClient.generate(prompt);
} catch (HttpStatusCodeException e) {
String responseBody = e.getResponseBodyAsString();
int statusCode = e.getStatusCode().value();
// Parse error details from response
if (statusCode == 429) {
// Check if it's a temporary rate limit
if (responseBody.contains("retry_after")) {
throw new TransientAiException(
"Temporary rate limit", e);
} else {
throw new NonTransientAiException(
"Quota exceeded", e);
}
}
if (statusCode == 503) {
// Check if it's scheduled maintenance
if (responseBody.contains("maintenance")) {
throw new NonTransientAiException(
"Service in maintenance", e);
} else {
throw new TransientAiException(
"Service temporarily down", e);
}
}
// Default classification
throw ErrorClassifier.classify(e);
}
});
}Install with Tessl CLI
npx tessl i tessl/maven-org-springframework-ai--spring-ai-retry@1.1.0