Spring AI utility library providing retry mechanisms for AI API interactions with comprehensive error handling and exception classification
This document demonstrates real-world integration patterns for Spring AI Retry.
Simple service with auto-configured retry:
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.ResponseErrorHandler;
@Service
public class AiChatService {
private final RetryTemplate retryTemplate;
private final RestTemplate restTemplate;
@Autowired
public AiChatService(
RetryTemplate retryTemplate,
ResponseErrorHandler responseErrorHandler) {
this.retryTemplate = retryTemplate;
this.restTemplate = new RestTemplate();
this.restTemplate.setErrorHandler(responseErrorHandler);
}
public String chat(String prompt) {
return retryTemplate.execute(context -> {
return restTemplate.postForObject(
"https://api.example.com/chat",
new ChatRequest(prompt),
String.class
);
});
}
}Provide fallback behavior when all retries are exhausted:
import org.springframework.ai.retry.TransientAiException;
import org.springframework.retry.RecoveryCallback;
import org.springframework.retry.RetryCallback;
@Service
public class ResilientAiService {
private final RetryTemplate retryTemplate;
private final CacheService cacheService;
public String generateWithFallback(String prompt) {
return retryTemplate.execute(
// Main operation
(RetryCallback<String, Exception>) context -> {
logger.debug("Attempt {} of {}",
context.getRetryCount() + 1, 10);
return aiClient.generate(prompt);
},
// Recovery callback - called after all retries exhausted
(RecoveryCallback<String>) context -> {
Throwable lastException = context.getLastThrowable();
if (lastException instanceof TransientAiException) {
// Service is down - use cache
logger.warn("Using cached response after retry exhaustion");
return cacheService.getLastKnownGood(prompt);
} else {
// Non-transient error - propagate
throw (RuntimeException) lastException;
}
}
);
}
}Try multiple AI providers in sequence:
import org.springframework.ai.retry.TransientAiException;
import org.springframework.ai.retry.NonTransientAiException;
@Service
public class MultiProviderAiService {
private final List<AiProvider> providers;
private final RetryTemplate retryTemplate;
public MultiProviderAiService(
List<AiProvider> providers,
RetryTemplate retryTemplate) {
this.providers = providers;
this.retryTemplate = retryTemplate;
}
public String generate(String prompt) {
// Try each provider in order
for (int i = 0; i < providers.size(); i++) {
AiProvider provider = providers.get(i);
boolean isLastProvider = (i == providers.size() - 1);
try {
return retryTemplate.execute(context -> {
return provider.generate(prompt);
});
} catch (TransientAiException e) {
if (isLastProvider) {
// All providers exhausted
throw new ServiceUnavailableException(
"All AI providers unavailable", e);
}
// Try next provider
logger.warn("Provider {} failed, trying next",
provider.getName());
} catch (NonTransientAiException e) {
// Configuration issue with this provider
logger.error("Provider {} misconfigured: {}",
provider.getName(), e.getMessage());
if (isLastProvider) {
throw e;
}
}
}
throw new ServiceUnavailableException("No providers available");
}
}Combine retry with circuit breaker pattern:
import org.springframework.ai.retry.TransientAiException;
import org.springframework.retry.support.RetryTemplate;
import io.github.resilience4j.circuitbreaker.CircuitBreaker;
import io.github.resilience4j.circuitbreaker.CircuitBreakerRegistry;
@Service
public class CircuitBreakerAiService {
private final RetryTemplate retryTemplate;
private final CircuitBreaker circuitBreaker;
private final AiClient aiClient;
public CircuitBreakerAiService(
RetryTemplate retryTemplate,
CircuitBreakerRegistry circuitBreakerRegistry,
AiClient aiClient) {
this.retryTemplate = retryTemplate;
this.circuitBreaker = circuitBreakerRegistry
.circuitBreaker("aiService");
this.aiClient = aiClient;
}
public String generate(String prompt) {
// Circuit breaker wraps retry logic
return circuitBreaker.executeSupplier(() -> {
// Check circuit state before attempting
if (circuitBreaker.getState() == CircuitBreaker.State.OPEN) {
throw new TransientAiException(
"Circuit breaker open - service recovering");
}
// Retry logic inside circuit breaker
return retryTemplate.execute(context -> {
return aiClient.generate(prompt);
});
});
}
}Handle rate limits with dynamic backoff:
import org.springframework.ai.retry.TransientAiException;
import org.springframework.retry.RetryListener;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryCallback;
@Service
public class RateLimitAwareService {
private final RetryTemplate retryTemplate;
public RateLimitAwareService(RetryTemplate retryTemplate) {
this.retryTemplate = retryTemplate;
// Add listener to handle rate limits
retryTemplate.registerListener(new RetryListener() {
@Override
public <T, E extends Throwable> void onError(
RetryContext context,
RetryCallback<T, E> callback,
Throwable throwable) {
if (throwable instanceof TransientAiException) {
String message = throwable.getMessage();
// Check if it's a rate limit error
if (message.contains("429") ||
message.contains("rate limit")) {
// Increase backoff for rate limits
long backoffMs = extractRetryAfter(message);
if (backoffMs > 0) {
try {
Thread.sleep(backoffMs);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
}
private long extractRetryAfter(String message) {
// Parse Retry-After header if present
// Return 0 if not found
return 0;
}
});
}
public String generate(String prompt) {
return retryTemplate.execute(context ->
aiClient.generate(prompt)
);
}
}Integrate with Micrometer for observability:
import org.springframework.retry.support.RetryTemplate;
import org.springframework.retry.RetryListener;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryCallback;
import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Counter;
import io.micrometer.core.instrument.Timer;
@Configuration
public class MonitoredRetryConfiguration {
@Bean
public RetryTemplate monitoredRetryTemplate(MeterRegistry meterRegistry) {
RetryTemplate template = RetryUtils.DEFAULT_RETRY_TEMPLATE;
// Add metrics listener
template.registerListener(new RetryListener() {
private final Counter retryCounter = meterRegistry.counter(
"ai.retry.attempts", "type", "retry");
private final Counter successAfterRetryCounter = meterRegistry.counter(
"ai.retry.success", "type", "after_retry");
private final Counter exhaustedCounter = meterRegistry.counter(
"ai.retry.exhausted", "type", "failed");
@Override
public <T, E extends Throwable> void onError(
RetryContext context,
RetryCallback<T, E> callback,
Throwable throwable) {
retryCounter.increment();
// Tag by exception type
String exceptionType = throwable.getClass().getSimpleName();
meterRegistry.counter(
"ai.retry.by_exception",
"exception", exceptionType
).increment();
}
@Override
public <T, E extends Throwable> void close(
RetryContext context,
RetryCallback<T, E> callback,
Throwable throwable) {
if (throwable == null && context.getRetryCount() > 0) {
// Succeeded after retry
successAfterRetryCounter.increment();
} else if (throwable != null) {
// All retries exhausted
exhaustedCounter.increment();
}
}
});
return template;
}
}
@Service
public class MonitoredAiService {
private final RetryTemplate retryTemplate;
private final Timer requestTimer;
public MonitoredAiService(
RetryTemplate retryTemplate,
MeterRegistry meterRegistry) {
this.retryTemplate = retryTemplate;
this.requestTimer = meterRegistry.timer("ai.request.duration");
}
public String generate(String prompt) {
return requestTimer.record(() -> {
return retryTemplate.execute(context -> {
return aiClient.generate(prompt);
});
});
}
}Integrate with Spring WebFlux:
import org.springframework.web.reactive.function.client.WebClient;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.ai.retry.NonTransientAiException;
import reactor.util.retry.Retry;
import java.time.Duration;
@Service
public class ReactiveAiService {
private final WebClient webClient;
public ReactiveAiService(WebClient.Builder webClientBuilder) {
this.webClient = webClientBuilder
.baseUrl("https://api.example.com")
.filter((request, next) -> next.exchange(request)
.flatMap(response -> {
if (response.statusCode().isError()) {
return response.bodyToMono(String.class)
.flatMap(body -> {
if (response.statusCode().is4xxClientError()) {
return Mono.error(new NonTransientAiException(
"HTTP " + response.statusCode() + ": " + body));
} else {
return Mono.error(new TransientAiException(
"HTTP " + response.statusCode() + ": " + body));
}
});
}
return Mono.just(response);
})
)
.build();
}
public Mono<String> generate(String prompt) {
return webClient
.post()
.uri("/chat")
.bodyValue(new ChatRequest(prompt))
.retrieve()
.bodyToMono(String.class)
.retryWhen(Retry.backoff(10, Duration.ofSeconds(2))
.maxBackoff(Duration.ofMinutes(3))
.filter(e -> e instanceof TransientAiException));
}
}Create a custom error handler for specific needs:
import org.springframework.ai.retry.TransientAiException;
import org.springframework.ai.retry.NonTransientAiException;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.web.client.ResponseErrorHandler;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
public class CustomAiErrorHandler implements ResponseErrorHandler {
@Override
public boolean hasError(ClientHttpResponse response) throws IOException {
return response.getStatusCode().isError();
}
@Override
public void handleError(ClientHttpResponse response) throws IOException {
String body = new String(
response.getBody().readAllBytes(),
StandardCharsets.UTF_8
);
int statusCode = response.getStatusCode().value();
// Custom logic: Treat 429 as transient
if (statusCode == 429) {
throw new TransientAiException(
"Rate limit exceeded (429): " + body);
}
// Custom logic: Treat 402 (Payment Required) as non-transient
if (statusCode == 402) {
throw new NonTransientAiException(
"Payment required (402): " + body);
}
// Default behavior
if (statusCode >= 400 && statusCode < 500) {
throw new NonTransientAiException(
"Client error (" + statusCode + "): " + body);
} else {
throw new TransientAiException(
"Server error (" + statusCode + "): " + body);
}
}
@Override
public void handleError(URI url, HttpMethod method, ClientHttpResponse response)
throws IOException {
handleError(response);
}
}
// Register custom error handler
@Configuration
public class CustomErrorHandlerConfig {
@Bean
public ResponseErrorHandler responseErrorHandler() {
return new CustomAiErrorHandler();
}
}Configure retry without Spring Boot auto-configuration:
import org.springframework.ai.retry.RetryUtils;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.RestTemplate;
@Configuration
public class ManualRetryConfiguration {
@Bean
public RetryTemplate aiRetryTemplate() {
return RetryUtils.DEFAULT_RETRY_TEMPLATE;
}
@Bean
public RestTemplate aiRestTemplate() {
RestTemplate restTemplate = new RestTemplate();
restTemplate.setErrorHandler(RetryUtils.DEFAULT_RESPONSE_ERROR_HANDLER);
return restTemplate;
}
@Bean
public AiService aiService(RetryTemplate aiRetryTemplate, RestTemplate aiRestTemplate) {
return new AiService(aiRetryTemplate, aiRestTemplate);
}
}Test retry behavior with mocked services:
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.retry.TransientAiException;
import org.springframework.ai.retry.NonTransientAiException;
import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import java.util.concurrent.atomic.AtomicInteger;
class AiServiceTest {
@Test
void testRetryOnTransientFailure() {
RetryTemplate template = RetryUtils.SHORT_RETRY_TEMPLATE;
AiClient mockClient = Mockito.mock(AiClient.class);
AtomicInteger attempts = new AtomicInteger(0);
Mockito.when(mockClient.generate(any())).thenAnswer(invocation -> {
int attempt = attempts.incrementAndGet();
if (attempt < 3) {
throw new TransientAiException("Temporary failure");
}
return "Success";
});
String result = template.execute(context ->
mockClient.generate("test")
);
assertEquals("Success", result);
assertEquals(3, attempts.get());
}
@Test
void testNoRetryOnNonTransientFailure() {
RetryTemplate template = RetryUtils.SHORT_RETRY_TEMPLATE;
AiClient mockClient = Mockito.mock(AiClient.class);
Mockito.when(mockClient.generate(any()))
.thenThrow(new NonTransientAiException("Invalid API key"));
assertThrows(NonTransientAiException.class, () -> {
template.execute(context -> mockClient.generate("test"));
});
// Should only call once (no retry)
Mockito.verify(mockClient, Mockito.times(1)).generate(any());
}
}Install with Tessl CLI
npx tessl i tessl/maven-org-springframework-ai--spring-ai-retry@1.1.0