0
# Triggers
1
2
Airflow triggers provide asynchronous monitoring capabilities for long-running OpenAI operations, enabling efficient resource usage and proper handling of batch processing workflows.
3
4
## Capabilities
5
6
### Batch Processing Trigger
7
8
Asynchronously monitor OpenAI Batch API operations with configurable polling intervals and timeout handling.
9
10
```python { .api }
11
class OpenAIBatchTrigger(BaseTrigger):
12
"""
13
Triggers OpenAI Batch API monitoring for long-running batch operations.
14
15
Args:
16
conn_id (str): The OpenAI connection ID to use
17
batch_id (str): The ID of the batch to monitor
18
poll_interval (float): Number of seconds between status checks
19
end_time (float): Unix timestamp when monitoring should timeout
20
"""
21
22
def __init__(
23
self,
24
conn_id: str,
25
batch_id: str,
26
poll_interval: float,
27
end_time: float,
28
) -> None: ...
29
30
def serialize(self) -> tuple[str, dict[str, Any]]:
31
"""
32
Serialize OpenAIBatchTrigger arguments and class path for persistence.
33
34
Returns:
35
Tuple of (class_path, serialized_arguments)
36
"""
37
38
async def run(self) -> AsyncIterator[TriggerEvent]:
39
"""
40
Make connection to OpenAI Client and poll the status of batch.
41
42
Yields:
43
TriggerEvent: Events indicating batch status changes or completion
44
45
Events:
46
- {"status": "success", "message": "...", "batch_id": "..."}: Batch completed successfully
47
- {"status": "cancelled", "message": "...", "batch_id": "..."}: Batch was cancelled
48
- {"status": "error", "message": "...", "batch_id": "..."}: Batch failed or timed out
49
"""
50
```
51
52
## Usage Examples
53
54
### Direct Trigger Usage
55
56
```python
57
import time
58
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger
59
60
# Create trigger for batch monitoring
61
trigger = OpenAIBatchTrigger(
62
conn_id='openai_default',
63
batch_id='batch_abc123',
64
poll_interval=60, # Check every minute
65
end_time=time.time() + 3600 # Timeout after 1 hour
66
)
67
68
# Serialize for storage (handled automatically by Airflow)
69
class_path, args = trigger.serialize()
70
print(f"Trigger class: {class_path}")
71
print(f"Trigger args: {args}")
72
```
73
74
### Integration with Deferrable Operator
75
76
```python
77
from datetime import datetime, timedelta
78
from airflow import DAG
79
from airflow.operators.python_operator import PythonOperator
80
from airflow.providers.openai.hooks.openai import OpenAIHook
81
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger
82
83
dag = DAG(
84
'deferred_batch_processing',
85
start_date=datetime(2024, 1, 1),
86
schedule_interval=None,
87
catchup=False
88
)
89
90
def start_batch_processing(**context):
91
"""Start a batch and defer to trigger for monitoring."""
92
hook = OpenAIHook(conn_id='openai_default')
93
94
# Create batch
95
batch = hook.create_batch(
96
file_id=context['params']['file_id'],
97
endpoint="/v1/chat/completions"
98
)
99
100
# Store batch ID for trigger
101
context['task_instance'].xcom_push(key='batch_id', value=batch.id)
102
103
# Defer to trigger
104
context['task_instance'].defer(
105
trigger=OpenAIBatchTrigger(
106
conn_id='openai_default',
107
batch_id=batch.id,
108
poll_interval=120, # Check every 2 minutes
109
end_time=time.time() + 86400 # 24 hour timeout
110
),
111
method_name='handle_batch_completion'
112
)
113
114
def handle_batch_completion(**context):
115
"""Handle batch completion event."""
116
event = context['event']
117
118
if event['status'] == 'success':
119
print(f"Batch {event['batch_id']} completed successfully!")
120
return event['batch_id']
121
elif event['status'] == 'cancelled':
122
print(f"Batch {event['batch_id']} was cancelled: {event['message']}")
123
raise Exception(f"Batch cancelled: {event['message']}")
124
else: # error
125
print(f"Batch {event['batch_id']} failed: {event['message']}")
126
raise Exception(f"Batch failed: {event['message']}")
127
128
deferred_batch_task = PythonOperator(
129
task_id='deferred_batch_processing',
130
python_callable=start_batch_processing,
131
params={'file_id': 'file-xyz789'},
132
dag=dag
133
)
134
```
135
136
### Custom Trigger Implementation
137
138
```python
139
import asyncio
140
import time
141
from collections.abc import AsyncIterator
142
from airflow.triggers.base import BaseTrigger, TriggerEvent
143
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
144
145
class CustomOpenAIBatchTrigger(BaseTrigger):
146
"""Extended batch trigger with custom monitoring logic."""
147
148
def __init__(
149
self,
150
conn_id: str,
151
batch_id: str,
152
poll_interval: float,
153
end_time: float,
154
progress_callback: str = None
155
):
156
super().__init__()
157
self.conn_id = conn_id
158
self.batch_id = batch_id
159
self.poll_interval = poll_interval
160
self.end_time = end_time
161
self.progress_callback = progress_callback
162
163
def serialize(self) -> tuple[str, dict]:
164
return (
165
f"{self.__class__.__module__}.{self.__class__.__name__}",
166
{
167
"conn_id": self.conn_id,
168
"batch_id": self.batch_id,
169
"poll_interval": self.poll_interval,
170
"end_time": self.end_time,
171
"progress_callback": self.progress_callback
172
}
173
)
174
175
async def run(self) -> AsyncIterator[TriggerEvent]:
176
"""Enhanced monitoring with progress tracking."""
177
hook = OpenAIHook(conn_id=self.conn_id)
178
last_status = None
179
180
try:
181
while True:
182
current_time = time.time()
183
184
# Check timeout
185
if current_time >= self.end_time:
186
yield TriggerEvent({
187
"status": "error",
188
"message": f"Batch {self.batch_id} monitoring timed out after {current_time - self.end_time} seconds",
189
"batch_id": self.batch_id
190
})
191
return
192
193
# Get batch status
194
batch = hook.get_batch(self.batch_id)
195
196
# Emit progress events for status changes
197
if batch.status != last_status:
198
yield TriggerEvent({
199
"status": "progress",
200
"message": f"Batch status changed from {last_status} to {batch.status}",
201
"batch_id": self.batch_id,
202
"batch_status": batch.status
203
})
204
last_status = batch.status
205
206
# Check for completion
207
if not BatchStatus.is_in_progress(batch.status):
208
if batch.status == BatchStatus.COMPLETED:
209
yield TriggerEvent({
210
"status": "success",
211
"message": f"Batch {self.batch_id} completed successfully",
212
"batch_id": self.batch_id,
213
"final_status": batch.status
214
})
215
elif batch.status in {BatchStatus.CANCELLED, BatchStatus.CANCELLING}:
216
yield TriggerEvent({
217
"status": "cancelled",
218
"message": f"Batch {self.batch_id} was cancelled",
219
"batch_id": self.batch_id,
220
"final_status": batch.status
221
})
222
else: # FAILED, EXPIRED, or other error states
223
yield TriggerEvent({
224
"status": "error",
225
"message": f"Batch {self.batch_id} failed with status: {batch.status}",
226
"batch_id": self.batch_id,
227
"final_status": batch.status
228
})
229
return
230
231
# Wait before next check
232
await asyncio.sleep(self.poll_interval)
233
234
except Exception as e:
235
yield TriggerEvent({
236
"status": "error",
237
"message": f"Trigger error: {str(e)}",
238
"batch_id": self.batch_id
239
})
240
```
241
242
### Monitoring Multiple Batches
243
244
```python
245
import asyncio
246
from typing import Dict, List
247
from airflow.triggers.base import BaseTrigger, TriggerEvent
248
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
249
250
class MultiBatchTrigger(BaseTrigger):
251
"""Monitor multiple OpenAI batches simultaneously."""
252
253
def __init__(
254
self,
255
conn_id: str,
256
batch_ids: List[str],
257
poll_interval: float,
258
end_time: float
259
):
260
super().__init__()
261
self.conn_id = conn_id
262
self.batch_ids = batch_ids
263
self.poll_interval = poll_interval
264
self.end_time = end_time
265
266
def serialize(self) -> tuple[str, dict]:
267
return (
268
f"{self.__class__.__module__}.{self.__class__.__name__}",
269
{
270
"conn_id": self.conn_id,
271
"batch_ids": self.batch_ids,
272
"poll_interval": self.poll_interval,
273
"end_time": self.end_time
274
}
275
)
276
277
async def run(self) -> AsyncIterator[TriggerEvent]:
278
"""Monitor multiple batches until all complete."""
279
hook = OpenAIHook(conn_id=self.conn_id)
280
completed_batches = set()
281
failed_batches = set()
282
283
try:
284
while len(completed_batches) + len(failed_batches) < len(self.batch_ids):
285
current_time = time.time()
286
287
# Check timeout
288
if current_time >= self.end_time:
289
remaining = set(self.batch_ids) - completed_batches - failed_batches
290
yield TriggerEvent({
291
"status": "timeout",
292
"message": f"Timeout reached. Remaining batches: {list(remaining)}",
293
"completed_batches": list(completed_batches),
294
"failed_batches": list(failed_batches),
295
"remaining_batches": list(remaining)
296
})
297
return
298
299
# Check each batch
300
for batch_id in self.batch_ids:
301
if batch_id in completed_batches or batch_id in failed_batches:
302
continue
303
304
batch = hook.get_batch(batch_id)
305
306
if not BatchStatus.is_in_progress(batch.status):
307
if batch.status == BatchStatus.COMPLETED:
308
completed_batches.add(batch_id)
309
yield TriggerEvent({
310
"status": "batch_completed",
311
"message": f"Batch {batch_id} completed",
312
"batch_id": batch_id,
313
"completed_count": len(completed_batches),
314
"total_count": len(self.batch_ids)
315
})
316
else:
317
failed_batches.add(batch_id)
318
yield TriggerEvent({
319
"status": "batch_failed",
320
"message": f"Batch {batch_id} failed with status: {batch.status}",
321
"batch_id": batch_id,
322
"batch_status": batch.status,
323
"failed_count": len(failed_batches),
324
"total_count": len(self.batch_ids)
325
})
326
327
await asyncio.sleep(self.poll_interval)
328
329
# All batches completed
330
if failed_batches:
331
yield TriggerEvent({
332
"status": "partial_success",
333
"message": f"Processing complete. {len(completed_batches)} succeeded, {len(failed_batches)} failed",
334
"completed_batches": list(completed_batches),
335
"failed_batches": list(failed_batches)
336
})
337
else:
338
yield TriggerEvent({
339
"status": "success",
340
"message": f"All {len(completed_batches)} batches completed successfully",
341
"completed_batches": list(completed_batches)
342
})
343
344
except Exception as e:
345
yield TriggerEvent({
346
"status": "error",
347
"message": f"Multi-batch trigger error: {str(e)}",
348
"batch_ids": self.batch_ids
349
})
350
351
# Usage example
352
multi_batch_trigger = MultiBatchTrigger(
353
conn_id='openai_default',
354
batch_ids=['batch_1', 'batch_2', 'batch_3'],
355
poll_interval=60,
356
end_time=time.time() + 7200 # 2 hours
357
)
358
```
359
360
### Integration with Airflow Sensors
361
362
```python
363
from airflow.sensors.base import BaseSensorOperator
364
from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger
365
366
class OpenAIBatchSensor(BaseSensorOperator):
367
"""Sensor that waits for OpenAI batch completion."""
368
369
def __init__(
370
self,
371
batch_id: str,
372
conn_id: str = 'openai_default',
373
poll_interval: float = 60,
374
**kwargs
375
):
376
super().__init__(**kwargs)
377
self.batch_id = batch_id
378
self.conn_id = conn_id
379
self.poll_interval = poll_interval
380
381
def poke(self, context) -> bool:
382
"""Check if batch is complete."""
383
from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus
384
385
hook = OpenAIHook(conn_id=self.conn_id)
386
batch = hook.get_batch(self.batch_id)
387
388
if batch.status == BatchStatus.COMPLETED:
389
return True
390
elif batch.status in {BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED}:
391
raise Exception(f"Batch {self.batch_id} failed with status: {batch.status}")
392
393
return False
394
395
# Use the sensor in a DAG
396
batch_sensor = OpenAIBatchSensor(
397
task_id='wait_for_batch',
398
batch_id="{{ task_instance.xcom_pull(task_ids='create_batch') }}",
399
conn_id='openai_default',
400
poll_interval=30,
401
timeout=3600,
402
dag=dag
403
)
404
```