or run

tessl search
Log in

Version

Files

docs

ai-registry.mdclarify.mddata-io.mddebugger.mdevaluation.mdexperiments.mdexplainer-config.mdindex.mdjumpstart.mdlineage.mdmlops.mdmonitoring.mdprocessing.mdremote-functions.mdresources.mds3-utilities.mdserving.mdtraining.mdworkflow-primitives.md
tile.json

ai-registry.mddocs/

AI Registry

Manage datasets and evaluators in the SageMaker AI Registry Hub for model customization workflows.

Package Information

Important: The sagemaker.ai_registry module does not export classes from its __init__.py. Classes must be imported using their full module paths:

# Correct imports (full module paths required)
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator, EvaluatorMethod
from sagemaker.ai_registry.dataset_utils import DataSetMethod
from sagemaker.ai_registry.air_constants import HubContentStatus

# Incorrect (will not work)
# from sagemaker.ai_registry import DataSet  # ImportError

Capabilities

AIRHub

Central hub client for managing AI Registry content operations.

class AIRHub:
    """
    AI Registry Hub client for managing hub content.

    Class Methods:
        get_hub_name() -> str
            Get hub name for current account/region.
            
            Returns:
                str: Hub name
        
        import_hub_content(hub_content_type, hub_content_name, document_schema_version, 
                          hub_content_document, hub_content_version, tags, session) -> Dict
            Import content to hub.
            
            Parameters:
                hub_content_type: str - Content type ("Dataset" or "Evaluator")
                hub_content_name: str - Content name (required)
                document_schema_version: str - Schema version (required)
                hub_content_document: Dict - Content document (required)
                hub_content_version: str - Content version (required)
                tags: Optional[List[Dict]] - Tags
                session: Optional[Session] - SageMaker session
            
            Returns:
                Dict: Import response with ARN
        
        list_hub_content(hub_content_type, max_results=100, next_token=None, session=None) -> Dict
            List hub content with details.
            
            Parameters:
                hub_content_type: str - Content type filter
                max_results: int - Maximum results (default: 100)
                next_token: Optional[str] - Pagination token
                session: Optional[Session] - Session
            
            Returns:
                Dict: Content list with pagination
        
        describe_hub_content(hub_content_type, hub_content_name, hub_content_version, session=None) -> Dict
            Describe hub content.
            
            Parameters:
                hub_content_type: str - Content type (required)
                hub_content_name: str - Content name (required)
                hub_content_version: str - Content version (required)
                session: Optional[Session] - Session
            
            Returns:
                Dict: Content details
        
        delete_hub_content(hub_content_type, hub_content_name, hub_content_version, session=None) -> None
            Delete hub content.
            
            Parameters:
                hub_content_type: str - Content type (required)
                hub_content_name: str - Content name (required)
                hub_content_version: str - Content version (required)
                session: Optional[Session] - Session
        
        list_hub_content_versions(hub_content_type, hub_content_name, session=None) -> List[Dict]
            List content versions.
            
            Returns:
                List[Dict]: Version information

    Notes:
        - Low-level hub operations
        - Prefer DataSet and Evaluator classes
        - Use for advanced hub management
    """

AIRHubEntity

Base class for AI Registry Hub entities.

class AIRHubEntity(ABC):
    """
    Base entity for AI Registry Hub content.

    Parameters:
        name: str - Name of the hub content (required)
        version: str - Version of the hub content (required)
        arn: str - ARN of the hub content (required)
        status: Optional[HubContentStatus] - Status of the hub content
        created_time: Optional[str] - Creation timestamp
        updated_time: Optional[str] - Last update timestamp
        description: Optional[str] - Description of the hub content
        sagemaker_session: Optional[Session] - SageMaker session

    Abstract Methods:
        hub_content_type() -> str
            Return hub content type.
            
            Returns:
                str: "Dataset" or "Evaluator"
        
        _get_hub_content_type_for_list() -> str
            Return content type for list operation.

    Methods:
        list(max_results=100, next_token=None) -> Dict
            List all entities of this type.
            
            Parameters:
                max_results: int - Maximum results (default: 100)
                next_token: Optional[str] - Pagination token
            
            Returns:
                Dict: Entities with pagination info
        
        get_versions() -> List[Dict]
            List all versions of this entity.
            
            Returns:
                List[Dict]: Version information
        
        delete(version=None) -> bool
            Delete this entity instance.
            
            Parameters:
                version: Optional[str] - Specific version (default: current)
            
            Returns:
                bool: True if deleted successfully
        
        delete_by_name(name, version=None) -> bool
            Delete entity by name.
            
            Parameters:
                name: str - Entity name (required)
                version: Optional[str] - Version
            
            Returns:
                bool: True if deleted
        
        wait(poll=5, timeout=300) -> None
            Wait for entity to reach terminal state.
            
            Parameters:
                poll: int - Polling interval (default: 5 seconds)
                timeout: int - Timeout (default: 300 seconds)
            
            Raises:
                TimeoutError: If timeout exceeded
        
        refresh() -> None
            Refresh entity state from API.

    Notes:
        - Base class for DataSet and Evaluator
        - Provides common operations
        - Status tracked through lifecycle
    """

DataSet

Dataset entity for AI Registry.

class DataSet(AIRHubEntity):
    """
    Dataset entity for AI Registry.

    Parameters:
        name: str - Name of the dataset (required)
        arn: str - ARN of the dataset (required)
        version: str - Version of the dataset (required)
        status: HubContentStatus - Current status (required)
        source: Optional[str] - S3 location of the dataset
        description: Optional[str] - Description
        customization_technique: Optional[CustomizationTechnique] - Customization technique
            - SFT, DPO, RLAIF, RLVR
        method: Optional[DataSetMethod] - Method used to create dataset
            - UPLOADED or GENERATED
        created_time: Optional[datetime] - Creation timestamp
        updated_time: Optional[datetime] - Last update timestamp
        sagemaker_session: Optional[Session] - SageMaker session

    Class Methods:
        get(name, sagemaker_session=None) -> DataSet
            Get dataset by name.
            
            Parameters:
                name: str - Dataset name (required)
                sagemaker_session: Optional[Session] - Session
            
            Returns:
                DataSet: Dataset object
            
            Raises:
                ClientError: If dataset not found
        
        create(name, source, customization_technique=None, wait=True, description=None, 
               tags=None, role=None, sagemaker_session=None) -> DataSet
            Create new dataset.
            
            Parameters:
                name: str - Dataset name (required)
                source: str - Local file path or S3 URI (required)
                customization_technique: Optional[CustomizationTechnique] - Technique
                wait: bool - Wait for availability (default: True)
                description: Optional[str] - Description
                tags: Optional[List[Dict]] - Tags
                role: Optional[str] - IAM role ARN
                sagemaker_session: Optional[Session] - Session
            
            Returns:
                DataSet: Created dataset
            
            Raises:
                ValueError: Invalid source or configuration
                ClientError: Creation errors
        
        list(max_results=100, next_token=None) -> Dict
            List datasets.
            
            Returns:
                Dict: Datasets with pagination

    Methods:
        refresh() -> None
            Refresh dataset state.
        
        wait(poll=5, timeout=300) -> None
            Wait for dataset to be available.
        
        delete(version=None) -> bool
            Delete dataset.

    Notes:
        - Datasets for fine-tuning foundation models
        - Local files automatically uploaded to S3
        - S3 files imported by reference
        - Status transitions: Importing -> Available or ImportFailed
    """

Usage:

from sagemaker.ai_registry.dataset import DataSet
from sagemaker.train.common import CustomizationTechnique

# Create dataset from local file
try:
    dataset = DataSet.create(
        name="my-training-dataset",
        source="./data/train.jsonl",  # Local file uploaded automatically
        customization_technique=CustomizationTechnique.SFT,
        description="Training dataset for customer support chatbot fine-tuning",
        wait=True,  # Block until available
        tags=[
            {"Key": "Project", "Value": "Chatbot"},
            {"Key": "DataSource", "Value": "CustomerSupport"}
        ]
    )
    
    print(f"Dataset created: {dataset.arn}")
    print(f"Status: {dataset.status}")
    print(f"S3 location: {dataset.source}")
    
except ValueError as e:
    print(f"Invalid dataset: {e}")
except ClientError as e:
    error_code = e.response['Error']['Code']
    if error_code == 'ResourceInUse':
        print("Dataset with this name already exists")
        dataset = DataSet.get("my-training-dataset")

# Create dataset from S3 (by reference)
dataset = DataSet.create(
    name="eval-dataset",
    source="s3://my-bucket/data/eval.jsonl",  # Already in S3
    customization_technique=CustomizationTechnique.SFT,
    wait=True
)

# Get existing dataset
dataset = DataSet.get("my-training-dataset")
print(f"Dataset location: {dataset.source}")
print(f"Created: {dataset.created_time}")
print(f"Method: {dataset.method}")  # UPLOADED or GENERATED

# List all datasets
datasets_response = DataSet.list(max_results=50)

print(f"Total datasets: {len(datasets_response['items'])}")
for ds_item in datasets_response['items']:
    print(f"  {ds_item['name']}: {ds_item['status']}")

# Delete dataset
if dataset.status == HubContentStatus.AVAILABLE:
    dataset.delete()
    print("Dataset deleted")

Wait for Dataset Availability:

# Create dataset without waiting
dataset = DataSet.create(
    name="large-dataset",
    source="s3://bucket/large-data.jsonl",
    wait=False  # Don't block
)

print(f"Dataset import started: {dataset.arn}")
print(f"Status: {dataset.status}")  # IMPORTING

# Do other work...

# Later: wait for availability
try:
    dataset.wait(poll=10, timeout=600)  # Wait up to 10 minutes
    print(f"Dataset available: {dataset.status}")
except TimeoutError:
    print("Dataset import timed out")
    # Check for import failures
    dataset.refresh()
    if dataset.status == HubContentStatus.IMPORT_FAILED:
        print("Import failed - check source file format")

Evaluator

Evaluator entity for AI Registry.

class Evaluator(AIRHubEntity):
    """
    Evaluator entity for AI Registry.

    Parameters:
        name: Optional[str] - Name of the evaluator
        version: Optional[str] - Version of the evaluator
        arn: Optional[str] - ARN of the evaluator
        type: Optional[str] - Type of evaluator (required at creation)
            - "RewardFunction": Programmatic reward function
            - "RewardPrompt": Text prompt for LLM-based evaluation
        method: Optional[EvaluatorMethod] - Method used by evaluator
            - BYOC: Bring your own code
            - LAMBDA: AWS Lambda function
        reference: Optional[str] - Reference to implementation (ARN, S3 URI, etc.)
        status: Optional[HubContentStatus] - Current status
        created_time: Optional[datetime] - Creation timestamp
        updated_time: Optional[datetime] - Last update timestamp
        sagemaker_session: Optional[Session] - SageMaker session

    Class Methods:
        get(name, sagemaker_session=None) -> Evaluator
            Get evaluator by name.
            
            Parameters:
                name: str - Evaluator name (required)
                sagemaker_session: Optional[Session] - Session
            
            Returns:
                Evaluator: Evaluator object
            
            Raises:
                ClientError: If evaluator not found
        
        create(name, type, source, wait=True, role=None, sagemaker_session=None) -> Evaluator
            Create new evaluator.
            
            Parameters:
                name: str - Evaluator name (required)
                type: str - Evaluator type (required)
                    - "RewardFunction" or "RewardPrompt"
                source: str - Evaluator source (required)
                    - Lambda ARN: "arn:aws:lambda:..."
                    - Local file: "./reward_function.py" (uploaded to Lambda)
                    - S3 URI: "s3://bucket/reward.py"
                    - Text prompt: "Evaluate the response quality..."
                wait: bool - Wait for availability (default: True)
                role: Optional[str] - IAM role ARN for Lambda creation
                sagemaker_session: Optional[Session] - Session
            
            Returns:
                Evaluator: Created evaluator
            
            Raises:
                ValueError: Invalid type or source
                ClientError: Creation errors
        
        list(max_results=100, next_token=None) -> Dict
            List evaluators.
            
            Returns:
                Dict: Evaluators with pagination

    Methods:
        refresh() -> None
            Refresh evaluator state.
        
        wait(poll=5, timeout=300) -> None
            Wait for evaluator to be available.
        
        delete(version=None) -> bool
            Delete evaluator.

    Notes:
        - RewardFunction: Code-based reward computation
        - RewardPrompt: LLM-based evaluation
        - Used in RLAIF and RLVR training
        - Lambda functions automatically created from local files
    """

Usage:

from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.ai_registry.air_constants import HubContentStatus

# Create reward function evaluator from Lambda
try:
    evaluator = Evaluator.create(
        name="code-correctness-evaluator",
        type="RewardFunction",
        source="arn:aws:lambda:us-west-2:123456789012:function:evaluate-code",
        wait=True
    )
    
    print(f"Evaluator created: {evaluator.arn}")
    print(f"Method: {evaluator.method}")  # LAMBDA
    
except ValueError as e:
    print(f"Invalid evaluator configuration: {e}")

# Create reward function from local code
evaluator = Evaluator.create(
    name="custom-reward",
    type="RewardFunction",
    source="./reward_function.py",  # Local file
    role="arn:aws:iam::123456789012:role/LambdaRole",  # For Lambda creation
    wait=True
)

print(f"Lambda created and evaluator registered: {evaluator.reference}")

# Create reward prompt for LLM-based evaluation
prompt_evaluator = Evaluator.create(
    name="helpfulness-evaluator",
    type="RewardPrompt",
    source="""
Evaluate the helpfulness of the following response on a scale of 1-10.

Question: {question}
Response: {response}

Consider:
- Relevance to the question
- Completeness of information
- Clarity of explanation

Score (1-10):
""",
    wait=True
)

# Get existing evaluator
evaluator = Evaluator.get("code-correctness-evaluator")
print(f"Evaluator type: {evaluator.type}")
print(f"Method: {evaluator.method}")
print(f"Reference: {evaluator.reference}")

# List all evaluators
evaluators_response = Evaluator.list(max_results=20)

print(f"Total evaluators: {len(evaluators_response['items'])}")
for ev_item in evaluators_response['items']:
    print(f"  {ev_item['name']} ({ev_item['type']}): {ev_item['status']}")

# Delete evaluator
evaluator.delete()

Reward Function Example:

# reward_function.py
import json

def lambda_handler(event, context):
    """
    Reward function for RLVR training.
    
    Args:
        event: Dict with 'prompt' and 'response' fields
            {
                "prompt": "Write a function to sort a list",
                "response": "def sort_list(lst): return sorted(lst)"
            }
    
    Returns:
        Dict with 'reward' field (float 0-1)
            {
                "reward": 0.95,
                "details": "Correct implementation"
            }
    """
    prompt = event['prompt']
    response = event['response']
    
    # Extract test cases from prompt
    test_cases = extract_test_cases(prompt)
    
    # Test generated code
    try:
        # Execute code with test cases
        passed = 0
        total = len(test_cases)
        
        for test_input, expected_output in test_cases:
            try:
                actual = execute_code(response, test_input)
                if actual == expected_output:
                    passed += 1
            except Exception:
                pass
        
        # Reward = fraction of tests passed
        reward = passed / total if total > 0 else 0.0
        
        return {
            'reward': reward,
            'details': f"Passed {passed}/{total} tests"
        }
        
    except SyntaxError:
        # Syntax errors get 0 reward
        return {
            'reward': 0.0,
            'details': "Syntax error in generated code"
        }
    
    except Exception as e:
        # Runtime errors get 0 reward
        return {
            'reward': 0.0,
            'details': f"Runtime error: {str(e)}"
        }

def execute_code(code, test_input):
    """Safely execute generated code."""
    # Implement safe execution with timeout and sandboxing
    pass

def extract_test_cases(prompt):
    """Extract test cases from prompt."""
    # Parse prompt for test cases
    pass

Use Evaluator in Training:

from sagemaker.train.rlvr_trainer import RLVRTrainer

# Get evaluator
evaluator = Evaluator.get("code-correctness-evaluator")

# Use in RLVR training
trainer = RLVRTrainer(
    model="codellama/CodeLlama-7b-hf",
    custom_reward_function=evaluator,  # Use registered evaluator
    training_dataset="s3://bucket/coding-tasks.jsonl",
    accept_eula=True
)

job = trainer.train()

Enums

HubContentStatus

class HubContentStatus(Enum):
    """
    Hub content status enumeration.

    Values:
        AVAILABLE = "Available"
            Content is available for use
            - Dataset ready for training
            - Evaluator ready for evaluation
        
        IMPORTING = "Importing"
            Content is being imported
            - Upload in progress
            - Processing ongoing
        
        DELETING = "Deleting"
            Content is being deleted
        
        IMPORT_FAILED = "ImportFailed"
            Import failed
            - Check failure reason
            - May need to retry
        
        DELETE_FAILED = "DeleteFailed"
            Deletion failed
            - May need manual cleanup

    Usage:
        Check dataset/evaluator status.
        Wait for AVAILABLE before using.
    
    Notes:
        - Terminal states: AVAILABLE, IMPORT_FAILED, DELETE_FAILED
        - Transient states: IMPORTING, DELETING
        - Always wait for AVAILABLE before training
    """

DataSetMethod

class DataSetMethod(Enum):
    """
    Dataset method enumeration.

    Values:
        UPLOADED = "uploaded"
            Dataset was uploaded by user
            - From local file or S3
        
        GENERATED = "generated"
            Dataset was generated programmatically
            - By AI Registry system or custom process

    Usage:
        Track dataset provenance.
    
    Notes:
        - Read-only attribute
        - Set automatically during creation
    """

EvaluatorMethod

class EvaluatorMethod(Enum):
    """
    Evaluator method enumeration.

    Values:
        BYOC = "byoc"
            Bring your own code
            - Custom code uploaded to Lambda
        
        LAMBDA = "lambda"
            Existing AWS Lambda function
            - Lambda ARN provided directly

    Usage:
        Track evaluator implementation method.
    
    Notes:
        - BYOC: Code uploaded and Lambda created automatically
        - LAMBDA: Existing Lambda function referenced
        - Read-only attribute
    """

Complete Example

from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.ai_registry.air_constants import HubContentStatus
from sagemaker.train.common import CustomizationTechnique

# Step 1: Create training dataset
print("Creating training dataset...")
train_dataset = DataSet.create(
    name="chatbot-training-v1",
    source="./data/train.jsonl",  # Local file
    customization_technique=CustomizationTechnique.SFT,
    description="Customer support chatbot training data",
    tags=[
        {"Key": "Project", "Value": "Chatbot"},
        {"Key": "Version", "Value": "1.0"}
    ],
    wait=True
)

# Wait for import to complete
if train_dataset.status != HubContentStatus.AVAILABLE:
    print(f"Waiting for dataset import...")
    train_dataset.wait(poll=5, timeout=300)

print(f"Training dataset ready: {train_dataset.arn}")
print(f"  S3 location: {train_dataset.source}")
print(f"  Status: {train_dataset.status}")

# Step 2: Create evaluation dataset
print("\nCreating evaluation dataset...")
eval_dataset = DataSet.create(
    name="chatbot-eval-v1",
    source="s3://my-bucket/data/eval.jsonl",  # S3 file
    customization_technique=CustomizationTechnique.SFT,
    wait=True
)

# Step 3: Create reward evaluator
print("\nCreating reward evaluator...")
evaluator = Evaluator.create(
    name="helpfulness-evaluator",
    type="RewardFunction",
    source="arn:aws:lambda:us-west-2:123456789012:function:evaluate-helpfulness",
    wait=True
)

print(f"Evaluator ready: {evaluator.arn}")

# Step 4: List all registry content
print("\n=== AI Registry Contents ===")

print("\nDatasets:")
datasets = DataSet.list(max_results=20)
for ds in datasets['items']:
    print(f"  - {ds['name']} (v{ds['version']}): {ds['status']}")

print("\nEvaluators:")
evaluators = Evaluator.list(max_results=20)
for ev in evaluators['items']:
    print(f"  - {ev['name']} ({ev['type']}): {ev['status']}")

# Step 5: Use in training
from sagemaker.train.sft_trainer import SFTTrainer

trainer = SFTTrainer(
    model="meta-llama/Llama-2-7b-hf",
    training_dataset=train_dataset,  # Use DataSet object
    validation_dataset=eval_dataset,
    accept_eula=True
)

job = trainer.train()

# Step 6: Cleanup (optional)
# train_dataset.delete()
# eval_dataset.delete()
# evaluator.delete()

Dataset Version Management:

# List versions of dataset
dataset = DataSet.get("chatbot-training-v1")
versions = dataset.get_versions()

print(f"Dataset versions for {dataset.name}:")
for version_info in versions:
    print(f"  Version {version_info['Version']}: {version_info['Status']}")
    print(f"    Created: {version_info['CreationTime']}")

# Delete specific version
dataset.delete(version="1")

# Delete latest version
dataset.delete()

Validation and Constraints

Dataset Constraints

  • Dataset name: 1-63 characters, alphanumeric and hyphens
  • Maximum file size: 10 GB
  • Supported formats: JSONL (for fine-tuning)
  • Minimum examples: 10 (recommended 100+)
  • Maximum datasets per account: 1000

Evaluator Constraints

  • Evaluator name: 1-63 characters, alphanumeric and hyphens
  • Lambda timeout: Maximum 15 minutes
  • Lambda memory: 128 MB - 10 GB
  • Reward value range: 0.0 - 1.0 (normalized)
  • Maximum evaluators per account: 100

Dataset Format Requirements

For SFT (Supervised Fine-Tuning):

{"prompt": "Question or instruction", "completion": "Expected response"}
{"prompt": "Another question", "completion": "Another response"}

For DPO (Direct Preference Optimization):

{"prompt": "Question", "chosen": "Better response", "rejected": "Worse response"}

For RLAIF/RLVR:

{"prompt": "Task or question"}
{"prompt": "Another task"}

Common Error Scenarios

  1. Import Failed:

    • Cause: Invalid file format or corrupted data
    • Solution: Validate JSONL format, check file encoding (UTF-8)
  2. Dataset Not Found:

    • Cause: Wrong name or not created yet
    • Solution: Verify name, check if creation completed
  3. Evaluator Lambda Timeout:

    • Cause: Reward function takes >15 minutes
    • Solution: Optimize reward function, increase Lambda timeout
  4. Invalid Reward Value:

    • Cause: Reward function returns value outside [0, 1]
    • Solution: Normalize reward to 0-1 range
  5. Source File Too Large:

    • Cause: Dataset >10 GB
    • Solution: Split into multiple datasets or sample data
  6. Concurrent Import Limit:

    • Cause: Too many simultaneous imports
    • Solution: Wait for current imports to complete, then retry