ML Collections is a library of Python collections designed for ML usecases.
—
Integration with absl.flags for loading configurations from files, defining command-line overrides, and managing parameterized experiment configurations. This module bridges ML Collections with command-line argument parsing, enabling flexible experiment management and configuration overrides.
Load configurations from Python files containing a get_config() function, enabling organized experiment configuration management.
def DEFINE_config_file(
name: str,
default: Optional[str] = None,
help_string: str = "path to config file.",
flag_values = FLAGS,
lock_config: bool = True,
accept_new_attributes: bool = False,
sys_argv: Optional[List[str]] = None,
**kwargs
):
"""
Define a flag that loads configuration from a Python file.
Args:
name (str): Name of the flag
default (str, optional): Default config file path
help_string (str): Help text for the flag
flag_values: FlagValues instance to register with (default: FLAGS)
lock_config (bool): Whether to lock config after loading (default: True)
accept_new_attributes (bool): Allow new attributes in overrides (default: False)
sys_argv (List[str], optional): Alternative sys.argv for parsing
**kwargs: Additional arguments passed to absl.flags
Returns:
Flag object for the configuration file
"""Usage example:
from absl import app, flags
from ml_collections import config_flags
# Define config file flag
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('config', default='configs/default.py')
def main(argv):
config = FLAGS.config
print(f"Model: {config.model}")
print(f"Learning rate: {config.learning_rate}")
if __name__ == '__main__':
app.run(main)Example config file (configs/default.py):
from ml_collections import ConfigDict
def get_config():
config = ConfigDict()
config.model = 'resnet50'
config.learning_rate = 0.001
config.batch_size = 32
config.optimizer = ConfigDict()
config.optimizer.name = 'adam'
config.optimizer.beta1 = 0.9
return configDefine flags that accept ConfigDict objects directly, useful for programmatic configuration setup.
def DEFINE_config_dict(
name: str,
config: ConfigDict,
help_string: str = "ConfigDict instance.",
flag_values = FLAGS,
lock_config: bool = True,
accept_new_attributes: bool = False,
sys_argv: Optional[List[str]] = None,
**kwargs
):
"""
Define a flag that accepts a ConfigDict directly.
Args:
name (str): Name of the flag
config (ConfigDict): Default ConfigDict value
help_string (str): Help text for the flag
flag_values: FlagValues instance to register with (default: FLAGS)
lock_config (bool): Whether to lock config after loading (default: True)
accept_new_attributes (bool): Allow new attributes in overrides (default: False)
sys_argv (List[str], optional): Alternative sys.argv for parsing
**kwargs: Additional arguments passed to absl.flags
Returns:
Flag object for the ConfigDict
"""Define flags for dataclass-based configurations, providing type-safe alternatives to dictionary-based configs.
def DEFINE_config_dataclass(
name: str,
config_class,
help_string: str = "Dataclass configuration.",
flag_values = FLAGS,
sys_argv: Optional[List[str]] = None,
parse_fn: Optional[Callable[[Any], Any]] = None,
**kwargs
):
"""
Define a flag for dataclass-based configurations.
Args:
name (str): Name of the flag
config_class: Dataclass type for the configuration
help_string (str): Help text for the flag
flag_values: FlagValues instance to register with (default: FLAGS)
sys_argv (List[str], optional): Alternative sys.argv for parsing
parse_fn (Callable, optional): Custom parse function for dataclass fields
**kwargs: Additional arguments passed to absl.flags
Returns:
Flag object for the dataclass configuration
"""Access metadata and override information from configuration flags.
def get_config_filename(flag_name: str) -> Optional[str]:
"""
Return the filename of a config file flag.
Args:
flag_name (str): Name of the config file flag
Returns:
str or None: Path to the config file, or None if not set
"""
def get_override_values(flag_name: str) -> Dict[str, Any]:
"""
Return dictionary of override values from command line.
Args:
flag_name (str): Name of the config flag
Returns:
dict: Dictionary containing all override values applied
"""
def is_config_flag(flag) -> bool:
"""
Type checking utility for ConfigFlags.
Args:
flag: Flag object to check
Returns:
bool: True if flag is a config flag type
"""Override nested configuration values directly from the command line using dot notation.
Command-line usage:
# Override nested values with dot notation
python train.py --config=configs/resnet.py \
--config.learning_rate=0.01 \
--config.model.num_layers=50 \
--config.optimizer.name=sgd
# Override with complex types
python train.py --config=configs/base.py \
--config.data.image_size="(224, 224)" \
--config.augmentation.enabled=TrueRegister custom parsers for specialized types in configuration overrides.
def register_flag_parser_for_type(type_name, parser_fn):
"""
Register a parser for a specific type.
Args:
type_name: Type to register parser for
parser_fn: Function that parses string to the type
"""
def register_flag_parser(*, parser: flags.ArgumentParser):
"""
Decorator to register custom flag parsers for types.
Args:
parser: ArgumentParser instance for the type
Returns:
Decorator function for parser registration
"""Usage example:
from ml_collections import config_flags
from absl import flags
import numpy as np
# Register custom parser for numpy arrays
@config_flags.register_flag_parser(parser=flags.ArgumentParser())
def parse_numpy_array(value):
# Parse string representation to numpy array
return np.array(eval(value))
# Now you can override numpy array fields from command line
# --config.weights="[1.0, 2.0, 3.0]"Pass parameters to config files for dynamic configuration generation.
Config file with parameters (configs/parameterized.py):
from ml_collections import ConfigDict
def get_config(model_size='base'):
config = ConfigDict()
if model_size == 'small':
config.hidden_size = 256
config.num_layers = 6
elif model_size == 'base':
config.hidden_size = 512
config.num_layers = 12
elif model_size == 'large':
config.hidden_size = 1024
config.num_layers = 24
config.learning_rate = 0.001
return configUsage:
# Pass parameters to config function
python train.py --config=configs/parameterized.py \
--config.model_size=largeML Collections supports command-line overrides for the following types:
int, float, bool, strtuple, list (converted to tuples)enum.Enum valuesExamples:
# Basic types
--config.learning_rate=0.01
--config.enabled=True
--config.model_name="resnet50"
# Tuples and nested structures
--config.image_size="(224, 224)"
--config.data.train_split=0.8
--config.optimizer.params.beta1=0.9
# Complex nested overrides
--config.model.layers.0.filters=64
--config.scheduler.milestones="[30, 60, 90]"For dataclass configurations, ML Collections supports special values for optional fields:
build/True/1: Create default instance of the dataclassnone/False/0: Set field to None (for Optional fields)Example:
# Create default instance
python train.py --config.model=build
# Set to None
python train.py --config.optional_field=noneML Collections provides backward compatibility exports for legacy code:
from ml_collections.config_flags import GetValue, GetType, SetValue
# Legacy functions (deprecated)
GetValue = config_path.get_value
GetType = config_path.get_type
SetValue = config_path.set_valueConfig flags provide additional properties for introspection:
# Access config filename (including parameterization)
config_filename = FLAGS.config.config_filename
# Get all override values as flat dictionary
override_values = FLAGS.config.override_valuesML Collections supports enum fields with case-insensitive parsing:
from enum import Enum
class ModelType(Enum):
RESNET = "resnet"
EFFICIENTNET = "efficientnet"
# Command line: --config.model_type=RESNET or --config.model_type=resnetTuple fields support repeated flags for building sequences:
# Build tuple from multiple flags
python train.py --config.layers 64 --config.layers 128 --config.layers 256class UnsupportedOperationError(Exception):
"""Raised for unsupported flag operations."""
class FlagOrderError(ValueError):
"""Raised when flags are accessed in wrong order."""
class UnparsedFlagError(ValueError):
"""Raised when flags haven't been parsed."""Install with Tessl CLI
npx tessl i tessl/pypi-ml-collections