Provider package that enables OpenAI integration for Apache Airflow, including hooks, operators, and triggers for AI-powered data pipelines.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Airflow triggers provide asynchronous monitoring capabilities for long-running OpenAI operations, enabling efficient resource usage and proper handling of batch processing workflows.
Asynchronously monitor OpenAI Batch API operations with configurable polling intervals and timeout handling.
class OpenAIBatchTrigger(BaseTrigger):
"""
Triggers OpenAI Batch API monitoring for long-running batch operations.
Args:
conn_id (str): The OpenAI connection ID to use
batch_id (str): The ID of the batch to monitor
poll_interval (float): Number of seconds between status checks
end_time (float): Unix timestamp when monitoring should timeout
"""
def __init__(
self,
conn_id: str,
batch_id: str,
poll_interval: float,
end_time: float,
) -> None: ...
def serialize(self) -> tuple[str, dict[str, Any]]:
"""
Serialize OpenAIBatchTrigger arguments and class path for persistence.
Returns:
Tuple of (class_path, serialized_arguments)
"""
async def run(self) -> AsyncIterator[TriggerEvent]:
"""
Make connection to OpenAI Client and poll the status of batch.
Yields:
TriggerEvent: Events indicating batch status changes or completion
Events:
- {"status": "success", "message": "...", "batch_id": "..."}: Batch completed successfully
- {"status": "cancelled", "message": "...", "batch_id": "..."}: Batch was cancelled
- {"status": "error", "message": "...", "batch_id": "..."}: Batch failed or timed out
"""import time
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger
# Create trigger for batch monitoring
trigger = OpenAIBatchTrigger(
conn_id='openai_default',
batch_id='batch_abc123',
poll_interval=60, # Check every minute
end_time=time.time() + 3600 # Timeout after 1 hour
)
# Serialize for storage (handled automatically by Airflow)
class_path, args = trigger.serialize()
print(f"Trigger class: {class_path}")
print(f"Trigger args: {args}")from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.providers.openai.hooks.openai import OpenAIHook
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger
dag = DAG(
'deferred_batch_processing',
start_date=datetime(2024, 1, 1),
schedule_interval=None,
catchup=False
)
def start_batch_processing(**context):
"""Start a batch and defer to trigger for monitoring."""
hook = OpenAIHook(conn_id='openai_default')
# Create batch
batch = hook.create_batch(
file_id=context['params']['file_id'],
endpoint="/v1/chat/completions"
)
# Store batch ID for trigger
context['task_instance'].xcom_push(key='batch_id', value=batch.id)
# Defer to trigger
context['task_instance'].defer(
trigger=OpenAIBatchTrigger(
conn_id='openai_default',
batch_id=batch.id,
poll_interval=120, # Check every 2 minutes
end_time=time.time() + 86400 # 24 hour timeout
),
method_name='handle_batch_completion'
)
def handle_batch_completion(**context):
"""Handle batch completion event."""
event = context['event']
if event['status'] == 'success':
print(f"Batch {event['batch_id']} completed successfully!")
return event['batch_id']
elif event['status'] == 'cancelled':
print(f"Batch {event['batch_id']} was cancelled: {event['message']}")
raise Exception(f"Batch cancelled: {event['message']}")
else: # error
print(f"Batch {event['batch_id']} failed: {event['message']}")
raise Exception(f"Batch failed: {event['message']}")
deferred_batch_task = PythonOperator(
task_id='deferred_batch_processing',
python_callable=start_batch_processing,
params={'file_id': 'file-xyz789'},
dag=dag
)import asyncio
import time
from collections.abc import AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
class CustomOpenAIBatchTrigger(BaseTrigger):
"""Extended batch trigger with custom monitoring logic."""
def __init__(
self,
conn_id: str,
batch_id: str,
poll_interval: float,
end_time: float,
progress_callback: str = None
):
super().__init__()
self.conn_id = conn_id
self.batch_id = batch_id
self.poll_interval = poll_interval
self.end_time = end_time
self.progress_callback = progress_callback
def serialize(self) -> tuple[str, dict]:
return (
f"{self.__class__.__module__}.{self.__class__.__name__}",
{
"conn_id": self.conn_id,
"batch_id": self.batch_id,
"poll_interval": self.poll_interval,
"end_time": self.end_time,
"progress_callback": self.progress_callback
}
)
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Enhanced monitoring with progress tracking."""
hook = OpenAIHook(conn_id=self.conn_id)
last_status = None
try:
while True:
current_time = time.time()
# Check timeout
if current_time >= self.end_time:
yield TriggerEvent({
"status": "error",
"message": f"Batch {self.batch_id} monitoring timed out after {current_time - self.end_time} seconds",
"batch_id": self.batch_id
})
return
# Get batch status
batch = hook.get_batch(self.batch_id)
# Emit progress events for status changes
if batch.status != last_status:
yield TriggerEvent({
"status": "progress",
"message": f"Batch status changed from {last_status} to {batch.status}",
"batch_id": self.batch_id,
"batch_status": batch.status
})
last_status = batch.status
# Check for completion
if not BatchStatus.is_in_progress(batch.status):
if batch.status == BatchStatus.COMPLETED:
yield TriggerEvent({
"status": "success",
"message": f"Batch {self.batch_id} completed successfully",
"batch_id": self.batch_id,
"final_status": batch.status
})
elif batch.status in {BatchStatus.CANCELLED, BatchStatus.CANCELLING}:
yield TriggerEvent({
"status": "cancelled",
"message": f"Batch {self.batch_id} was cancelled",
"batch_id": self.batch_id,
"final_status": batch.status
})
else: # FAILED, EXPIRED, or other error states
yield TriggerEvent({
"status": "error",
"message": f"Batch {self.batch_id} failed with status: {batch.status}",
"batch_id": self.batch_id,
"final_status": batch.status
})
return
# Wait before next check
await asyncio.sleep(self.poll_interval)
except Exception as e:
yield TriggerEvent({
"status": "error",
"message": f"Trigger error: {str(e)}",
"batch_id": self.batch_id
})import asyncio
from typing import Dict, List
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
class MultiBatchTrigger(BaseTrigger):
"""Monitor multiple OpenAI batches simultaneously."""
def __init__(
self,
conn_id: str,
batch_ids: List[str],
poll_interval: float,
end_time: float
):
super().__init__()
self.conn_id = conn_id
self.batch_ids = batch_ids
self.poll_interval = poll_interval
self.end_time = end_time
def serialize(self) -> tuple[str, dict]:
return (
f"{self.__class__.__module__}.{self.__class__.__name__}",
{
"conn_id": self.conn_id,
"batch_ids": self.batch_ids,
"poll_interval": self.poll_interval,
"end_time": self.end_time
}
)
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Monitor multiple batches until all complete."""
hook = OpenAIHook(conn_id=self.conn_id)
completed_batches = set()
failed_batches = set()
try:
while len(completed_batches) + len(failed_batches) < len(self.batch_ids):
current_time = time.time()
# Check timeout
if current_time >= self.end_time:
remaining = set(self.batch_ids) - completed_batches - failed_batches
yield TriggerEvent({
"status": "timeout",
"message": f"Timeout reached. Remaining batches: {list(remaining)}",
"completed_batches": list(completed_batches),
"failed_batches": list(failed_batches),
"remaining_batches": list(remaining)
})
return
# Check each batch
for batch_id in self.batch_ids:
if batch_id in completed_batches or batch_id in failed_batches:
continue
batch = hook.get_batch(batch_id)
if not BatchStatus.is_in_progress(batch.status):
if batch.status == BatchStatus.COMPLETED:
completed_batches.add(batch_id)
yield TriggerEvent({
"status": "batch_completed",
"message": f"Batch {batch_id} completed",
"batch_id": batch_id,
"completed_count": len(completed_batches),
"total_count": len(self.batch_ids)
})
else:
failed_batches.add(batch_id)
yield TriggerEvent({
"status": "batch_failed",
"message": f"Batch {batch_id} failed with status: {batch.status}",
"batch_id": batch_id,
"batch_status": batch.status,
"failed_count": len(failed_batches),
"total_count": len(self.batch_ids)
})
await asyncio.sleep(self.poll_interval)
# All batches completed
if failed_batches:
yield TriggerEvent({
"status": "partial_success",
"message": f"Processing complete. {len(completed_batches)} succeeded, {len(failed_batches)} failed",
"completed_batches": list(completed_batches),
"failed_batches": list(failed_batches)
})
else:
yield TriggerEvent({
"status": "success",
"message": f"All {len(completed_batches)} batches completed successfully",
"completed_batches": list(completed_batches)
})
except Exception as e:
yield TriggerEvent({
"status": "error",
"message": f"Multi-batch trigger error: {str(e)}",
"batch_ids": self.batch_ids
})
# Usage example
multi_batch_trigger = MultiBatchTrigger(
conn_id='openai_default',
batch_ids=['batch_1', 'batch_2', 'batch_3'],
poll_interval=60,
end_time=time.time() + 7200 # 2 hours
)from airflow.sensors.base import BaseSensorOperator
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger
class OpenAIBatchSensor(BaseSensorOperator):
"""Sensor that waits for OpenAI batch completion."""
def __init__(
self,
batch_id: str,
conn_id: str = 'openai_default',
poll_interval: float = 60,
**kwargs
):
super().__init__(**kwargs)
self.batch_id = batch_id
self.conn_id = conn_id
self.poll_interval = poll_interval
def poke(self, context) -> bool:
"""Check if batch is complete."""
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
hook = OpenAIHook(conn_id=self.conn_id)
batch = hook.get_batch(self.batch_id)
if batch.status == BatchStatus.COMPLETED:
return True
elif batch.status in {BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED}:
raise Exception(f"Batch {self.batch_id} failed with status: {batch.status}")
return False
# Use the sensor in a DAG
batch_sensor = OpenAIBatchSensor(
task_id='wait_for_batch',
batch_id="{{ task_instance.xcom_pull(task_ids='create_batch') }}",
conn_id='openai_default',
poll_interval=30,
timeout=3600,
dag=dag
)Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-openai