Manage datasets and evaluators in the SageMaker AI Registry Hub for model customization workflows.
Important: The sagemaker.ai_registry module does not export classes from its __init__.py. Classes must be imported using their full module paths:
# Correct imports (full module paths required)
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator, EvaluatorMethod
from sagemaker.ai_registry.dataset_utils import DataSetMethod
from sagemaker.ai_registry.air_constants import HubContentStatus
# Incorrect (will not work)
# from sagemaker.ai_registry import DataSet # ImportErrorCentral hub client for managing AI Registry content operations.
class AIRHub:
"""
AI Registry Hub client for managing hub content.
Class Methods:
get_hub_name() -> str
Get hub name for current account/region.
Returns:
str: Hub name
import_hub_content(hub_content_type, hub_content_name, document_schema_version,
hub_content_document, hub_content_version, tags, session) -> Dict
Import content to hub.
Parameters:
hub_content_type: str - Content type ("Dataset" or "Evaluator")
hub_content_name: str - Content name (required)
document_schema_version: str - Schema version (required)
hub_content_document: Dict - Content document (required)
hub_content_version: str - Content version (required)
tags: Optional[List[Dict]] - Tags
session: Optional[Session] - SageMaker session
Returns:
Dict: Import response with ARN
list_hub_content(hub_content_type, max_results=100, next_token=None, session=None) -> Dict
List hub content with details.
Parameters:
hub_content_type: str - Content type filter
max_results: int - Maximum results (default: 100)
next_token: Optional[str] - Pagination token
session: Optional[Session] - Session
Returns:
Dict: Content list with pagination
describe_hub_content(hub_content_type, hub_content_name, hub_content_version, session=None) -> Dict
Describe hub content.
Parameters:
hub_content_type: str - Content type (required)
hub_content_name: str - Content name (required)
hub_content_version: str - Content version (required)
session: Optional[Session] - Session
Returns:
Dict: Content details
delete_hub_content(hub_content_type, hub_content_name, hub_content_version, session=None) -> None
Delete hub content.
Parameters:
hub_content_type: str - Content type (required)
hub_content_name: str - Content name (required)
hub_content_version: str - Content version (required)
session: Optional[Session] - Session
list_hub_content_versions(hub_content_type, hub_content_name, session=None) -> List[Dict]
List content versions.
Returns:
List[Dict]: Version information
Notes:
- Low-level hub operations
- Prefer DataSet and Evaluator classes
- Use for advanced hub management
"""Base class for AI Registry Hub entities.
class AIRHubEntity(ABC):
"""
Base entity for AI Registry Hub content.
Parameters:
name: str - Name of the hub content (required)
version: str - Version of the hub content (required)
arn: str - ARN of the hub content (required)
status: Optional[HubContentStatus] - Status of the hub content
created_time: Optional[str] - Creation timestamp
updated_time: Optional[str] - Last update timestamp
description: Optional[str] - Description of the hub content
sagemaker_session: Optional[Session] - SageMaker session
Abstract Methods:
hub_content_type() -> str
Return hub content type.
Returns:
str: "Dataset" or "Evaluator"
_get_hub_content_type_for_list() -> str
Return content type for list operation.
Methods:
list(max_results=100, next_token=None) -> Dict
List all entities of this type.
Parameters:
max_results: int - Maximum results (default: 100)
next_token: Optional[str] - Pagination token
Returns:
Dict: Entities with pagination info
get_versions() -> List[Dict]
List all versions of this entity.
Returns:
List[Dict]: Version information
delete(version=None) -> bool
Delete this entity instance.
Parameters:
version: Optional[str] - Specific version (default: current)
Returns:
bool: True if deleted successfully
delete_by_name(name, version=None) -> bool
Delete entity by name.
Parameters:
name: str - Entity name (required)
version: Optional[str] - Version
Returns:
bool: True if deleted
wait(poll=5, timeout=300) -> None
Wait for entity to reach terminal state.
Parameters:
poll: int - Polling interval (default: 5 seconds)
timeout: int - Timeout (default: 300 seconds)
Raises:
TimeoutError: If timeout exceeded
refresh() -> None
Refresh entity state from API.
Notes:
- Base class for DataSet and Evaluator
- Provides common operations
- Status tracked through lifecycle
"""Dataset entity for AI Registry.
class DataSet(AIRHubEntity):
"""
Dataset entity for AI Registry.
Parameters:
name: str - Name of the dataset (required)
arn: str - ARN of the dataset (required)
version: str - Version of the dataset (required)
status: HubContentStatus - Current status (required)
source: Optional[str] - S3 location of the dataset
description: Optional[str] - Description
customization_technique: Optional[CustomizationTechnique] - Customization technique
- SFT, DPO, RLAIF, RLVR
method: Optional[DataSetMethod] - Method used to create dataset
- UPLOADED or GENERATED
created_time: Optional[datetime] - Creation timestamp
updated_time: Optional[datetime] - Last update timestamp
sagemaker_session: Optional[Session] - SageMaker session
Class Methods:
get(name, sagemaker_session=None) -> DataSet
Get dataset by name.
Parameters:
name: str - Dataset name (required)
sagemaker_session: Optional[Session] - Session
Returns:
DataSet: Dataset object
Raises:
ClientError: If dataset not found
create(name, source, customization_technique=None, wait=True, description=None,
tags=None, role=None, sagemaker_session=None) -> DataSet
Create new dataset.
Parameters:
name: str - Dataset name (required)
source: str - Local file path or S3 URI (required)
customization_technique: Optional[CustomizationTechnique] - Technique
wait: bool - Wait for availability (default: True)
description: Optional[str] - Description
tags: Optional[List[Dict]] - Tags
role: Optional[str] - IAM role ARN
sagemaker_session: Optional[Session] - Session
Returns:
DataSet: Created dataset
Raises:
ValueError: Invalid source or configuration
ClientError: Creation errors
list(max_results=100, next_token=None) -> Dict
List datasets.
Returns:
Dict: Datasets with pagination
Methods:
refresh() -> None
Refresh dataset state.
wait(poll=5, timeout=300) -> None
Wait for dataset to be available.
delete(version=None) -> bool
Delete dataset.
Notes:
- Datasets for fine-tuning foundation models
- Local files automatically uploaded to S3
- S3 files imported by reference
- Status transitions: Importing -> Available or ImportFailed
"""Usage:
from sagemaker.ai_registry.dataset import DataSet
from sagemaker.train.common import CustomizationTechnique
# Create dataset from local file
try:
dataset = DataSet.create(
name="my-training-dataset",
source="./data/train.jsonl", # Local file uploaded automatically
customization_technique=CustomizationTechnique.SFT,
description="Training dataset for customer support chatbot fine-tuning",
wait=True, # Block until available
tags=[
{"Key": "Project", "Value": "Chatbot"},
{"Key": "DataSource", "Value": "CustomerSupport"}
]
)
print(f"Dataset created: {dataset.arn}")
print(f"Status: {dataset.status}")
print(f"S3 location: {dataset.source}")
except ValueError as e:
print(f"Invalid dataset: {e}")
except ClientError as e:
error_code = e.response['Error']['Code']
if error_code == 'ResourceInUse':
print("Dataset with this name already exists")
dataset = DataSet.get("my-training-dataset")
# Create dataset from S3 (by reference)
dataset = DataSet.create(
name="eval-dataset",
source="s3://my-bucket/data/eval.jsonl", # Already in S3
customization_technique=CustomizationTechnique.SFT,
wait=True
)
# Get existing dataset
dataset = DataSet.get("my-training-dataset")
print(f"Dataset location: {dataset.source}")
print(f"Created: {dataset.created_time}")
print(f"Method: {dataset.method}") # UPLOADED or GENERATED
# List all datasets
datasets_response = DataSet.list(max_results=50)
print(f"Total datasets: {len(datasets_response['items'])}")
for ds_item in datasets_response['items']:
print(f" {ds_item['name']}: {ds_item['status']}")
# Delete dataset
if dataset.status == HubContentStatus.AVAILABLE:
dataset.delete()
print("Dataset deleted")Wait for Dataset Availability:
# Create dataset without waiting
dataset = DataSet.create(
name="large-dataset",
source="s3://bucket/large-data.jsonl",
wait=False # Don't block
)
print(f"Dataset import started: {dataset.arn}")
print(f"Status: {dataset.status}") # IMPORTING
# Do other work...
# Later: wait for availability
try:
dataset.wait(poll=10, timeout=600) # Wait up to 10 minutes
print(f"Dataset available: {dataset.status}")
except TimeoutError:
print("Dataset import timed out")
# Check for import failures
dataset.refresh()
if dataset.status == HubContentStatus.IMPORT_FAILED:
print("Import failed - check source file format")Evaluator entity for AI Registry.
class Evaluator(AIRHubEntity):
"""
Evaluator entity for AI Registry.
Parameters:
name: Optional[str] - Name of the evaluator
version: Optional[str] - Version of the evaluator
arn: Optional[str] - ARN of the evaluator
type: Optional[str] - Type of evaluator (required at creation)
- "RewardFunction": Programmatic reward function
- "RewardPrompt": Text prompt for LLM-based evaluation
method: Optional[EvaluatorMethod] - Method used by evaluator
- BYOC: Bring your own code
- LAMBDA: AWS Lambda function
reference: Optional[str] - Reference to implementation (ARN, S3 URI, etc.)
status: Optional[HubContentStatus] - Current status
created_time: Optional[datetime] - Creation timestamp
updated_time: Optional[datetime] - Last update timestamp
sagemaker_session: Optional[Session] - SageMaker session
Class Methods:
get(name, sagemaker_session=None) -> Evaluator
Get evaluator by name.
Parameters:
name: str - Evaluator name (required)
sagemaker_session: Optional[Session] - Session
Returns:
Evaluator: Evaluator object
Raises:
ClientError: If evaluator not found
create(name, type, source, wait=True, role=None, sagemaker_session=None) -> Evaluator
Create new evaluator.
Parameters:
name: str - Evaluator name (required)
type: str - Evaluator type (required)
- "RewardFunction" or "RewardPrompt"
source: str - Evaluator source (required)
- Lambda ARN: "arn:aws:lambda:..."
- Local file: "./reward_function.py" (uploaded to Lambda)
- S3 URI: "s3://bucket/reward.py"
- Text prompt: "Evaluate the response quality..."
wait: bool - Wait for availability (default: True)
role: Optional[str] - IAM role ARN for Lambda creation
sagemaker_session: Optional[Session] - Session
Returns:
Evaluator: Created evaluator
Raises:
ValueError: Invalid type or source
ClientError: Creation errors
list(max_results=100, next_token=None) -> Dict
List evaluators.
Returns:
Dict: Evaluators with pagination
Methods:
refresh() -> None
Refresh evaluator state.
wait(poll=5, timeout=300) -> None
Wait for evaluator to be available.
delete(version=None) -> bool
Delete evaluator.
Notes:
- RewardFunction: Code-based reward computation
- RewardPrompt: LLM-based evaluation
- Used in RLAIF and RLVR training
- Lambda functions automatically created from local files
"""Usage:
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.ai_registry.air_constants import HubContentStatus
# Create reward function evaluator from Lambda
try:
evaluator = Evaluator.create(
name="code-correctness-evaluator",
type="RewardFunction",
source="arn:aws:lambda:us-west-2:123456789012:function:evaluate-code",
wait=True
)
print(f"Evaluator created: {evaluator.arn}")
print(f"Method: {evaluator.method}") # LAMBDA
except ValueError as e:
print(f"Invalid evaluator configuration: {e}")
# Create reward function from local code
evaluator = Evaluator.create(
name="custom-reward",
type="RewardFunction",
source="./reward_function.py", # Local file
role="arn:aws:iam::123456789012:role/LambdaRole", # For Lambda creation
wait=True
)
print(f"Lambda created and evaluator registered: {evaluator.reference}")
# Create reward prompt for LLM-based evaluation
prompt_evaluator = Evaluator.create(
name="helpfulness-evaluator",
type="RewardPrompt",
source="""
Evaluate the helpfulness of the following response on a scale of 1-10.
Question: {question}
Response: {response}
Consider:
- Relevance to the question
- Completeness of information
- Clarity of explanation
Score (1-10):
""",
wait=True
)
# Get existing evaluator
evaluator = Evaluator.get("code-correctness-evaluator")
print(f"Evaluator type: {evaluator.type}")
print(f"Method: {evaluator.method}")
print(f"Reference: {evaluator.reference}")
# List all evaluators
evaluators_response = Evaluator.list(max_results=20)
print(f"Total evaluators: {len(evaluators_response['items'])}")
for ev_item in evaluators_response['items']:
print(f" {ev_item['name']} ({ev_item['type']}): {ev_item['status']}")
# Delete evaluator
evaluator.delete()Reward Function Example:
# reward_function.py
import json
def lambda_handler(event, context):
"""
Reward function for RLVR training.
Args:
event: Dict with 'prompt' and 'response' fields
{
"prompt": "Write a function to sort a list",
"response": "def sort_list(lst): return sorted(lst)"
}
Returns:
Dict with 'reward' field (float 0-1)
{
"reward": 0.95,
"details": "Correct implementation"
}
"""
prompt = event['prompt']
response = event['response']
# Extract test cases from prompt
test_cases = extract_test_cases(prompt)
# Test generated code
try:
# Execute code with test cases
passed = 0
total = len(test_cases)
for test_input, expected_output in test_cases:
try:
actual = execute_code(response, test_input)
if actual == expected_output:
passed += 1
except Exception:
pass
# Reward = fraction of tests passed
reward = passed / total if total > 0 else 0.0
return {
'reward': reward,
'details': f"Passed {passed}/{total} tests"
}
except SyntaxError:
# Syntax errors get 0 reward
return {
'reward': 0.0,
'details': "Syntax error in generated code"
}
except Exception as e:
# Runtime errors get 0 reward
return {
'reward': 0.0,
'details': f"Runtime error: {str(e)}"
}
def execute_code(code, test_input):
"""Safely execute generated code."""
# Implement safe execution with timeout and sandboxing
pass
def extract_test_cases(prompt):
"""Extract test cases from prompt."""
# Parse prompt for test cases
passUse Evaluator in Training:
from sagemaker.train.rlvr_trainer import RLVRTrainer
# Get evaluator
evaluator = Evaluator.get("code-correctness-evaluator")
# Use in RLVR training
trainer = RLVRTrainer(
model="codellama/CodeLlama-7b-hf",
custom_reward_function=evaluator, # Use registered evaluator
training_dataset="s3://bucket/coding-tasks.jsonl",
accept_eula=True
)
job = trainer.train()class HubContentStatus(Enum):
"""
Hub content status enumeration.
Values:
AVAILABLE = "Available"
Content is available for use
- Dataset ready for training
- Evaluator ready for evaluation
IMPORTING = "Importing"
Content is being imported
- Upload in progress
- Processing ongoing
DELETING = "Deleting"
Content is being deleted
IMPORT_FAILED = "ImportFailed"
Import failed
- Check failure reason
- May need to retry
DELETE_FAILED = "DeleteFailed"
Deletion failed
- May need manual cleanup
Usage:
Check dataset/evaluator status.
Wait for AVAILABLE before using.
Notes:
- Terminal states: AVAILABLE, IMPORT_FAILED, DELETE_FAILED
- Transient states: IMPORTING, DELETING
- Always wait for AVAILABLE before training
"""class DataSetMethod(Enum):
"""
Dataset method enumeration.
Values:
UPLOADED = "uploaded"
Dataset was uploaded by user
- From local file or S3
GENERATED = "generated"
Dataset was generated programmatically
- By AI Registry system or custom process
Usage:
Track dataset provenance.
Notes:
- Read-only attribute
- Set automatically during creation
"""class EvaluatorMethod(Enum):
"""
Evaluator method enumeration.
Values:
BYOC = "byoc"
Bring your own code
- Custom code uploaded to Lambda
LAMBDA = "lambda"
Existing AWS Lambda function
- Lambda ARN provided directly
Usage:
Track evaluator implementation method.
Notes:
- BYOC: Code uploaded and Lambda created automatically
- LAMBDA: Existing Lambda function referenced
- Read-only attribute
"""from sagemaker.ai_registry.dataset import DataSet
from sagemaker.ai_registry.evaluator import Evaluator
from sagemaker.ai_registry.air_constants import HubContentStatus
from sagemaker.train.common import CustomizationTechnique
# Step 1: Create training dataset
print("Creating training dataset...")
train_dataset = DataSet.create(
name="chatbot-training-v1",
source="./data/train.jsonl", # Local file
customization_technique=CustomizationTechnique.SFT,
description="Customer support chatbot training data",
tags=[
{"Key": "Project", "Value": "Chatbot"},
{"Key": "Version", "Value": "1.0"}
],
wait=True
)
# Wait for import to complete
if train_dataset.status != HubContentStatus.AVAILABLE:
print(f"Waiting for dataset import...")
train_dataset.wait(poll=5, timeout=300)
print(f"Training dataset ready: {train_dataset.arn}")
print(f" S3 location: {train_dataset.source}")
print(f" Status: {train_dataset.status}")
# Step 2: Create evaluation dataset
print("\nCreating evaluation dataset...")
eval_dataset = DataSet.create(
name="chatbot-eval-v1",
source="s3://my-bucket/data/eval.jsonl", # S3 file
customization_technique=CustomizationTechnique.SFT,
wait=True
)
# Step 3: Create reward evaluator
print("\nCreating reward evaluator...")
evaluator = Evaluator.create(
name="helpfulness-evaluator",
type="RewardFunction",
source="arn:aws:lambda:us-west-2:123456789012:function:evaluate-helpfulness",
wait=True
)
print(f"Evaluator ready: {evaluator.arn}")
# Step 4: List all registry content
print("\n=== AI Registry Contents ===")
print("\nDatasets:")
datasets = DataSet.list(max_results=20)
for ds in datasets['items']:
print(f" - {ds['name']} (v{ds['version']}): {ds['status']}")
print("\nEvaluators:")
evaluators = Evaluator.list(max_results=20)
for ev in evaluators['items']:
print(f" - {ev['name']} ({ev['type']}): {ev['status']}")
# Step 5: Use in training
from sagemaker.train.sft_trainer import SFTTrainer
trainer = SFTTrainer(
model="meta-llama/Llama-2-7b-hf",
training_dataset=train_dataset, # Use DataSet object
validation_dataset=eval_dataset,
accept_eula=True
)
job = trainer.train()
# Step 6: Cleanup (optional)
# train_dataset.delete()
# eval_dataset.delete()
# evaluator.delete()Dataset Version Management:
# List versions of dataset
dataset = DataSet.get("chatbot-training-v1")
versions = dataset.get_versions()
print(f"Dataset versions for {dataset.name}:")
for version_info in versions:
print(f" Version {version_info['Version']}: {version_info['Status']}")
print(f" Created: {version_info['CreationTime']}")
# Delete specific version
dataset.delete(version="1")
# Delete latest version
dataset.delete()For SFT (Supervised Fine-Tuning):
{"prompt": "Question or instruction", "completion": "Expected response"}
{"prompt": "Another question", "completion": "Another response"}For DPO (Direct Preference Optimization):
{"prompt": "Question", "chosen": "Better response", "rejected": "Worse response"}For RLAIF/RLVR:
{"prompt": "Task or question"}
{"prompt": "Another task"}Import Failed:
Dataset Not Found:
Evaluator Lambda Timeout:
Invalid Reward Value:
Source File Too Large:
Concurrent Import Limit: