Convert scikit-learn models to ONNX format for cross-platform inference and deployment
—
System for registering custom converters, parsers, and operators to extend skl2onnx support for new model types and third-party libraries. The registration system enables complete customization of the conversion process while maintaining the library's modular architecture.
Register custom conversion functions for new model types or override existing converters.
def update_registered_converter(model, alias=None, shape_fct=None,
convert_fct=None, overwrite=False,
parser=None, options=None):
"""
Register or update a converter for a model type.
Parameters:
- model: class or str, model class to register converter for
- alias: str, alias name for the model (optional, defaults to class name)
- shape_fct: function, shape calculation function for the model
- convert_fct: function, conversion function that generates ONNX operators
- overwrite: bool, whether to overwrite existing converter (default False)
- parser: function, custom parser function for the model (optional)
- options: dict, default options for this converter (optional)
"""Register custom parsing functions that extract conversion-relevant information from models.
def update_registered_parser(model, parser_fct=None, overwrite=False):
"""
Register or update a parser for a model type.
Parameters:
- model: class, model class to register parser for
- parser_fct: function, parser function that extracts model information
- overwrite: bool, whether to overwrite existing parser (default False)
"""Discover supported models and their aliases in the conversion system.
def supported_converters(from_sklearn=False):
"""
Get list of all supported model converters.
Parameters:
- from_sklearn: bool, if True return sklearn model names without 'Sklearn' prefix
Returns:
- list: Supported model names/aliases
"""
def get_model_alias(model_type):
"""
Get the alias name for a model type.
Parameters:
- model_type: class, model class to get alias for
Returns:
- str: Alias name for the model type
Raises:
- KeyError: If model type is not registered
"""The library provides extensive support across all major scikit-learn model categories:
from skl2onnx import update_registered_converter
from skl2onnx.common.data_types import FloatTensorType, Int64TensorType
# Define custom model
class CustomModel:
def __init__(self):
self.coef_ = None
self.intercept_ = None
def fit(self, X, y):
# Custom fitting logic
pass
def predict(self, X):
# Custom prediction logic
pass
# Define shape calculator
def custom_shape_calculator(operator):
"""Calculate output shape for custom model."""
input_shape = operator.inputs[0].shape
return [('output', FloatTensorType(input_shape))]
# Define converter function
def custom_converter(scope, operator, container):
"""Convert custom model to ONNX operators."""
# Implementation of ONNX operator generation
pass
# Register the converter
update_registered_converter(
CustomModel,
alias='CustomModel',
shape_fct=custom_shape_calculator,
convert_fct=custom_converter
)from skl2onnx import update_registered_parser
def custom_parser(scope, model, inputs, custom_parsers=None):
"""Parse custom model and create operator."""
# Extract model information and create operator
pass
# Register the parser
update_registered_parser(CustomModel, custom_parser)from skl2onnx import supported_converters, get_model_alias
from sklearn.ensemble import RandomForestClassifier
# Get all supported converters
all_converters = supported_converters()
print(f"Total supported converters: {len(all_converters)}")
# Get sklearn model names without prefix
sklearn_models = supported_converters(from_sklearn=True)
print(f"Supported sklearn models: {len(sklearn_models)}")
# Get alias for specific model
alias = get_model_alias(RandomForestClassifier)
print(f"RandomForestClassifier alias: {alias}")def advanced_custom_converter(scope, operator, container):
"""Advanced converter with options support."""
# Access custom options
options = operator.raw_operator.get_options()
custom_param = options.get('custom_param', 'default_value')
# Generate ONNX operators based on options
pass
# Register with default options
update_registered_converter(
CustomModel,
alias='AdvancedCustomModel',
shape_fct=custom_shape_calculator,
convert_fct=advanced_custom_converter,
options={'custom_param': 'optimized_value'}
)(scope, operator, container)(operator)(name, type) for outputs(scope, model, inputs, custom_parsers=None)Install with Tessl CLI
npx tessl i tessl/pypi-skl2onnx