or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

exceptions.mdhooks.mdindex.mdoperators.mdtriggers.mdversion_compat.md

triggers.mddocs/

0

# Triggers

1

2

Airflow triggers provide asynchronous monitoring capabilities for long-running OpenAI operations, enabling efficient resource usage and proper handling of batch processing workflows.

3

4

## Capabilities

5

6

### Batch Processing Trigger

7

8

Asynchronously monitor OpenAI Batch API operations with configurable polling intervals and timeout handling.

9

10

```python { .api }

11

class OpenAIBatchTrigger(BaseTrigger):

12

"""

13

Triggers OpenAI Batch API monitoring for long-running batch operations.

14

15

Args:

16

conn_id (str): The OpenAI connection ID to use

17

batch_id (str): The ID of the batch to monitor

18

poll_interval (float): Number of seconds between status checks

19

end_time (float): Unix timestamp when monitoring should timeout

20

"""

21

22

def __init__(

23

self,

24

conn_id: str,

25

batch_id: str,

26

poll_interval: float,

27

end_time: float,

28

) -> None: ...

29

30

def serialize(self) -> tuple[str, dict[str, Any]]:

31

"""

32

Serialize OpenAIBatchTrigger arguments and class path for persistence.

33

34

Returns:

35

Tuple of (class_path, serialized_arguments)

36

"""

37

38

async def run(self) -> AsyncIterator[TriggerEvent]:

39

"""

40

Make connection to OpenAI Client and poll the status of batch.

41

42

Yields:

43

TriggerEvent: Events indicating batch status changes or completion

44

45

Events:

46

- {"status": "success", "message": "...", "batch_id": "..."}: Batch completed successfully

47

- {"status": "cancelled", "message": "...", "batch_id": "..."}: Batch was cancelled

48

- {"status": "error", "message": "...", "batch_id": "..."}: Batch failed or timed out

49

"""

50

```

51

52

## Usage Examples

53

54

### Direct Trigger Usage

55

56

```python

57

import time

58

from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

59

60

# Create trigger for batch monitoring

61

trigger = OpenAIBatchTrigger(

62

conn_id='openai_default',

63

batch_id='batch_abc123',

64

poll_interval=60, # Check every minute

65

end_time=time.time() + 3600 # Timeout after 1 hour

66

)

67

68

# Serialize for storage (handled automatically by Airflow)

69

class_path, args = trigger.serialize()

70

print(f"Trigger class: {class_path}")

71

print(f"Trigger args: {args}")

72

```

73

74

### Integration with Deferrable Operator

75

76

```python

77

from datetime import datetime, timedelta

78

from airflow import DAG

79

from airflow.operators.python_operator import PythonOperator

80

from airflow.providers.openai.hooks.openai import OpenAIHook

81

from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

82

83

dag = DAG(

84

'deferred_batch_processing',

85

start_date=datetime(2024, 1, 1),

86

schedule_interval=None,

87

catchup=False

88

)

89

90

def start_batch_processing(**context):

91

"""Start a batch and defer to trigger for monitoring."""

92

hook = OpenAIHook(conn_id='openai_default')

93

94

# Create batch

95

batch = hook.create_batch(

96

file_id=context['params']['file_id'],

97

endpoint="/v1/chat/completions"

98

)

99

100

# Store batch ID for trigger

101

context['task_instance'].xcom_push(key='batch_id', value=batch.id)

102

103

# Defer to trigger

104

context['task_instance'].defer(

105

trigger=OpenAIBatchTrigger(

106

conn_id='openai_default',

107

batch_id=batch.id,

108

poll_interval=120, # Check every 2 minutes

109

end_time=time.time() + 86400 # 24 hour timeout

110

),

111

method_name='handle_batch_completion'

112

)

113

114

def handle_batch_completion(**context):

115

"""Handle batch completion event."""

116

event = context['event']

117

118

if event['status'] == 'success':

119

print(f"Batch {event['batch_id']} completed successfully!")

120

return event['batch_id']

121

elif event['status'] == 'cancelled':

122

print(f"Batch {event['batch_id']} was cancelled: {event['message']}")

123

raise Exception(f"Batch cancelled: {event['message']}")

124

else: # error

125

print(f"Batch {event['batch_id']} failed: {event['message']}")

126

raise Exception(f"Batch failed: {event['message']}")

127

128

deferred_batch_task = PythonOperator(

129

task_id='deferred_batch_processing',

130

python_callable=start_batch_processing,

131

params={'file_id': 'file-xyz789'},

132

dag=dag

133

)

134

```

135

136

### Custom Trigger Implementation

137

138

```python

139

import asyncio

140

import time

141

from collections.abc import AsyncIterator

142

from airflow.triggers.base import BaseTrigger, TriggerEvent

143

from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus

144

145

class CustomOpenAIBatchTrigger(BaseTrigger):

146

"""Extended batch trigger with custom monitoring logic."""

147

148

def __init__(

149

self,

150

conn_id: str,

151

batch_id: str,

152

poll_interval: float,

153

end_time: float,

154

progress_callback: str = None

155

):

156

super().__init__()

157

self.conn_id = conn_id

158

self.batch_id = batch_id

159

self.poll_interval = poll_interval

160

self.end_time = end_time

161

self.progress_callback = progress_callback

162

163

def serialize(self) -> tuple[str, dict]:

164

return (

165

f"{self.__class__.__module__}.{self.__class__.__name__}",

166

{

167

"conn_id": self.conn_id,

168

"batch_id": self.batch_id,

169

"poll_interval": self.poll_interval,

170

"end_time": self.end_time,

171

"progress_callback": self.progress_callback

172

}

173

)

174

175

async def run(self) -> AsyncIterator[TriggerEvent]:

176

"""Enhanced monitoring with progress tracking."""

177

hook = OpenAIHook(conn_id=self.conn_id)

178

last_status = None

179

180

try:

181

while True:

182

current_time = time.time()

183

184

# Check timeout

185

if current_time >= self.end_time:

186

yield TriggerEvent({

187

"status": "error",

188

"message": f"Batch {self.batch_id} monitoring timed out after {current_time - self.end_time} seconds",

189

"batch_id": self.batch_id

190

})

191

return

192

193

# Get batch status

194

batch = hook.get_batch(self.batch_id)

195

196

# Emit progress events for status changes

197

if batch.status != last_status:

198

yield TriggerEvent({

199

"status": "progress",

200

"message": f"Batch status changed from {last_status} to {batch.status}",

201

"batch_id": self.batch_id,

202

"batch_status": batch.status

203

})

204

last_status = batch.status

205

206

# Check for completion

207

if not BatchStatus.is_in_progress(batch.status):

208

if batch.status == BatchStatus.COMPLETED:

209

yield TriggerEvent({

210

"status": "success",

211

"message": f"Batch {self.batch_id} completed successfully",

212

"batch_id": self.batch_id,

213

"final_status": batch.status

214

})

215

elif batch.status in {BatchStatus.CANCELLED, BatchStatus.CANCELLING}:

216

yield TriggerEvent({

217

"status": "cancelled",

218

"message": f"Batch {self.batch_id} was cancelled",

219

"batch_id": self.batch_id,

220

"final_status": batch.status

221

})

222

else: # FAILED, EXPIRED, or other error states

223

yield TriggerEvent({

224

"status": "error",

225

"message": f"Batch {self.batch_id} failed with status: {batch.status}",

226

"batch_id": self.batch_id,

227

"final_status": batch.status

228

})

229

return

230

231

# Wait before next check

232

await asyncio.sleep(self.poll_interval)

233

234

except Exception as e:

235

yield TriggerEvent({

236

"status": "error",

237

"message": f"Trigger error: {str(e)}",

238

"batch_id": self.batch_id

239

})

240

```

241

242

### Monitoring Multiple Batches

243

244

```python

245

import asyncio

246

from typing import Dict, List

247

from airflow.triggers.base import BaseTrigger, TriggerEvent

248

from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus

249

250

class MultiBatchTrigger(BaseTrigger):

251

"""Monitor multiple OpenAI batches simultaneously."""

252

253

def __init__(

254

self,

255

conn_id: str,

256

batch_ids: List[str],

257

poll_interval: float,

258

end_time: float

259

):

260

super().__init__()

261

self.conn_id = conn_id

262

self.batch_ids = batch_ids

263

self.poll_interval = poll_interval

264

self.end_time = end_time

265

266

def serialize(self) -> tuple[str, dict]:

267

return (

268

f"{self.__class__.__module__}.{self.__class__.__name__}",

269

{

270

"conn_id": self.conn_id,

271

"batch_ids": self.batch_ids,

272

"poll_interval": self.poll_interval,

273

"end_time": self.end_time

274

}

275

)

276

277

async def run(self) -> AsyncIterator[TriggerEvent]:

278

"""Monitor multiple batches until all complete."""

279

hook = OpenAIHook(conn_id=self.conn_id)

280

completed_batches = set()

281

failed_batches = set()

282

283

try:

284

while len(completed_batches) + len(failed_batches) < len(self.batch_ids):

285

current_time = time.time()

286

287

# Check timeout

288

if current_time >= self.end_time:

289

remaining = set(self.batch_ids) - completed_batches - failed_batches

290

yield TriggerEvent({

291

"status": "timeout",

292

"message": f"Timeout reached. Remaining batches: {list(remaining)}",

293

"completed_batches": list(completed_batches),

294

"failed_batches": list(failed_batches),

295

"remaining_batches": list(remaining)

296

})

297

return

298

299

# Check each batch

300

for batch_id in self.batch_ids:

301

if batch_id in completed_batches or batch_id in failed_batches:

302

continue

303

304

batch = hook.get_batch(batch_id)

305

306

if not BatchStatus.is_in_progress(batch.status):

307

if batch.status == BatchStatus.COMPLETED:

308

completed_batches.add(batch_id)

309

yield TriggerEvent({

310

"status": "batch_completed",

311

"message": f"Batch {batch_id} completed",

312

"batch_id": batch_id,

313

"completed_count": len(completed_batches),

314

"total_count": len(self.batch_ids)

315

})

316

else:

317

failed_batches.add(batch_id)

318

yield TriggerEvent({

319

"status": "batch_failed",

320

"message": f"Batch {batch_id} failed with status: {batch.status}",

321

"batch_id": batch_id,

322

"batch_status": batch.status,

323

"failed_count": len(failed_batches),

324

"total_count": len(self.batch_ids)

325

})

326

327

await asyncio.sleep(self.poll_interval)

328

329

# All batches completed

330

if failed_batches:

331

yield TriggerEvent({

332

"status": "partial_success",

333

"message": f"Processing complete. {len(completed_batches)} succeeded, {len(failed_batches)} failed",

334

"completed_batches": list(completed_batches),

335

"failed_batches": list(failed_batches)

336

})

337

else:

338

yield TriggerEvent({

339

"status": "success",

340

"message": f"All {len(completed_batches)} batches completed successfully",

341

"completed_batches": list(completed_batches)

342

})

343

344

except Exception as e:

345

yield TriggerEvent({

346

"status": "error",

347

"message": f"Multi-batch trigger error: {str(e)}",

348

"batch_ids": self.batch_ids

349

})

350

351

# Usage example

352

multi_batch_trigger = MultiBatchTrigger(

353

conn_id='openai_default',

354

batch_ids=['batch_1', 'batch_2', 'batch_3'],

355

poll_interval=60,

356

end_time=time.time() + 7200 # 2 hours

357

)

358

```

359

360

### Integration with Airflow Sensors

361

362

```python

363

from airflow.sensors.base import BaseSensorOperator

364

from airflow.providers.openai.triggers.openai import OpenAIBatchTrigger

365

366

class OpenAIBatchSensor(BaseSensorOperator):

367

"""Sensor that waits for OpenAI batch completion."""

368

369

def __init__(

370

self,

371

batch_id: str,

372

conn_id: str = 'openai_default',

373

poll_interval: float = 60,

374

**kwargs

375

):

376

super().__init__(**kwargs)

377

self.batch_id = batch_id

378

self.conn_id = conn_id

379

self.poll_interval = poll_interval

380

381

def poke(self, context) -> bool:

382

"""Check if batch is complete."""

383

from airflow.providers.openai.hooks.openai import OpenAIHook, BatchStatus

384

385

hook = OpenAIHook(conn_id=self.conn_id)

386

batch = hook.get_batch(self.batch_id)

387

388

if batch.status == BatchStatus.COMPLETED:

389

return True

390

elif batch.status in {BatchStatus.FAILED, BatchStatus.EXPIRED, BatchStatus.CANCELLED}:

391

raise Exception(f"Batch {self.batch_id} failed with status: {batch.status}")

392

393

return False

394

395

# Use the sensor in a DAG

396

batch_sensor = OpenAIBatchSensor(

397

task_id='wait_for_batch',

398

batch_id="{{ task_instance.xcom_pull(task_ids='create_batch') }}",

399

conn_id='openai_default',

400

poll_interval=30,

401

timeout=3600,

402

dag=dag

403

)

404

```