0
# Collections & Training
1
2
Management of model collections and training operations for custom model development and organization.
3
4
## Capabilities
5
6
### Collection Management
7
8
Organize and manage collections of related models.
9
10
```python { .api }
11
class Collections:
12
def get(self, name: str) -> Collection:
13
"""
14
Get a collection by name.
15
16
Parameters:
17
- name: Collection name in format "owner/name"
18
19
Returns:
20
Collection object with metadata and model listings
21
"""
22
23
def list(self, **params) -> Page[Collection]:
24
"""
25
List collections.
26
27
Returns:
28
Paginated list of Collection objects
29
"""
30
```
31
32
### Collection Objects
33
34
Collections represent curated groups of related models.
35
36
```python { .api }
37
class Collection:
38
name: str
39
"""The name of the collection."""
40
41
slug: str
42
"""The URL slug of the collection."""
43
44
description: Optional[str]
45
"""The description of the collection."""
46
47
models: List[Model]
48
"""List of models in the collection."""
49
```
50
51
### Training Management
52
53
Create and manage custom model training jobs.
54
55
```python { .api }
56
class Trainings:
57
def create(
58
self,
59
model: str,
60
version: str,
61
input: Dict[str, Any],
62
*,
63
destination: Optional[str] = None,
64
webhook: Optional[str] = None,
65
webhook_events_filter: Optional[List[str]] = None,
66
**params
67
) -> Training:
68
"""
69
Create a new training job.
70
71
Parameters:
72
- model: Base model name in format "owner/name"
73
- version: Base model version ID
74
- input: Training input parameters and datasets
75
- destination: Destination for trained model (defaults to your account)
76
- webhook: Webhook URL for completion notification
77
- webhook_events_filter: Events to trigger webhook
78
79
Returns:
80
Training object to monitor training progress
81
"""
82
83
def get(self, id: str) -> Training:
84
"""
85
Get a training by ID.
86
87
Parameters:
88
- id: Training ID
89
90
Returns:
91
Training object with current status and details
92
"""
93
94
def list(self, **params) -> Page[Training]:
95
"""
96
List training jobs.
97
98
Returns:
99
Paginated list of Training objects
100
"""
101
102
def cancel(self, id: str) -> Training:
103
"""
104
Cancel a running training job.
105
106
Parameters:
107
- id: Training ID
108
109
Returns:
110
Updated Training object with canceled status
111
"""
112
```
113
114
### Training Objects
115
116
Training jobs represent custom model training with status, logs, and output models.
117
118
```python { .api }
119
class Training:
120
id: str
121
"""The unique ID of the training."""
122
123
model: str
124
"""Base model identifier in format `owner/name`."""
125
126
version: str
127
"""Base model version identifier."""
128
129
status: Literal["starting", "processing", "succeeded", "failed", "canceled"]
130
"""The status of the training."""
131
132
input: Optional[Dict[str, Any]]
133
"""The input parameters for training."""
134
135
output: Optional[Dict[str, Any]]
136
"""The output of the training (trained model info)."""
137
138
logs: Optional[str]
139
"""The logs of the training."""
140
141
error: Optional[str]
142
"""The error encountered during training, if any."""
143
144
metrics: Optional[Dict[str, Any]]
145
"""Training metrics and statistics."""
146
147
created_at: Optional[str]
148
"""When the training was created."""
149
150
started_at: Optional[str]
151
"""When the training was started."""
152
153
completed_at: Optional[str]
154
"""When the training was completed, if finished."""
155
156
urls: Dict[str, str]
157
"""URLs associated with the training (get, cancel)."""
158
159
def wait(self, **params) -> "Training":
160
"""Wait for the training to complete."""
161
162
def cancel(self) -> "Training":
163
"""Cancel the training."""
164
165
def reload(self) -> "Training":
166
"""Reload the training from the API."""
167
```
168
169
### Usage Examples
170
171
#### Explore Collections
172
173
```python
174
import replicate
175
176
# Get a specific collection
177
collection = replicate.collections.get("replicate/image-upscaling")
178
179
print(f"Collection: {collection.name}")
180
print(f"Description: {collection.description}")
181
print(f"Models: {len(collection.models)}")
182
183
# List models in collection
184
for model in collection.models:
185
print(f"- {model.owner}/{model.name}: {model.description}")
186
187
# List all collections
188
collections = replicate.collections.list()
189
for collection in collections.results:
190
print(f"{collection.name}: {len(collection.models)} models")
191
```
192
193
#### Create Training Job
194
195
```python
196
import replicate
197
198
# Create a training job
199
training = replicate.trainings.create(
200
model="stability-ai/stable-diffusion",
201
version="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
202
input={
203
"input_images": "https://example.com/training-images.zip",
204
"class_name": "my-custom-style",
205
"num_train_epochs": 1000,
206
"learning_rate": 1e-6,
207
"resolution": 512,
208
"batch_size": 1
209
},
210
destination="myusername/my-custom-model"
211
)
212
213
print(f"Training ID: {training.id}")
214
print(f"Status: {training.status}")
215
```
216
217
#### Monitor Training Progress
218
219
```python
220
import replicate
221
import time
222
223
# Get existing training
224
training = replicate.trainings.get("training-id-here")
225
226
# Monitor progress
227
while training.status in ["starting", "processing"]:
228
print(f"Status: {training.status}")
229
230
if training.logs:
231
# Print recent logs
232
log_lines = training.logs.split('\n')
233
print(f"Latest log: {log_lines[-2] if len(log_lines) > 1 else 'No logs yet'}")
234
235
time.sleep(30) # Wait 30 seconds
236
training.reload()
237
238
print(f"Final status: {training.status}")
239
240
if training.status == "succeeded":
241
print(f"Trained model: {training.output}")
242
elif training.status == "failed":
243
print(f"Training failed: {training.error}")
244
```
245
246
#### Training with Webhooks
247
248
```python
249
import replicate
250
251
# Create training with webhook notification
252
training = replicate.trainings.create(
253
model="stability-ai/stable-diffusion",
254
version="db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf",
255
input={
256
"input_images": "https://example.com/dataset.zip",
257
"class_name": "my-style",
258
"num_train_epochs": 500
259
},
260
destination="myusername/my-trained-model",
261
webhook="https://myapp.com/training-webhook",
262
webhook_events_filter=["completed"]
263
)
264
265
print(f"Training started: {training.id}")
266
```
267
268
#### List and Filter Trainings
269
270
```python
271
import replicate
272
273
# List all trainings
274
trainings = replicate.trainings.list()
275
276
# Filter by status
277
successful_trainings = [
278
t for t in trainings.results
279
if t.status == "succeeded"
280
]
281
282
print(f"Successful trainings: {len(successful_trainings)}")
283
284
# Show training details
285
for training in successful_trainings[:5]: # Show first 5
286
print(f"ID: {training.id}")
287
print(f"Base model: {training.model}")
288
print(f"Created: {training.created_at}")
289
print(f"Duration: {training.started_at} - {training.completed_at}")
290
if training.output:
291
print(f"Result: {training.output}")
292
print("---")
293
```
294
295
#### Cancel Training
296
297
```python
298
import replicate
299
300
# Get a running training
301
training = replicate.trainings.get("training-id-here")
302
303
if training.status in ["starting", "processing"]:
304
# Cancel the training
305
training.cancel()
306
print(f"Training {training.id} canceled")
307
else:
308
print(f"Training is {training.status}, cannot cancel")
309
```
310
311
#### Advanced Training Configuration
312
313
```python
314
import replicate
315
316
# Fine-tune a text-to-image model with custom parameters
317
training = replicate.trainings.create(
318
model="stability-ai/stable-diffusion",
319
version="latest-version-id",
320
input={
321
# Dataset configuration
322
"input_images": "https://example.com/my-dataset.zip",
323
"class_name": "myobjclass",
324
325
# Training hyperparameters
326
"num_train_epochs": 2000,
327
"learning_rate": 5e-6,
328
"lr_scheduler": "constant",
329
"lr_warmup_steps": 100,
330
331
# Model configuration
332
"resolution": 768,
333
"train_batch_size": 2,
334
"gradient_accumulation_steps": 1,
335
"mixed_precision": "fp16",
336
337
# Output configuration
338
"save_sample_prompt": "a photo of myobjclass",
339
"save_sample_negative_prompt": "blurry, low quality",
340
"num_validation_images": 4
341
},
342
destination="myusername/my-fine-tuned-model"
343
)
344
345
# Wait for training completion
346
training.wait()
347
348
if training.status == "succeeded":
349
print("Training completed successfully!")
350
print(f"New model available at: {training.output.get('model')}")
351
352
# Test the trained model
353
output = replicate.run(
354
training.output['model'],
355
input={"prompt": "a photo of myobjclass in a forest"}
356
)
357
358
with open("test_output.png", "wb") as f:
359
f.write(output.read())
360
```