Provider package that enables OpenAI integration for Apache Airflow, including hooks, operators, and triggers for AI-powered data pipelines.
—
Quality
Pending
Does it follow best practices?
Impact
Pending
No eval scenarios have been run
Airflow operators provide task-level abstractions for OpenAI operations, integrating seamlessly with DAG workflows and providing proper task lifecycle management, templating, and error handling.
Generate OpenAI embeddings as part of an Airflow DAG task, with support for text templating and various input formats.
class OpenAIEmbeddingOperator(BaseOperator):
"""
Operator that accepts input text to generate OpenAI embeddings using the specified model.
Args:
conn_id (str): The OpenAI connection ID to use
input_text (str | list[str] | list[int] | list[list[int]]): The text to generate embeddings for
model (str): The OpenAI model to use for generating embeddings, defaults to "text-embedding-ada-002"
embedding_kwargs (dict, optional): Additional keyword arguments for the create_embeddings method
**kwargs: Additional BaseOperator arguments
"""
template_fields: Sequence[str] = ("input_text",)
def __init__(
self,
conn_id: str,
input_text: str | list[str] | list[int] | list[list[int]],
model: str = "text-embedding-ada-002",
embedding_kwargs: dict | None = None,
**kwargs: Any,
): ...
@cached_property
def hook(self) -> OpenAIHook:
"""Return an instance of the OpenAIHook."""
def execute(self, context: Context) -> list[float]:
"""
Execute the embedding generation task.
Args:
context: Airflow task context
Returns:
List of embedding values (floats)
Raises:
ValueError: If input_text is empty or invalid format
"""Trigger OpenAI Batch API operations with support for both synchronous and asynchronous (deferrable) execution modes.
class OpenAITriggerBatchOperator(BaseOperator):
"""
Operator that triggers an OpenAI Batch API endpoint and waits for the batch to complete.
Args:
file_id (str): The ID of the batch file to trigger
endpoint (Literal): The OpenAI Batch API endpoint ("/v1/chat/completions", "/v1/embeddings", "/v1/completions")
conn_id (str): The OpenAI connection ID, defaults to 'openai_default'
deferrable (bool): Run operator in deferrable mode, defaults to system configuration setting
wait_seconds (float): Number of seconds between checks when not deferrable, defaults to 3
timeout (float): Time to wait for completion in seconds, defaults to 24 hours
wait_for_completion (bool): Whether to wait for batch completion, defaults to True
**kwargs: Additional BaseOperator arguments
"""
template_fields: Sequence[str] = ("file_id",)
batch_id: str | None = None # Set during execution with the created batch ID
def __init__(
self,
file_id: str,
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
conn_id: str = OpenAIHook.default_conn_name,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
wait_seconds: float = 3,
timeout: float = 24 * 60 * 60,
wait_for_completion: bool = True,
**kwargs: Any,
): ...
@cached_property
def hook(self) -> OpenAIHook:
"""Return an instance of the OpenAIHook."""
def execute(self, context: Context) -> str | None:
"""
Execute the batch operation.
Args:
context: Airflow task context
Returns:
Batch ID if successful, None if not waiting for completion
"""
def execute_complete(self, context: Context, event: Any = None) -> str:
"""
Callback for deferrable execution completion.
Args:
context: Airflow task context
event: Event data from trigger
Returns:
Batch ID
Raises:
OpenAIBatchJobException: If batch processing failed
"""
def on_kill(self) -> None:
"""Cancel the batch if task is cancelled."""from datetime import datetime
from airflow import DAG
from airflow.providers.openai.operators.openai import OpenAIEmbeddingOperator
dag = DAG(
'embedding_example',
start_date=datetime(2024, 1, 1),
schedule_interval=None,
catchup=False
)
# Simple text embedding
embedding_task = OpenAIEmbeddingOperator(
task_id='generate_embeddings',
conn_id='openai_default',
input_text="This is sample text for embedding generation",
model="text-embedding-ada-002",
dag=dag
)
# Multiple texts with custom parameters
batch_embedding_task = OpenAIEmbeddingOperator(
task_id='batch_embeddings',
conn_id='openai_default',
input_text=[
"First document to embed",
"Second document to embed",
"Third document to embed"
],
model="text-embedding-3-large",
embedding_kwargs={
"dimensions": 1024,
"encoding_format": "float"
},
dag=dag
)# Using Airflow templating for dynamic input
templated_embedding_task = OpenAIEmbeddingOperator(
task_id='templated_embeddings',
conn_id='openai_default',
input_text="{{ dag_run.conf.get('text_content', 'Default text') }}",
model="text-embedding-ada-002",
dag=dag
)from airflow.providers.openai.operators.openai import OpenAITriggerBatchOperator
# Synchronous batch processing
sync_batch_task = OpenAITriggerBatchOperator(
task_id='process_batch_sync',
file_id="{{ task_instance.xcom_pull(task_ids='upload_batch_file') }}",
endpoint="/v1/chat/completions",
conn_id='openai_default',
deferrable=False,
wait_seconds=5,
timeout=7200, # 2 hours
dag=dag
)
# Asynchronous (deferrable) batch processing
async_batch_task = OpenAITriggerBatchOperator(
task_id='process_batch_async',
file_id="file-abc123",
endpoint="/v1/embeddings",
conn_id='openai_default',
deferrable=True,
timeout=86400, # 24 hours
dag=dag
)
# Trigger batch without waiting
fire_and_forget_batch = OpenAITriggerBatchOperator(
task_id='trigger_batch_only',
file_id="file-def456",
endpoint="/v1/completions",
conn_id='openai_default',
wait_for_completion=False,
dag=dag
)from datetime import datetime, timedelta
from airflow import DAG
from airflow.operators.python_operator import PythonOperator
from airflow.providers.openai.hooks.openai import OpenAIHook
from airflow.providers.openai.operators.openai import (
OpenAIEmbeddingOperator,
OpenAITriggerBatchOperator
)
default_args = {
'owner': 'data-team',
'depends_on_past': False,
'start_date': datetime(2024, 1, 1),
'email_on_failure': False,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'openai_processing_pipeline',
default_args=default_args,
description='Process data using OpenAI services',
schedule_interval=timedelta(days=1),
catchup=False
)
def upload_batch_file(**context):
"""Upload a batch processing file."""
hook = OpenAIHook(conn_id='openai_default')
# Create batch file content
batch_requests = []
for i in range(10):
request = {
"custom_id": f"request-{i}",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": f"Process item {i}"}],
"max_tokens": 100
}
}
batch_requests.append(request)
# Write to temporary file
import tempfile
import json
with tempfile.NamedTemporaryFile(mode='w', suffix='.jsonl', delete=False) as f:
for request in batch_requests:
f.write(json.dumps(request) + '\n')
temp_file = f.name
# Upload file
file_obj = hook.upload_file(temp_file, purpose="batch")
return file_obj.id
# Task to upload batch file
upload_task = PythonOperator(
task_id='upload_batch_file',
python_callable=upload_batch_file,
dag=dag
)
# Generate embeddings for input data
embedding_task = OpenAIEmbeddingOperator(
task_id='generate_embeddings',
conn_id='openai_default',
input_text="{{ dag_run.conf.get('input_texts', ['Default text']) }}",
model="text-embedding-ada-002",
dag=dag
)
# Process batch requests
batch_task = OpenAITriggerBatchOperator(
task_id='process_chat_batch',
file_id="{{ task_instance.xcom_pull(task_ids='upload_batch_file') }}",
endpoint="/v1/chat/completions",
conn_id='openai_default',
deferrable=True,
dag=dag
)
# Set task dependencies
upload_task >> batch_task
embedding_taskfrom airflow.providers.openai.exceptions import OpenAIBatchJobException, OpenAIBatchTimeout
def handle_batch_with_retry(**context):
"""Handle batch processing with custom retry logic."""
from airflow.providers.openai.operators.openai import OpenAITriggerBatchOperator
try:
operator = OpenAITriggerBatchOperator(
task_id='batch_with_handling',
file_id=context['params']['file_id'],
endpoint="/v1/chat/completions",
conn_id='openai_default',
timeout=1800 # 30 minutes
)
result = operator.execute(context)
return result
except OpenAIBatchTimeout as e:
print(f"Batch timed out: {e}")
# Implement custom timeout handling
raise
except OpenAIBatchJobException as e:
print(f"Batch job failed: {e}")
# Implement custom failure handling
raise
error_handling_task = PythonOperator(
task_id='batch_with_error_handling',
python_callable=handle_batch_with_retry,
params={'file_id': 'file-123'},
dag=dag
)Install with Tessl CLI
npx tessl i tessl/pypi-apache-airflow-providers-openai