Python client for Replicate
—
Management of model collections and training operations for custom model development and organization.
Organize and manage collections of related models.
class Collections:
def get(self, name: str) -> Collection:
"""
Get a collection by name.
Parameters:
- name: Collection name in format "owner/name"
Returns:
Collection object with metadata and model listings
"""
def list(self, **params) -> Page[Collection]:
"""
List collections.
Returns:
Paginated list of Collection objects
"""Collections represent curated groups of related models.
class Collection:
name: str
"""The name of the collection."""
slug: str
"""The URL slug of the collection."""
description: Optional[str]
"""The description of the collection."""
models: List[Model]
"""List of models in the collection."""Create and manage custom model training jobs.
class Trainings:
def create(
self,
model: str,
version: str,
input: Dict[str, Any],
*,
destination: Optional[str] = None,
webhook: Optional[str] = None,
webhook_events_filter: Optional[List[str]] = None,
**params
) -> Training:
"""
Create a new training job.
Parameters:
- model: Base model name in format "owner/name"
- version: Base model version ID
- input: Training input parameters and datasets
- destination: Destination for trained model (defaults to your account)
- webhook: Webhook URL for completion notification
- webhook_events_filter: Events to trigger webhook
Returns:
Training object to monitor training progress
"""
def get(self, id: str) -> Training:
"""
Get a training by ID.
Parameters:
- id: Training ID
Returns:
Training object with current status and details
"""
def list(self, **params) -> Page[Training]:
"""
List training jobs.
Returns:
Paginated list of Training objects
"""
def cancel(self, id: str) -> Training:
"""
Cancel a running training job.
Parameters:
- id: Training ID
Returns:
Updated Training object with canceled status
"""Training jobs represent custom model training with status, logs, and output models.
class Training:
id: str
"""The unique ID of the training."""
model: str
"""Base model identifier in format `owner/name`."""
version: str
"""Base model version identifier."""
status: Literal["starting", "processing", "succeeded", "failed", "canceled"]
"""The status of the training."""
input: Optional[Dict[str, Any]]
"""The input parameters for training."""
output: Optional[Dict[str, Any]]
"""The output of the training (trained model info)."""
logs: Optional[str]
"""The logs of the training."""
error: Optional[str]
"""The error encountered during training, if any."""
metrics: Optional[Dict[str, Any]]
"""Training metrics and statistics."""
created_at: Optional[str]
"""When the training was created."""
started_at: Optional[str]
"""When the training was started."""
completed_at: Optional[str]
"""When the training was completed, if finished."""
urls: Dict[str, str]
"""URLs associated with the training (get, cancel)."""
def wait(self, **params) -> "Training":
"""Wait for the training to complete."""
def cancel(self) -> "Training":
"""Cancel the training."""
def reload(self) -> "Training":
"""Reload the training from the API."""import replicate
# Get a specific collection
collection = replicate.collections.get("replicate/image-upscaling")
print(f"Collection: {collection.name}")
print(f"Description: {collection.description}")
print(f"Models: {len(collection.models)}")
# List models in collection
for model in collection.models:
print(f"- {model.owner}/{model.name}: {model.description}")
# List all collections
collections = replicate.collections.list()
for collection in collections.results:
print(f"{collection.name}: {len(collection.models)} models")import replicate
# Create a training job
training = replicate.trainings.create(
model="stability-ai/stable-diffusion",
version="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
input={
"input_images": "https://example.com/training-images.zip",
"class_name": "my-custom-style",
"num_train_epochs": 1000,
"learning_rate": 1e-6,
"resolution": 512,
"batch_size": 1
},
destination="myusername/my-custom-model"
)
print(f"Training ID: {training.id}")
print(f"Status: {training.status}")import replicate
import time
# Get existing training
training = replicate.trainings.get("training-id-here")
# Monitor progress
while training.status in ["starting", "processing"]:
print(f"Status: {training.status}")
if training.logs:
# Print recent logs
log_lines = training.logs.split('\n')
print(f"Latest log: {log_lines[-2] if len(log_lines) > 1 else 'No logs yet'}")
time.sleep(30) # Wait 30 seconds
training.reload()
print(f"Final status: {training.status}")
if training.status == "succeeded":
print(f"Trained model: {training.output}")
elif training.status == "failed":
print(f"Training failed: {training.error}")import replicate
# Create training with webhook notification
training = replicate.trainings.create(
model="stability-ai/stable-diffusion",
version="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
input={
"input_images": "https://example.com/dataset.zip",
"class_name": "my-style",
"num_train_epochs": 500
},
destination="myusername/my-trained-model",
webhook="https://myapp.com/training-webhook",
webhook_events_filter=["completed"]
)
print(f"Training started: {training.id}")import replicate
# List all trainings
trainings = replicate.trainings.list()
# Filter by status
successful_trainings = [
t for t in trainings.results
if t.status == "succeeded"
]
print(f"Successful trainings: {len(successful_trainings)}")
# Show training details
for training in successful_trainings[:5]: # Show first 5
print(f"ID: {training.id}")
print(f"Base model: {training.model}")
print(f"Created: {training.created_at}")
print(f"Duration: {training.started_at} - {training.completed_at}")
if training.output:
print(f"Result: {training.output}")
print("---")import replicate
# Get a running training
training = replicate.trainings.get("training-id-here")
if training.status in ["starting", "processing"]:
# Cancel the training
training.cancel()
print(f"Training {training.id} canceled")
else:
print(f"Training is {training.status}, cannot cancel")import replicate
# Fine-tune a text-to-image model with custom parameters
training = replicate.trainings.create(
model="stability-ai/stable-diffusion",
version="latest-version-id",
input={
# Dataset configuration
"input_images": "https://example.com/my-dataset.zip",
"class_name": "myobjclass",
# Training hyperparameters
"num_train_epochs": 2000,
"learning_rate": 5e-6,
"lr_scheduler": "constant",
"lr_warmup_steps": 100,
# Model configuration
"resolution": 768,
"train_batch_size": 2,
"gradient_accumulation_steps": 1,
"mixed_precision": "fp16",
# Output configuration
"save_sample_prompt": "a photo of myobjclass",
"save_sample_negative_prompt": "blurry, low quality",
"num_validation_images": 4
},
destination="myusername/my-fine-tuned-model"
)
# Wait for training completion
training.wait()
if training.status == "succeeded":
print("Training completed successfully!")
print(f"New model available at: {training.output.get('model')}")
# Test the trained model
output = replicate.run(
training.output['model'],
input={"prompt": "a photo of myobjclass in a forest"}
)
with open("test_output.png", "wb") as f:
f.write(output.read())Install with Tessl CLI
npx tessl i tessl/pypi-replicate