0
# Multimodal Machine Learning
1
2
Automated machine learning for heterogeneous data combining text, images, and tabular features. MultiModalPredictor supports diverse tasks including classification, regression, object detection, named entity recognition, semantic matching, and feature extraction using state-of-the-art foundation models.
3
4
## Capabilities
5
6
### MultiModalPredictor Class
7
8
Main predictor class for multimodal data that automatically handles different data modalities and task types with minimal configuration.
9
10
```python { .api }
11
class MultiModalPredictor:
12
def __init__(
13
self,
14
label: str = None,
15
problem_type: str = None,
16
query: str = None,
17
response: str = None,
18
match_label = None,
19
presets: str = None,
20
eval_metric = None,
21
hyperparameters: dict = None,
22
path: str = None,
23
verbosity: int = 2,
24
num_classes: int = None,
25
classes: list = None,
26
warn_if_exist: bool = True,
27
enable_progress_bar: bool = None,
28
pretrained: bool = True,
29
validation_metric: str = None,
30
sample_data_path: str = None,
31
use_ensemble: bool = False,
32
ensemble_size: int = 2,
33
ensemble_mode: str = "one_shot"
34
):
35
"""
36
Initialize MultiModalPredictor for automated multimodal machine learning.
37
38
Parameters:
39
- label: Name of target column to predict
40
- problem_type: Problem type ('binary', 'multiclass', 'regression', 'object_detection',
41
'ner', 'text_similarity', 'image_similarity', 'image_text_similarity',
42
'feature_extraction', 'zero_shot_image_classification', 'few_shot_classification',
43
'semantic_segmentation')
44
- query: Column name for query data in semantic matching tasks
45
- response: Column name for response data in semantic matching tasks
46
- match_label: Label indicating positive matches in semantic matching
47
- presets: Quality presets ('best_quality', 'high_quality', 'medium_quality')
48
- eval_metric: Evaluation metric for model selection
49
- hyperparameters: Custom hyperparameter configurations
50
- path: Directory to save models and artifacts
51
- verbosity: Logging verbosity level (0-4)
52
- num_classes: Number of classes for object detection
53
- classes: Class names for object detection
54
- warn_if_exist: Whether to warn if save path exists
55
- enable_progress_bar: Show training progress bars
56
- pretrained: Use pretrained model weights
57
- validation_metric: Metric for validation and early stopping
58
- sample_data_path: Path to sample data for inference shape
59
- use_ensemble: Enable ensemble learning
60
- ensemble_size: Number of models in ensemble
61
- ensemble_mode: Ensemble construction mode ('one_shot', 'sequential')
62
"""
63
```
64
65
### Model Training
66
67
Train multimodal models on heterogeneous data with automatic preprocessing and model selection.
68
69
```python { .api }
70
def fit(
71
self,
72
train_data,
73
presets: str = None,
74
tuning_data = None,
75
max_num_tuning_data: int = None,
76
id_mappings: dict = None,
77
time_limit: int = None,
78
save_path: str = None,
79
hyperparameters = None,
80
column_types: dict = None,
81
holdout_frac: float = None,
82
teacher_predictor = None,
83
seed: int = 0,
84
standalone: bool = True,
85
hyperparameter_tune_kwargs: dict = None,
86
clean_ckpts: bool = True,
87
predictions: list = None,
88
labels = None,
89
predictors: list = None
90
):
91
"""
92
Fit MultiModalPredictor on multimodal training data.
93
94
Parameters:
95
- train_data: Training data (DataFrame with text, images, tabular columns)
96
- presets: Quality/speed presets
97
- tuning_data: Validation data for hyperparameter tuning
98
- max_num_tuning_data: Maximum tuning samples for object detection
99
- id_mappings: ID-to-content mappings for semantic matching
100
- time_limit: Maximum training time in seconds
101
- save_path: Directory to save models
102
- hyperparameters: Custom hyperparameter configurations
103
- column_types: Manual column type specifications
104
- holdout_frac: Fraction of data for validation
105
- teacher_predictor: Teacher model for knowledge distillation
106
- seed: Random seed for reproducibility
107
- standalone: Save complete model for offline deployment
108
- hyperparameter_tune_kwargs: HPO configuration
109
- clean_ckpts: Clean intermediate checkpoints
110
- predictions: Pre-computed predictions for ensemble
111
- labels: Pre-computed labels for ensemble
112
- predictors: Pre-trained predictors for ensemble
113
114
Returns:
115
MultiModalPredictor: Fitted predictor instance
116
"""
117
```
118
119
### Prediction
120
121
Generate predictions for multimodal data across different task types.
122
123
```python { .api }
124
def predict(
125
self,
126
data,
127
candidate_data = None,
128
id_mappings: dict = None,
129
as_pandas: bool = None,
130
realtime: bool = False,
131
save_results: bool = None,
132
**kwargs
133
):
134
"""
135
Generate predictions for multimodal data.
136
137
Parameters:
138
- data: Input data (DataFrame, dict, list, or file path)
139
- candidate_data: Candidate data for semantic matching/retrieval
140
- id_mappings: ID-to-content mappings
141
- as_pandas: Return results as pandas DataFrame/Series
142
- realtime: Use realtime inference optimization
143
- save_results: Save prediction results to disk
144
- **kwargs: Additional arguments (e.g., as_coco for object detection)
145
146
Returns:
147
Predictions in format appropriate for the task type
148
"""
149
150
def predict_proba(
151
self,
152
data,
153
candidate_data = None,
154
id_mappings: dict = None,
155
as_pandas: bool = None,
156
as_multiclass: bool = True,
157
realtime: bool = False
158
):
159
"""
160
Generate prediction probabilities for classification tasks.
161
162
Parameters:
163
- data: Input data
164
- candidate_data: Candidate data for retrieval tasks
165
- id_mappings: ID-to-content mappings
166
- as_pandas: Return results as pandas DataFrame
167
- as_multiclass: Return all class probabilities vs positive class only
168
- realtime: Use realtime inference optimization
169
170
Returns:
171
Prediction probabilities as DataFrame or numpy array
172
"""
173
```
174
175
### Feature Extraction
176
177
Extract embeddings and features from multimodal data for downstream tasks.
178
179
```python { .api }
180
def extract_embedding(
181
self,
182
data,
183
id_mappings: dict = None,
184
return_masks: bool = False,
185
as_tensor: bool = False,
186
as_pandas: bool = False,
187
realtime: bool = False,
188
signature: str = None
189
):
190
"""
191
Extract feature embeddings from multimodal data.
192
193
Parameters:
194
- data: Input data (DataFrame, dict, or list)
195
- id_mappings: ID-to-content mappings
196
- return_masks: Return attention masks for missing data
197
- as_tensor: Return PyTorch tensors
198
- as_pandas: Return pandas DataFrame
199
- realtime: Use realtime inference optimization
200
- signature: Signature type for semantic matching ('query' or 'response')
201
202
Returns:
203
Feature embeddings as numpy array, tensor, or DataFrame
204
"""
205
```
206
207
### Model Evaluation
208
209
Evaluate multimodal model performance with task-specific metrics.
210
211
```python { .api }
212
def evaluate(
213
self,
214
data,
215
query_data: list = None,
216
response_data: list = None,
217
id_mappings: dict = None,
218
metrics: list = None,
219
chunk_size: int = 1024,
220
similarity_type: str = "cosine",
221
cutoffs: list = [1, 5, 10],
222
label: str = None,
223
return_pred: bool = False,
224
realtime: bool = False,
225
eval_tool: str = None,
226
predictions: list = None,
227
labels = None
228
):
229
"""
230
Evaluate multimodal model performance.
231
232
Parameters:
233
- data: Test data (DataFrame, dict, list, or annotation file path)
234
- query_data: Query data for ranking evaluation
235
- response_data: Response data for ranking evaluation
236
- id_mappings: ID-to-content mappings
237
- metrics: List of evaluation metrics
238
- chunk_size: Batch size for similarity computation
239
- similarity_type: Similarity function ('cosine', 'dot_prod')
240
- cutoffs: Cutoff values for ranking metrics
241
- label: Label column name
242
- return_pred: Return individual predictions
243
- realtime: Use realtime inference
244
- eval_tool: Evaluation tool for object detection ('pycocotools', 'torchmetrics')
245
- predictions: Pre-computed predictions
246
- labels: Pre-computed labels
247
248
Returns:
249
dict: Evaluation metrics and optionally predictions
250
"""
251
```
252
253
### Model Management
254
255
Save, load, and export multimodal models for deployment.
256
257
```python { .api }
258
def save(self, path: str, standalone: bool = True):
259
"""
260
Save trained predictor to disk.
261
262
Parameters:
263
- path: Directory to save predictor
264
- standalone: Save complete model for offline deployment
265
"""
266
267
@classmethod
268
def load(
269
cls,
270
path: str,
271
resume: bool = False,
272
verbosity: int = 3
273
):
274
"""
275
Load saved predictor from disk.
276
277
Parameters:
278
- path: Directory containing saved predictor
279
- resume: Resume training from checkpoint
280
- verbosity: Logging verbosity level
281
282
Returns:
283
MultiModalPredictor: Loaded predictor instance
284
"""
285
286
def export_onnx(
287
self,
288
data,
289
path: str = None,
290
batch_size: int = None,
291
verbose: bool = False,
292
opset_version: int = 16,
293
truncate_long_and_double: bool = False
294
):
295
"""
296
Export model to ONNX format for deployment.
297
298
Parameters:
299
- data: Sample data for tracing
300
- path: Export path (if None, returns bytes)
301
- batch_size: Batch size for export
302
- verbose: Verbose export logging
303
- opset_version: ONNX opset version
304
- truncate_long_and_double: Truncate precision for compatibility
305
306
Returns:
307
Export path or ONNX model bytes
308
"""
309
310
def optimize_for_inference(self, providers: list = None):
311
"""
312
Optimize model for faster inference using ONNX runtime.
313
314
Parameters:
315
- providers: ONNX execution providers
316
317
Returns:
318
Optimized ONNX module for inference
319
"""
320
```
321
322
### Advanced Features
323
324
Advanced functionality for specialized use cases and model analysis.
325
326
```python { .api }
327
def fit_summary(self, verbosity: int = 0, show_plot: bool = False):
328
"""
329
Display training summary and model information.
330
331
Parameters:
332
- verbosity: Detail level (0-4)
333
- show_plot: Show training plots
334
335
Returns:
336
dict: Training summary information
337
"""
338
339
def list_supported_models(self, pretrained: bool = True):
340
"""
341
List supported models for the current problem type.
342
343
Parameters:
344
- pretrained: Show only models with pretrained weights
345
346
Returns:
347
list: Available model names
348
"""
349
350
def dump_model(self, save_path: str = None):
351
"""
352
Export model weights and configs to local directory.
353
354
Parameters:
355
- save_path: Directory to save model files
356
"""
357
358
def set_num_gpus(self, num_gpus: int):
359
"""
360
Set number of GPUs for training/inference.
361
362
Parameters:
363
- num_gpus: Number of GPUs to use
364
"""
365
```
366
367
### Properties
368
369
Access model and training information through properties.
370
371
```python { .api }
372
@property
373
def problem_type(self) -> str:
374
"""Type of ML problem (classification, object_detection, etc.)"""
375
376
@property
377
def label(self) -> str:
378
"""Name of target label column"""
379
380
@property
381
def eval_metric(self) -> str:
382
"""Evaluation metric used for model selection"""
383
384
@property
385
def class_labels(self) -> list:
386
"""Original class label names for classification"""
387
388
@property
389
def positive_class(self):
390
"""Positive class label for binary classification"""
391
392
@property
393
def total_parameters(self) -> int:
394
"""Total number of model parameters"""
395
396
@property
397
def trainable_parameters(self) -> int:
398
"""Number of trainable model parameters"""
399
400
@property
401
def model_size(self) -> float:
402
"""Model size in megabytes"""
403
```
404
405
## Usage Examples
406
407
### Text and Image Classification
408
409
```python
410
from autogluon.multimodal import MultiModalPredictor
411
import pandas as pd
412
413
# Prepare multimodal dataset
414
data = pd.DataFrame({
415
'image_path': ['img1.jpg', 'img2.jpg', 'img3.jpg'],
416
'text_content': ['Product description 1', 'Product description 2', 'Product description 3'],
417
'price': [10.5, 25.0, 15.5],
418
'category': ['A', 'B', 'A']
419
})
420
421
# Train multimodal classifier
422
predictor = MultiModalPredictor(
423
label='category',
424
problem_type='multiclass',
425
presets='high_quality'
426
)
427
428
predictor.fit(
429
data,
430
time_limit=3600,
431
column_types={
432
'image_path': 'image_path',
433
'text_content': 'text',
434
'price': 'numerical'
435
}
436
)
437
438
# Make predictions
439
predictions = predictor.predict(test_data)
440
probabilities = predictor.predict_proba(test_data)
441
442
# Extract embeddings
443
embeddings = predictor.extract_embedding(test_data)
444
print(f"Embedding shape: {embeddings.shape}")
445
```
446
447
### Object Detection
448
449
```python
450
# Object detection on image data
451
detector = MultiModalPredictor(
452
problem_type='object_detection',
453
presets='medium_quality',
454
classes=['person', 'car', 'bicycle'], # Target classes
455
path='./detection_models'
456
)
457
458
# Train on COCO-format data
459
detector.fit(
460
train_data='train_annotations.json', # COCO format
461
time_limit=7200
462
)
463
464
# Predict bounding boxes
465
detections = detector.predict(
466
'test_images/',
467
save_results=True,
468
as_coco=True # Return COCO format results
469
)
470
471
# Evaluate with COCO metrics
472
metrics = detector.evaluate(
473
'test_annotations.json',
474
eval_tool='pycocotools'
475
)
476
print(f"mAP: {metrics['map']:.3f}")
477
```
478
479
### Semantic Text Matching
480
481
```python
482
# Text similarity for semantic matching
483
matcher = MultiModalPredictor(
484
problem_type='text_similarity',
485
query='question',
486
response='answer',
487
match_label='relevant'
488
)
489
490
# Train on query-response pairs
491
matcher.fit(qa_data, time_limit=1800)
492
493
# Find similar documents
494
query_data = ['What is machine learning?']
495
candidate_data = [
496
'ML is a subset of AI',
497
'Python is a programming language',
498
'Deep learning uses neural networks'
499
]
500
501
similarities = matcher.predict(
502
data=query_data,
503
candidate_data=candidate_data
504
)
505
print("Most similar:", candidate_data[similarities.argmax()])
506
```