class FlowSpec:
"""
Base class for defining flows. Subclass and add @step methods.
Example:
class MyFlow(FlowSpec):
@step
def start(self):
self.next(self.end)
@step
def end(self):
pass
"""@step
"""
Marks a method as a workflow step.
Steps must call self.next() to specify the next step(s).
"""@step
def start(self):
self.next(self.end)@step
def start(self):
self.next(self.a, self.b) # Both run in parallel
@step
def a(self):
self.next(self.join)
@step
def b(self):
self.next(self.join)
@step
def join(self, inputs):
# inputs contains results from a and b
self.merge_artifacts(inputs)
self.next(self.end)@step
def start(self):
self.items = [1, 2, 3]
self.next(self.process, foreach='items')
@step
def process(self):
self.result = self.input * 2 # self.input = current item
self.next(self.join)
@step
def join(self, inputs):
self.results = [i.result for i in inputs]
self.next(self.end)class Parameter:
"""
Define flow parameter (command-line argument).
Args:
name (str): Parameter name
help (str, optional): Help text
default: Default value
type: Type (int, float, str, bool, JSONType, IncludeFile)
required (bool): Whether required (default False)
Example:
from metaflow import FlowSpec, Parameter, step
class MyFlow(FlowSpec):
alpha = Parameter('alpha', default=0.01, type=float)
config = Parameter('config', type=JSONType)
@step
def start(self):
print(f"Alpha: {self.alpha}")
self.next(self.end)
"""class JSONType:
"""Parse JSON string to Python object"""
class IncludeFile:
"""Load parameter value from file"""from metaflow import current
current.run_id # Run ID (e.g. '123')
current.step_name # Step name (e.g. 'train')
current.task_id # Task ID (e.g. '456')
current.retry_count # Retry attempt number
current.origin_run_id # Original run ID if resumed
current.is_production # True if deployed run
current.username # User running the flow
current.flow_name # Flow name
current.pathspec # Full path (FlowName/RunID/StepName/TaskID)
current.namespace # Current namespace
current.card # Access cards (see Cards section)@step
def join(self, inputs):
"""
Args:
inputs (list): List of DataArtifact objects from parallel branches
"""
# Access artifacts from each input
for i in inputs:
print(i.result) # Access artifact
# Merge common artifacts
self.merge_artifacts(inputs) # Merges artifacts present in all inputs
self.merge_artifacts(inputs, include=['a', 'b']) # Merge specific
self.merge_artifacts(inputs, exclude=['temp']) # Exclude specific
self.next(self.end)@batch(cpu: int = None, memory: int = None, image: str = None,
queue: str = None, iam_role: str = None, execution_role: str = None,
shared_memory: int = None, max_swap: int = None,
swappiness: int = None, host_volumes: list = None,
efs_volumes: list = None, use_tmpfs: bool = None,
tmpfs_tempdir: bool = None, tmpfs_size: int = None,
tmpfs_path: str = None, inferentia: int = None,
trainium: int = None, ephemeral_storage: int = None)
"""
Execute step on AWS Batch.
Args:
cpu (int): CPU units (1024 = 1 vCPU)
memory (int): Memory in MB
image (str): Docker image
queue (str): Batch queue name
gpu (int): Number of GPUs
Example:
@batch(cpu=4096, memory=16000, gpu=1, queue='gpu-queue')
@step
def train(self):
pass
"""@kubernetes(cpu: int = None, memory: int = None, image: str = None,
secrets: list = None, namespace: str = None,
service_account: str = None, persistent_volume_claims: dict = None,
node_selector: dict = None, tolerations: list = None)
"""
Execute step on Kubernetes.
Example:
@kubernetes(cpu=4, memory=8000, gpu=1)
@step
def train(self):
pass
"""@resources(cpu: int = None, memory: int = None, gpu: int = None)
"""
Specify resource requirements (works with @batch/@kubernetes).
Example:
@batch
@resources(cpu=8, memory=32000, gpu=2)
@step
def train(self):
pass
"""@retry(times: int = 3, minutes_between_retries: int = 2)
"""
Retry step on failure.
Example:
@retry(times=5, minutes_between_retries=1)
@step
def fetch_data(self):
# Retries up to 5 times with 1 min between attempts
pass
"""@catch(var: str = '_exception', print_exception: bool = True)
"""
Catch exceptions without failing flow.
Args:
var (str): Variable name to store exception (default '_exception')
print_exception (bool): Print exception traceback
Example:
@catch(var='training_error')
@step
def train(self):
if self.training_error:
print(f"Training failed: {self.training_error}")
pass
"""@timeout(seconds: int = None, minutes: int = None, hours: int = None)
"""
Set step execution timeout.
Example:
@timeout(hours=2)
@step
def long_running(self):
pass
"""@card(type: str = 'default', id: str = None, options: dict = None,
timeout: int = None, refresh_interval: int = None)
"""
Attach visualization card to step.
Args:
type (str): Card type ('default', 'blank', or custom)
id (str): Card identifier for multiple cards
refresh_interval (int): Real-time refresh interval (seconds)
Example:
@card(type='blank')
@step
def analyze(self):
from metaflow import current
from metaflow.cards import Markdown, Table
current.card.append(Markdown("# Results"))
current.card.append(Table([[1, 2], [3, 4]]))
self.next(self.end)
"""@conda(libraries: dict = None, python: str = None,
sources: list = None, disabled: bool = False)
"""
Specify Conda environment.
Args:
libraries (dict): {'package': 'version'}
python (str): Python version
disabled (bool): Disable conda for this step
Example:
@conda(libraries={'pandas': '1.3.0', 'scikit-learn': '0.24.2'})
@step
def train(self):
import pandas as pd
pass
"""@pypi(packages: dict = None, python: str = None)
"""
Specify PyPI packages.
Args:
packages (dict): {'package': 'version'}
Example:
@pypi(packages={'torch': '1.9.0', 'transformers': '4.18.0'})
@step
def train(self):
import torch
pass
"""@environment(vars: dict = None)
"""
Set environment variables for step.
Example:
@environment(vars={'AWS_DEFAULT_REGION': 'us-west-2'})
@step
def process(self):
pass
"""@schedule(hourly: bool = False, daily: bool = False,
weekly: bool = False, cron: str = None,
timezone: str = 'UTC')
"""
Schedule flow execution (requires Argo Workflows).
Example:
from metaflow import FlowSpec, schedule, step
@schedule(daily=True, timezone='America/New_York')
class DailyFlow(FlowSpec):
@step
def start(self):
self.next(self.end)
"""@project(name: str)
"""
Assign flow to project namespace.
Example:
@project(name='ml_models')
class TrainingFlow(FlowSpec):
pass
"""@trigger(event: str)
"""
Trigger flow on events (requires Argo Workflows).
Example:
@trigger(event='data_available')
class DataPipeline(FlowSpec):
pass
"""@trigger_on_finish(flow: str)
"""
Trigger when another flow completes.
Example:
@trigger_on_finish(flow='DataPrep')
class Training(FlowSpec):
pass
"""from metaflow import namespace, default_namespace, get_namespace
namespace('production') # Switch to namespace
default_namespace() # Switch to user's default namespace
ns = get_namespace() # Get current namespaceExample:
from metaflow import Flow, namespace
# Access production runs
namespace('production')
flow = Flow('TrainingFlow')
run = flow.latest_runfrom metaflow import FlowSpec, step, Parameter, batch, resources, retry, card, current
from metaflow import conda
class MLPipeline(FlowSpec):
learning_rate = Parameter('lr', default=0.01, type=float)
@step
def start(self):
self.data_files = ['file1.csv', 'file2.csv', 'file3.csv']
self.next(self.load, foreach='data_files')
@retry(times=3)
@step
def load(self):
self.data = self.load_file(self.input)
self.next(self.join_data)
@step
def join_data(self, inputs):
self.all_data = [i.data for i in inputs]
self.next(self.train)
@batch(cpu=8192, memory=32000)
@resources(gpu=2)
@conda(libraries={'torch': '1.9.0', 'pandas': '1.3.0'})
@card(type='blank')
@step
def train(self):
from metaflow.cards import Markdown
self.model = self.train_model(self.all_data, self.learning_rate)
self.metrics = self.evaluate(self.model)
current.card.append(Markdown(f"# Training Complete\n\nAccuracy: {self.metrics['acc']:.2%}"))
self.next(self.end)
@step
def end(self):
print(f"Pipeline complete. Metrics: {self.metrics}")
def load_file(self, path):
return []
def train_model(self, data, lr):
return "model"
def evaluate(self, model):
return {'acc': 0.95}
if __name__ == '__main__':
MLPipeline()