0
# Model Fine-Tuning
1
2
Create and manage supervised fine-tuning jobs to customize models on your data (Vertex AI only). Fine-tuning allows you to adapt foundation models to your specific use case by training on your own datasets, improving model performance for domain-specific tasks.
3
4
## Capabilities
5
6
### Tune Model
7
8
Create a supervised fine-tuning job to customize a base model.
9
10
```python { .api }
11
class Tunings:
12
"""Synchronous tuning jobs API (Vertex AI only)."""
13
14
def tune(
15
self,
16
*,
17
base_model: str,
18
training_dataset: TuningDataset,
19
config: Optional[CreateTuningJobConfig] = None
20
) -> TuningJob:
21
"""
22
Create a supervised fine-tuning job.
23
24
Parameters:
25
base_model (str): Base model to fine-tune (e.g., 'gemini-1.5-flash-002').
26
training_dataset (TuningDataset): Training dataset specification including
27
GCS URI or inline examples.
28
config (CreateTuningJobConfig, optional): Tuning configuration including:
29
- tuned_model_display_name: Display name for tuned model
30
- epoch_count: Number of training epochs
31
- learning_rate: Learning rate
32
- adapter_size: Adapter size for efficient tuning
33
- validation_dataset: Optional validation dataset
34
35
Returns:
36
TuningJob: Created tuning job with name and initial status.
37
38
Raises:
39
ClientError: For client errors (4xx status codes)
40
ServerError: For server errors (5xx status codes)
41
"""
42
...
43
44
class AsyncTunings:
45
"""Asynchronous tuning jobs API (Vertex AI only)."""
46
47
async def tune(
48
self,
49
*,
50
base_model: str,
51
training_dataset: TuningDataset,
52
config: Optional[CreateTuningJobConfig] = None
53
) -> TuningJob:
54
"""Async version of tune."""
55
...
56
```
57
58
**Usage Example:**
59
60
```python
61
from google.genai import Client
62
from google.genai.types import (
63
TuningDataset,
64
CreateTuningJobConfig,
65
HyperParameters,
66
AdapterSize
67
)
68
69
client = Client(vertexai=True, project='PROJECT_ID', location='us-central1')
70
71
# Define training dataset
72
training_dataset = TuningDataset(
73
gcs_uri='gs://my-bucket/training_data.jsonl'
74
)
75
76
# Configure tuning
77
config = CreateTuningJobConfig(
78
tuned_model_display_name='my-custom-model',
79
epoch_count=5,
80
learning_rate_multiplier=1.0,
81
adapter_size=AdapterSize.ADAPTER_SIZE_SIXTEEN,
82
validation_dataset=TuningDataset(
83
gcs_uri='gs://my-bucket/validation_data.jsonl'
84
)
85
)
86
87
# Start tuning
88
job = client.tunings.tune(
89
base_model='gemini-1.5-flash-002',
90
training_dataset=training_dataset,
91
config=config
92
)
93
94
print(f"Tuning job created: {job.name}")
95
print(f"State: {job.state}")
96
97
# Poll for completion
98
import time
99
while job.state in ['JOB_STATE_PENDING', 'JOB_STATE_RUNNING']:
100
time.sleep(60)
101
job = client.tunings.get(name=job.name)
102
print(f"State: {job.state}")
103
104
if job.state == 'JOB_STATE_SUCCEEDED':
105
print(f"Tuned model: {job.tuned_model.model}")
106
107
# Use tuned model
108
response = client.models.generate_content(
109
model=job.tuned_model.model,
110
contents='Test the tuned model'
111
)
112
print(response.text)
113
```
114
115
### Get Tuning Job
116
117
Retrieve information about a tuning job including status and progress.
118
119
```python { .api }
120
class Tunings:
121
"""Synchronous tuning jobs API (Vertex AI only)."""
122
123
def get(
124
self,
125
*,
126
name: str,
127
config: Optional[GetTuningJobConfig] = None
128
) -> TuningJob:
129
"""
130
Get tuning job information.
131
132
Parameters:
133
name (str): Job name in format 'projects/*/locations/*/tuningJobs/*'.
134
config (GetTuningJobConfig, optional): Get configuration.
135
136
Returns:
137
TuningJob: Job information including state, progress, and tuned model.
138
139
Raises:
140
ClientError: For client errors including 404 if job not found
141
ServerError: For server errors (5xx status codes)
142
"""
143
...
144
145
class AsyncTunings:
146
"""Asynchronous tuning jobs API (Vertex AI only)."""
147
148
async def get(
149
self,
150
*,
151
name: str,
152
config: Optional[GetTuningJobConfig] = None
153
) -> TuningJob:
154
"""Async version of get."""
155
...
156
```
157
158
### Cancel Tuning Job
159
160
Cancel a running tuning job.
161
162
```python { .api }
163
class Tunings:
164
"""Synchronous tuning jobs API (Vertex AI only)."""
165
166
def cancel(self, *, name: str) -> None:
167
"""
168
Cancel a tuning job.
169
170
Parameters:
171
name (str): Job name in format 'projects/*/locations/*/tuningJobs/*'.
172
173
Raises:
174
ClientError: For client errors
175
ServerError: For server errors
176
"""
177
...
178
179
class AsyncTunings:
180
"""Asynchronous tuning jobs API (Vertex AI only)."""
181
182
async def cancel(self, *, name: str) -> None:
183
"""Async version of cancel."""
184
...
185
```
186
187
### List Tuning Jobs
188
189
List all tuning jobs with optional filtering and pagination.
190
191
```python { .api }
192
class Tunings:
193
"""Synchronous tuning jobs API (Vertex AI only)."""
194
195
def list(
196
self,
197
*,
198
config: Optional[ListTuningJobsConfig] = None
199
) -> Union[Pager[TuningJob], Iterator[TuningJob]]:
200
"""
201
List tuning jobs.
202
203
Parameters:
204
config (ListTuningJobsConfig, optional): List configuration including:
205
- page_size: Number of jobs per page
206
- page_token: Token for pagination
207
- filter: Filter expression
208
209
Returns:
210
Union[Pager[TuningJob], Iterator[TuningJob]]: Paginated job list.
211
212
Raises:
213
ClientError: For client errors
214
ServerError: For server errors
215
"""
216
...
217
218
class AsyncTunings:
219
"""Asynchronous tuning jobs API (Vertex AI only)."""
220
221
async def list(
222
self,
223
*,
224
config: Optional[ListTuningJobsConfig] = None
225
) -> Union[AsyncPager[TuningJob], AsyncIterator[TuningJob]]:
226
"""Async version of list."""
227
...
228
```
229
230
**Usage Example:**
231
232
```python
233
from google.genai import Client
234
235
client = Client(vertexai=True, project='PROJECT_ID', location='us-central1')
236
237
# List all tuning jobs
238
for job in client.tunings.list():
239
print(f"{job.name}: {job.state}")
240
if job.tuned_model:
241
print(f" Tuned model: {job.tuned_model.model}")
242
```
243
244
## Types
245
246
```python { .api }
247
from typing import Optional, Union, List, Iterator, AsyncIterator, Dict
248
from datetime import datetime
249
from enum import Enum
250
251
# Configuration types
252
class CreateTuningJobConfig:
253
"""
254
Configuration for creating tuning jobs.
255
256
Attributes:
257
tuned_model_display_name (str, optional): Display name for the tuned model.
258
epoch_count (int, optional): Number of training epochs (1-100). Default: 5.
259
learning_rate_multiplier (float, optional): Learning rate multiplier (0.001-10.0).
260
Default: 1.0.
261
adapter_size (AdapterSize, optional): Adapter size for efficient tuning.
262
Default: ADAPTER_SIZE_ONE.
263
validation_dataset (TuningDataset, optional): Validation dataset.
264
labels (dict[str, str], optional): Job labels.
265
"""
266
tuned_model_display_name: Optional[str] = None
267
epoch_count: Optional[int] = None
268
learning_rate_multiplier: Optional[float] = None
269
adapter_size: Optional[AdapterSize] = None
270
validation_dataset: Optional[TuningDataset] = None
271
labels: Optional[dict[str, str]] = None
272
273
class GetTuningJobConfig:
274
"""Configuration for getting tuning job."""
275
pass
276
277
class ListTuningJobsConfig:
278
"""
279
Configuration for listing tuning jobs.
280
281
Attributes:
282
page_size (int, optional): Number of jobs per page.
283
page_token (str, optional): Token for pagination.
284
filter (str, optional): Filter expression.
285
"""
286
page_size: Optional[int] = None
287
page_token: Optional[str] = None
288
filter: Optional[str] = None
289
290
class TuningDataset:
291
"""
292
Training or validation dataset specification.
293
294
Attributes:
295
gcs_uri (str, optional): GCS URI to JSONL file (e.g., 'gs://bucket/data.jsonl').
296
Each line should be a JSON object with 'contents' field.
297
"""
298
gcs_uri: Optional[str] = None
299
300
# Response types
301
class TuningJob:
302
"""
303
Tuning job information.
304
305
Attributes:
306
name (str): Job resource name.
307
base_model (str): Base model being tuned.
308
tuned_model (TunedModel, optional): Information about the tuned model.
309
state (JobState): Current job state.
310
create_time (datetime): When job was created.
311
start_time (datetime, optional): When job started.
312
end_time (datetime, optional): When job completed.
313
update_time (datetime): Last update time.
314
labels (dict[str, str], optional): Job labels.
315
tuning_data_stats (TuningDataStats, optional): Training data statistics.
316
error (JobError, optional): Error if job failed.
317
"""
318
name: str
319
base_model: str
320
tuned_model: Optional[TunedModel] = None
321
state: JobState
322
create_time: datetime
323
start_time: Optional[datetime] = None
324
end_time: Optional[datetime] = None
325
update_time: datetime
326
labels: Optional[dict[str, str]] = None
327
tuning_data_stats: Optional[TuningDataStats] = None
328
error: Optional[JobError] = None
329
330
class TunedModel:
331
"""
332
Information about tuned model.
333
334
Attributes:
335
model (str): Model resource name for use in generation requests.
336
endpoint (str, optional): Endpoint for the tuned model.
337
display_name (str, optional): Display name.
338
"""
339
model: str
340
endpoint: Optional[str] = None
341
display_name: Optional[str] = None
342
343
class TuningDataStats:
344
"""
345
Statistics about tuning data.
346
347
Attributes:
348
training_dataset_size (int, optional): Number of training examples.
349
validation_dataset_size (int, optional): Number of validation examples.
350
total_tuning_character_count (int, optional): Total characters in training data.
351
total_billable_character_count (int, optional): Billable characters.
352
"""
353
training_dataset_size: Optional[int] = None
354
validation_dataset_size: Optional[int] = None
355
total_tuning_character_count: Optional[int] = None
356
total_billable_character_count: Optional[int] = None
357
358
class JobState(Enum):
359
"""Tuning job states."""
360
JOB_STATE_UNSPECIFIED = 'JOB_STATE_UNSPECIFIED'
361
JOB_STATE_QUEUED = 'JOB_STATE_QUEUED'
362
JOB_STATE_PENDING = 'JOB_STATE_PENDING'
363
JOB_STATE_RUNNING = 'JOB_STATE_RUNNING'
364
JOB_STATE_SUCCEEDED = 'JOB_STATE_SUCCEEDED'
365
JOB_STATE_FAILED = 'JOB_STATE_FAILED'
366
JOB_STATE_CANCELLING = 'JOB_STATE_CANCELLING'
367
JOB_STATE_CANCELLED = 'JOB_STATE_CANCELLED'
368
JOB_STATE_PAUSED = 'JOB_STATE_PAUSED'
369
370
class AdapterSize(Enum):
371
"""Adapter sizes for efficient tuning."""
372
ADAPTER_SIZE_UNSPECIFIED = 'ADAPTER_SIZE_UNSPECIFIED'
373
ADAPTER_SIZE_ONE = 'ADAPTER_SIZE_ONE'
374
ADAPTER_SIZE_FOUR = 'ADAPTER_SIZE_FOUR'
375
ADAPTER_SIZE_EIGHT = 'ADAPTER_SIZE_EIGHT'
376
ADAPTER_SIZE_SIXTEEN = 'ADAPTER_SIZE_SIXTEEN'
377
ADAPTER_SIZE_THIRTY_TWO = 'ADAPTER_SIZE_THIRTY_TWO'
378
379
class JobError:
380
"""
381
Job error information.
382
383
Attributes:
384
code (int): Error code.
385
message (str): Error message.
386
details (list[dict], optional): Error details.
387
"""
388
code: int
389
message: str
390
details: Optional[list[dict]] = None
391
392
# Pager types
393
class Pager[T]:
394
"""Synchronous pager."""
395
page: list[T]
396
def next_page(self) -> None: ...
397
def __iter__(self) -> Iterator[T]: ...
398
399
class AsyncPager[T]:
400
"""Asynchronous pager."""
401
page: list[T]
402
async def next_page(self) -> None: ...
403
async def __aiter__(self) -> AsyncIterator[T]: ...
404
```
405