CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-apache-airflow-providers-openai

Provider package that enables OpenAI integration for Apache Airflow, including hooks, operators, and triggers for AI-powered data pipelines.

Pending

Quality

Pending

Does it follow best practices?

Impact

Pending

No eval scenarios have been run

Overview
Eval results
Files

triggers.mddocs/

Triggers

Airflow triggers provide asynchronous monitoring capabilities for long-running OpenAI operations, enabling efficient resource usage and proper handling of batch processing workflows.

Capabilities

Batch Processing Trigger

Asynchronously monitor OpenAI Batch API operations with configurable polling intervals and timeout handling.

class OpenAIBatchTrigger(BaseTrigger):
    """
    Triggers OpenAI Batch API monitoring for long-running batch operations.
    
    Args:
        conn_id (str): The OpenAI connection ID to use
        batch_id (str): The ID of the batch to monitor
        poll_interval (float): Number of seconds between status checks
        end_time (float): Unix timestamp when monitoring should timeout
    """
    
    def __init__(
        self,
        conn_id: str,
        batch_id: str,
        poll_interval: float,
        end_time: float,
    ) -> None: ...
    
    def serialize(self) -> tuple[str, dict[str, Any]]:
        """
        Serialize OpenAIBatchTrigger arguments and class path for persistence.
        
        Returns:
            Tuple of (class_path, serialized_arguments)
        """
    
    async def run(self) -> AsyncIterator[TriggerEvent]:
        """
        Make connection to OpenAI Client and poll the status of batch.
        
        Yields:
            TriggerEvent: Events indicating batch status changes or completion
            
        Events:
            - {"status": "success", "message": "...", "batch_id": "..."}: Batch completed successfully
            - {"status": "cancelled", "message": "...", "batch_id": "..."}: Batch was cancelled
            - {"status": "error", "message": "...", "batch_id": "..."}: Batch failed or timed out
        """

Usage Examples

Direct Trigger Usage

import time
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

# Create trigger for batch monitoring
trigger = OpenAIBatchTrigger(
    conn_id='openai_default',
    batch_id='batch_abc123',
    poll_interval=60,  # Check every minute
    end_time=time.time() + 3600  # Timeout after 1 hour
)

# Serialize for storage (handled automatically by Airflow)
class_path, args = trigger.serialize()
print(f"Trigger class: {class_path}")
print(f"Trigger args: {args}")

Integration with Deferrable Operator

from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.providers.openai.hooks.openai import OpenAIHook
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

dag = DAG(
    'deferred_batch_processing',
    start_date=datetime(2024, 1, 1),
    schedule_interval=None,
    catchup=False
)

def start_batch_processing(**context):
    """Start a batch and defer to trigger for monitoring."""
    hook = OpenAIHook(conn_id='openai_default')
    
    # Create batch
    batch = hook.create_batch(
        file_id=context['params']['file_id'],
        endpoint="/v1/chat/completions"
    )
    
    # Store batch ID for trigger
    context['task_instance'].xcom_push(key='batch_id', value=batch.id)
    
    # Defer to trigger
    context['task_instance'].defer(
        trigger=OpenAIBatchTrigger(
            conn_id='openai_default',
            batch_id=batch.id,
            poll_interval=120,  # Check every 2 minutes
            end_time=time.time() + 86400  # 24 hour timeout
        ),
        method_name='handle_batch_completion'
    )

def handle_batch_completion(**context):
    """Handle batch completion event."""
    event = context['event']
    
    if event['status'] == 'success':
        print(f"Batch {event['batch_id']} completed successfully!")
        return event['batch_id']
    elif event['status'] == 'cancelled':
        print(f"Batch {event['batch_id']} was cancelled: {event['message']}")
        raise Exception(f"Batch cancelled: {event['message']}")
    else:  # error
        print(f"Batch {event['batch_id']} failed: {event['message']}")
        raise Exception(f"Batch failed: {event['message']}")

deferred_batch_task = PythonOperator(
    task_id='deferred_batch_processing',
    python_callable=start_batch_processing,
    params={'file_id': 'file-xyz789'},
    dag=dag
)

Custom Trigger Implementation

import asyncio
import time
from collections.abc import AsyncIterator
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus

class CustomOpenAIBatchTrigger(BaseTrigger):
    """Extended batch trigger with custom monitoring logic."""
    
    def __init__(
        self,
        conn_id: str,
        batch_id: str,
        poll_interval: float,
        end_time: float,
        progress_callback: str = None
    ):
        super().__init__()
        self.conn_id = conn_id
        self.batch_id = batch_id
        self.poll_interval = poll_interval
        self.end_time = end_time
        self.progress_callback = progress_callback
    
    def serialize(self) -> tuple[str, dict]:
        return (
            f"{self.__class__.__module__}.{self.__class__.__name__}",
            {
                "conn_id": self.conn_id,
                "batch_id": self.batch_id,
                "poll_interval": self.poll_interval,
                "end_time": self.end_time,
                "progress_callback": self.progress_callback
            }
        )
    
    async def run(self) -> AsyncIterator[TriggerEvent]:
        """Enhanced monitoring with progress tracking."""
        hook = OpenAIHook(conn_id=self.conn_id)
        last_status = None
        
        try:
            while True:
                current_time = time.time()
                
                # Check timeout
                if current_time >= self.end_time:
                    yield TriggerEvent({
                        "status": "error",
                        "message": f"Batch {self.batch_id} monitoring timed out after {current_time - self.end_time} seconds",
                        "batch_id": self.batch_id
                    })
                    return
                
                # Get batch status
                batch = hook.get_batch(self.batch_id)
                
                # Emit progress events for status changes
                if batch.status != last_status:
                    yield TriggerEvent({
                        "status": "progress",
                        "message": f"Batch status changed from {last_status} to {batch.status}",
                        "batch_id": self.batch_id,
                        "batch_status": batch.status
                    })
                    last_status = batch.status
                
                # Check for completion
                if not BatchStatus.is_in_progress(batch.status):
                    if batch.status == BatchStatus.COMPLETED:
                        yield TriggerEvent({
                            "status": "success",
                            "message": f"Batch {self.batch_id} completed successfully",
                            "batch_id": self.batch_id,
                            "final_status": batch.status
                        })
                    elif batch.status in {BatchStatus.CANCELLED, BatchStatus.CANCELLING}:
                        yield TriggerEvent({
                            "status": "cancelled",
                            "message": f"Batch {self.batch_id} was cancelled",
                            "batch_id": self.batch_id,
                            "final_status": batch.status
                        })
                    else:  # FAILED, EXPIRED, or other error states
                        yield TriggerEvent({
                            "status": "error",
                            "message": f"Batch {self.batch_id} failed with status: {batch.status}",
                            "batch_id": self.batch_id,
                            "final_status": batch.status
                        })
                    return
                
                # Wait before next check
                await asyncio.sleep(self.poll_interval)
                
        except Exception as e:
            yield TriggerEvent({
                "status": "error",
                "message": f"Trigger error: {str(e)}",
                "batch_id": self.batch_id
            })

Monitoring Multiple Batches

import asyncio
from typing import Dict, List
from airflow.triggers.base import BaseTrigger, TriggerEvent
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus

class MultiBatchTrigger(BaseTrigger):
    """Monitor multiple OpenAI batches simultaneously."""
    
    def __init__(
        self,
        conn_id: str,
        batch_ids: List[str],
        poll_interval: float,
        end_time: float
    ):
        super().__init__()
        self.conn_id = conn_id
        self.batch_ids = batch_ids
        self.poll_interval = poll_interval
        self.end_time = end_time
    
    def serialize(self) -> tuple[str, dict]:
        return (
            f"{self.__class__.__module__}.{self.__class__.__name__}",
            {
                "conn_id": self.conn_id,
                "batch_ids": self.batch_ids,
                "poll_interval": self.poll_interval,
                "end_time": self.end_time
            }
        )
    
    async def run(self) -> AsyncIterator[TriggerEvent]:
        """Monitor multiple batches until all complete."""
        hook = OpenAIHook(conn_id=self.conn_id)
        completed_batches = set()
        failed_batches = set()
        
        try:
            while len(completed_batches) + len(failed_batches) < len(self.batch_ids):
                current_time = time.time()
                
                # Check timeout
                if current_time >= self.end_time:
                    remaining = set(self.batch_ids) - completed_batches - failed_batches
                    yield TriggerEvent({
                        "status": "timeout",
                        "message": f"Timeout reached. Remaining batches: {list(remaining)}",
                        "completed_batches": list(completed_batches),
                        "failed_batches": list(failed_batches),
                        "remaining_batches": list(remaining)
                    })
                    return
                
                # Check each batch
                for batch_id in self.batch_ids:
                    if batch_id in completed_batches or batch_id in failed_batches:
                        continue
                    
                    batch = hook.get_batch(batch_id)
                    
                    if not BatchStatus.is_in_progress(batch.status):
                        if batch.status == BatchStatus.COMPLETED:
                            completed_batches.add(batch_id)
                            yield TriggerEvent({
                                "status": "batch_completed",
                                "message": f"Batch {batch_id} completed",
                                "batch_id": batch_id,
                                "completed_count": len(completed_batches),
                                "total_count": len(self.batch_ids)
                            })
                        else:
                            failed_batches.add(batch_id)
                            yield TriggerEvent({
                                "status": "batch_failed",
                                "message": f"Batch {batch_id} failed with status: {batch.status}",
                                "batch_id": batch_id,
                                "batch_status": batch.status,
                                "failed_count": len(failed_batches),
                                "total_count": len(self.batch_ids)
                            })
                
                await asyncio.sleep(self.poll_interval)
            
            # All batches completed
            if failed_batches:
                yield TriggerEvent({
                    "status": "partial_success",
                    "message": f"Processing complete. {len(completed_batches)} succeeded, {len(failed_batches)} failed",
                    "completed_batches": list(completed_batches),
                    "failed_batches": list(failed_batches)
                })
            else:
                yield TriggerEvent({
                    "status": "success",
                    "message": f"All {len(completed_batches)} batches completed successfully",
                    "completed_batches": list(completed_batches)
                })
                
        except Exception as e:
            yield TriggerEvent({
                "status": "error",
                "message": f"Multi-batch trigger error: {str(e)}",
                "batch_ids": self.batch_ids
            })

# Usage example
multi_batch_trigger = MultiBatchTrigger(
    conn_id='openai_default',
    batch_ids=['batch_1', 'batch_2', 'batch_3'],
    poll_interval=60,
    end_time=time.time() + 7200  # 2 hours
)

Integration with Airflow Sensors

from airflow.sensors.base import BaseSensorOperator
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

class OpenAIBatchSensor(BaseSensorOperator):
    """Sensor that waits for OpenAI batch completion."""
    
    def __init__(
        self,
        batch_id: str,
        conn_id: str = 'openai_default',
        poll_interval: float = 60,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.batch_id = batch_id
        self.conn_id = conn_id
        self.poll_interval = poll_interval
    
    def poke(self, context) -> bool:
        """Check if batch is complete."""
        from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
        
        hook = OpenAIHook(conn_id=self.conn_id)
        batch = hook.get_batch(self.batch_id)
        
        if batch.status == BatchStatus.COMPLETED:
            return True
        elif batch.status in {BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED}:
            raise Exception(f"Batch {self.batch_id} failed with status: {batch.status}")
        
        return False

# Use the sensor in a DAG
batch_sensor = OpenAIBatchSensor(
    task_id='wait_for_batch',
    batch_id="{{ task_instance.xcom_pull(task_ids='create_batch') }}",
    conn_id='openai_default',
    poll_interval=30,
    timeout=3600,
    dag=dag
)

Install with Tessl CLI

npx tessl i tessl/pypi-apache-airflow-providers-openai

docs

exceptions.md

hooks.md

index.md

operators.md

triggers.md

version_compat.md

tile.json