Python client for Replicate
—
Comprehensive exception classes and error handling patterns for robust Replicate API integration.
The Replicate client provides a structured exception hierarchy for different types of errors.
class ReplicateException(Exception):
"""A base class for all Replicate exceptions."""
class ModelError(ReplicateException):
"""An error from user's code in a model."""
prediction: Prediction
def __init__(self, prediction: Prediction) -> None:
"""
Initialize with the failed prediction.
Parameters:
- prediction: The prediction that encountered an error
"""
class ReplicateError(ReplicateException):
"""
An error from Replicate's API.
This class represents a problem details response as defined in RFC 7807.
"""
type: Optional[str]
"""A URI that identifies the error type."""
title: Optional[str]
"""A short, human-readable summary of the error."""
status: Optional[int]
"""The HTTP status code."""
detail: Optional[str]
"""A human-readable explanation specific to this occurrence of the error."""
instance: Optional[str]
"""A URI that identifies the specific occurrence of the error."""
def __init__(
self,
type: Optional[str] = None,
title: Optional[str] = None,
status: Optional[int] = None,
detail: Optional[str] = None,
instance: Optional[str] = None
) -> None: ...
@classmethod
def from_response(cls, response: httpx.Response) -> "ReplicateError":
"""Create a ReplicateError from an HTTP response."""
def to_dict(self) -> dict:
"""Get a dictionary representation of the error."""Specific exceptions for webhook validation and processing.
class WebhookValidationError(ValueError):
"""Base webhook validation error."""
class MissingWebhookHeaderError(WebhookValidationError):
"""Missing required webhook header."""
class InvalidSecretKeyError(WebhookValidationError):
"""Invalid webhook secret key."""
class MissingWebhookBodyError(WebhookValidationError):
"""Missing webhook request body."""
class InvalidTimestampError(WebhookValidationError):
"""Invalid or expired timestamp."""
class InvalidSignatureError(WebhookValidationError):
"""Invalid webhook signature."""import replicate
from replicate.exceptions import ModelError, ReplicateError, ReplicateException
from replicate.webhook import (
WebhookValidationError,
MissingWebhookHeaderError,
InvalidSecretKeyError,
MissingWebhookBodyError,
InvalidTimestampError,
InvalidSignatureError
)
try:
output = replicate.run(
"stability-ai/stable-diffusion-3",
input={"prompt": "An astronaut riding a rainbow unicorn"}
)
with open("output.png", "wb") as f:
f.write(output.read())
except ModelError as e:
# Model execution failed
print(f"Model execution failed for prediction {e.prediction.id}")
print(f"Error: {e.prediction.error}")
print(f"Status: {e.prediction.status}")
# Check logs for debugging
if e.prediction.logs:
print(f"Logs:\n{e.prediction.logs}")
except ReplicateError as e:
# API error
print(f"API Error: {e.title}")
print(f"Status: {e.status}")
print(f"Detail: {e.detail}")
# Check error type for specific handling
if e.status == 429:
print("Rate limited - consider adding delays between requests")
elif e.status == 402:
print("Payment required - check your account balance")
elif e.status == 404:
print("Model not found - verify the model name and version")
except ReplicateException as e:
# Generic Replicate error
print(f"Replicate error: {e}")
except Exception as e:
# Unexpected error
print(f"Unexpected error: {e}")import replicate
from replicate.exceptions import ModelError
try:
# This might fail due to invalid input
output = replicate.run(
"some-model/that-might-fail",
input={"invalid_parameter": "bad_value"}
)
except ModelError as e:
prediction = e.prediction
print(f"Model Error Details:")
print(f"Prediction ID: {prediction.id}")
print(f"Status: {prediction.status}")
print(f"Error: {prediction.error}")
# Analyze logs for common issues
if prediction.logs:
logs = prediction.logs.lower()
if "out of memory" in logs:
print("Suggestion: Try reducing batch size or image resolution")
elif "invalid input" in logs:
print("Suggestion: Check input parameters against model schema")
elif "timeout" in logs:
print("Suggestion: Model might need more time, try again")
# Print recent logs
log_lines = prediction.logs.split('\n')
print(f"Recent logs:\n" + '\n'.join(log_lines[-10:]))
# Check if prediction can be retried
if prediction.status == "failed":
print("Prediction failed permanently")
else:
print("Prediction might still be retryable")import replicate
from replicate.exceptions import ReplicateError
import time
def robust_prediction(model_name, input_params, max_retries=3):
"""Create prediction with retry logic for transient errors."""
for attempt in range(max_retries):
try:
return replicate.predictions.create(
model=model_name,
input=input_params
)
except ReplicateError as e:
if e.status == 429: # Rate limited
wait_time = 2 ** attempt # Exponential backoff
print(f"Rate limited, waiting {wait_time} seconds...")
time.sleep(wait_time)
continue
elif e.status == 503: # Service unavailable
wait_time = 5 * (attempt + 1)
print(f"Service unavailable, waiting {wait_time} seconds...")
time.sleep(wait_time)
continue
elif e.status == 402: # Payment required
print("Payment required - check account balance")
raise
elif e.status == 404: # Not found
print(f"Model not found: {model_name}")
raise
elif e.status == 422: # Validation error
print(f"Invalid input parameters: {e.detail}")
print("Available parameters might have changed")
raise
else:
# Other API errors
print(f"API Error: {e.title} (Status: {e.status})")
print(f"Detail: {e.detail}")
if attempt == max_retries - 1:
raise
else:
time.sleep(2)
continue
raise Exception(f"Failed after {max_retries} attempts")
# Usage
try:
prediction = robust_prediction(
"stability-ai/stable-diffusion-3",
{"prompt": "a beautiful landscape"}
)
print(f"Prediction created: {prediction.id}")
except Exception as e:
print(f"Failed to create prediction: {e}")import replicate
from replicate.exceptions import (
WebhookValidationError,
MissingWebhookHeaderError,
InvalidSecretKeyError,
InvalidSignatureError,
InvalidTimestampError
)
def handle_webhook_request(request):
"""Handle webhook with comprehensive error handling."""
try:
# Extract required headers
signature = request.headers.get('Replicate-Signature')
timestamp = request.headers.get('Replicate-Timestamp')
if not signature:
raise MissingWebhookHeaderError("Missing Replicate-Signature header")
if not timestamp:
raise MissingWebhookHeaderError("Missing Replicate-Timestamp header")
# Get request body
body = request.body
if not body:
raise MissingWebhookBodyError("Empty request body")
# Get webhook secret and validate
secret_obj = replicate.webhooks.default.secret()
replicate.webhooks.validate(
headers={"webhook-signature": signature, "webhook-timestamp": timestamp, "webhook-id": "webhook-id"},
body=body.decode('utf-8'),
secret=secret_obj,
tolerance=300
)
# Process webhook payload
import json
payload = json.loads(body.decode('utf-8'))
# Handle different event types
if payload.get('status') == 'succeeded':
handle_success(payload)
elif payload.get('status') == 'failed':
handle_failure(payload)
return {"success": True}, 200
except MissingWebhookHeaderError as e:
print(f"Missing webhook header: {e}")
return {"error": "Missing required headers"}, 400
except InvalidSignatureError as e:
print(f"Invalid webhook signature: {e}")
return {"error": "Invalid signature"}, 401
except InvalidTimestampError as e:
print(f"Invalid timestamp: {e}")
return {"error": "Request too old"}, 400
except WebhookValidationError as e:
print(f"Webhook validation failed: {e}")
return {"error": "Validation failed"}, 400
except json.JSONDecodeError as e:
print(f"Invalid JSON payload: {e}")
return {"error": "Invalid JSON"}, 400
except Exception as e:
print(f"Unexpected webhook error: {e}")
return {"error": "Internal server error"}, 500
def handle_success(payload):
"""Handle successful prediction webhook."""
prediction_id = payload.get('id')
output = payload.get('output')
print(f"Prediction {prediction_id} succeeded")
if output:
# Process output files
for i, url in enumerate(output):
print(f"Output {i}: {url}")
def handle_failure(payload):
"""Handle failed prediction webhook."""
prediction_id = payload.get('id')
error = payload.get('error')
logs = payload.get('logs')
print(f"Prediction {prediction_id} failed: {error}")
if logs:
# Log failure details for debugging
print(f"Error logs:\n{logs}")import replicate
from replicate.exceptions import ModelError, ReplicateError
import time
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ReplicateClient:
"""Wrapper class with error recovery and retry logic."""
def __init__(self, max_retries=3, base_delay=1):
self.max_retries = max_retries
self.base_delay = base_delay
def run_with_retry(self, model, input_params, **kwargs):
"""Run model with automatic retry and error recovery."""
for attempt in range(self.max_retries):
try:
# Attempt to run model
prediction = replicate.predictions.create(
model=model,
input=input_params,
**kwargs
)
# Wait for completion
prediction.wait()
if prediction.status == "succeeded":
return prediction.output
elif prediction.status == "failed":
raise ModelError(prediction)
else:
raise Exception(f"Unexpected status: {prediction.status}")
except ModelError as e:
logger.error(f"Model error (attempt {attempt + 1}): {e.prediction.error}")
# Check if error is retryable
error_msg = e.prediction.error.lower() if e.prediction.error else ""
if "out of memory" in error_msg:
# Try to reduce input complexity
input_params = self._reduce_complexity(input_params)
logger.info("Reduced input complexity for retry")
elif "timeout" in error_msg:
# Just retry - might be transient
logger.info("Timeout error, retrying...")
else:
# Non-retryable error
logger.error("Non-retryable model error")
raise
if attempt < self.max_retries - 1:
delay = self.base_delay * (2 ** attempt)
logger.info(f"Waiting {delay} seconds before retry...")
time.sleep(delay)
else:
raise
except ReplicateError as e:
logger.error(f"API error (attempt {attempt + 1}): {e.title}")
if e.status == 429: # Rate limited
delay = self.base_delay * (2 ** attempt)
logger.info(f"Rate limited, waiting {delay} seconds...")
time.sleep(delay)
elif e.status in [500, 502, 503, 504]: # Server errors
delay = self.base_delay * (2 ** attempt)
logger.info(f"Server error, waiting {delay} seconds...")
time.sleep(delay)
else:
# Non-retryable API error
raise
except Exception as e:
logger.error(f"Unexpected error (attempt {attempt + 1}): {e}")
if attempt < self.max_retries - 1:
time.sleep(self.base_delay)
else:
raise
raise Exception(f"Failed after {self.max_retries} attempts")
def _reduce_complexity(self, input_params):
"""Reduce input complexity to avoid memory issues."""
params = input_params.copy()
# Reduce common parameters that might cause OOM
if 'num_inference_steps' in params:
params['num_inference_steps'] = min(params['num_inference_steps'], 20)
if 'guidance_scale' in params:
params['guidance_scale'] = min(params['guidance_scale'], 7.5)
if 'width' in params and 'height' in params:
# Reduce resolution
params['width'] = min(params['width'], 512)
params['height'] = min(params['height'], 512)
return params
# Usage
client = ReplicateClient(max_retries=3)
try:
output = client.run_with_retry(
"stability-ai/stable-diffusion-3",
{
"prompt": "a beautiful landscape",
"width": 1024,
"height": 1024,
"num_inference_steps": 50
}
)
# Save output
with open("output.png", "wb") as f:
f.write(output.read())
print("Generation completed successfully")
except Exception as e:
logger.error(f"Final failure: {e}")Install with Tessl CLI
npx tessl i tessl/pypi-replicate