0
# Interceptors
1
2
Client and server-side middleware for cross-cutting concerns like logging, metrics, authentication, request/response modification, and tracing, with comprehensive support for all RPC patterns and state passing via contextvars.
3
4
## Capabilities
5
6
### Client Interceptors
7
8
Client-side interceptors that can modify requests, responses, and RPC metadata for all four RPC patterns.
9
10
```python { .api }
11
class UnaryUnaryClientInterceptor(abc.ABC):
12
"""Affords intercepting unary-unary invocations."""
13
14
def intercept_unary_unary(self, continuation, client_call_details, request):
15
"""
16
Intercepts a unary-unary invocation asynchronously.
17
18
Parameters:
19
- continuation: Function to proceed with the invocation
20
- client_call_details: ClientCallDetails object describing the outgoing RPC
21
- request: The request value for the RPC
22
23
Returns:
24
Call-Future: Object that is both a Call for the RPC and a Future
25
"""
26
27
class UnaryStreamClientInterceptor(abc.ABC):
28
"""Affords intercepting unary-stream invocations."""
29
30
def intercept_unary_stream(self, continuation, client_call_details, request):
31
"""
32
Intercepts a unary-stream invocation.
33
34
Parameters:
35
- continuation: Function to proceed with the invocation
36
- client_call_details: ClientCallDetails object describing the outgoing RPC
37
- request: The request value for the RPC
38
39
Returns:
40
Call-iterator: Object that is both a Call for the RPC and an iterator
41
"""
42
43
class StreamUnaryClientInterceptor(abc.ABC):
44
"""Affords intercepting stream-unary invocations."""
45
46
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
47
"""
48
Intercepts a stream-unary invocation asynchronously.
49
50
Parameters:
51
- continuation: Function to proceed with the invocation
52
- client_call_details: ClientCallDetails object describing the outgoing RPC
53
- request_iterator: An iterator that yields request values for the RPC
54
55
Returns:
56
Call-Future: Object that is both a Call for the RPC and a Future
57
"""
58
59
class StreamStreamClientInterceptor(abc.ABC):
60
"""Affords intercepting stream-stream invocations."""
61
62
def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
63
"""
64
Intercepts a stream-stream invocation.
65
66
Parameters:
67
- continuation: Function to proceed with the invocation
68
- client_call_details: ClientCallDetails object describing the outgoing RPC
69
- request_iterator: An iterator that yields request values for the RPC
70
71
Returns:
72
Call-iterator: Object that is both a Call for the RPC and an iterator
73
"""
74
```
75
76
**Usage Examples:**
77
78
```python
79
class LoggingClientInterceptor(
80
grpc.UnaryUnaryClientInterceptor,
81
grpc.UnaryStreamClientInterceptor,
82
grpc.StreamUnaryClientInterceptor,
83
grpc.StreamStreamClientInterceptor
84
):
85
"""Multi-pattern logging interceptor."""
86
87
def intercept_unary_unary(self, continuation, client_call_details, request):
88
print(f"[CLIENT] Calling {client_call_details.method}")
89
start_time = time.time()
90
91
try:
92
response = continuation(client_call_details, request)
93
duration = time.time() - start_time
94
print(f"[CLIENT] {client_call_details.method} completed in {duration:.3f}s")
95
return response
96
except grpc.RpcError as e:
97
duration = time.time() - start_time
98
print(f"[CLIENT] {client_call_details.method} failed in {duration:.3f}s: {e}")
99
raise
100
101
def intercept_unary_stream(self, continuation, client_call_details, request):
102
print(f"[CLIENT] Starting stream {client_call_details.method}")
103
response_iterator = continuation(client_call_details, request)
104
105
# Wrap iterator to log each response
106
def logged_iterator():
107
count = 0
108
try:
109
for response in response_iterator:
110
count += 1
111
yield response
112
print(f"[CLIENT] Stream {client_call_details.method} completed with {count} responses")
113
except grpc.RpcError as e:
114
print(f"[CLIENT] Stream {client_call_details.method} failed after {count} responses: {e}")
115
raise
116
117
return logged_iterator()
118
119
# Authentication interceptor
120
class AuthClientInterceptor(grpc.UnaryUnaryClientInterceptor):
121
def __init__(self, token_provider):
122
self.token_provider = token_provider
123
124
def intercept_unary_unary(self, continuation, client_call_details, request):
125
# Add authentication metadata
126
metadata = []
127
if client_call_details.metadata:
128
metadata = list(client_call_details.metadata)
129
130
# Get fresh token
131
token = self.token_provider.get_token()
132
metadata.append(('authorization', f'Bearer {token}'))
133
134
# Create new call details with auth metadata
135
authenticated_call_details = client_call_details._replace(metadata=metadata)
136
137
return continuation(authenticated_call_details, request)
138
139
# Retry interceptor
140
class RetryClientInterceptor(grpc.UnaryUnaryClientInterceptor):
141
def __init__(self, max_retries=3, backoff_factor=1.0):
142
self.max_retries = max_retries
143
self.backoff_factor = backoff_factor
144
145
def intercept_unary_unary(self, continuation, client_call_details, request):
146
for attempt in range(self.max_retries + 1):
147
try:
148
return continuation(client_call_details, request)
149
except grpc.RpcError as e:
150
if attempt == self.max_retries:
151
raise # Final attempt, re-raise the error
152
153
if e.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED]:
154
sleep_time = self.backoff_factor * (2 ** attempt)
155
print(f"Retrying in {sleep_time}s (attempt {attempt + 1}/{self.max_retries})")
156
time.sleep(sleep_time)
157
else:
158
raise # Don't retry for non-transient errors
159
160
# Using client interceptors
161
def create_intercepted_channel():
162
base_channel = grpc.insecure_channel('localhost:50051')
163
164
interceptors = [
165
AuthClientInterceptor(TokenProvider()),
166
LoggingClientInterceptor(),
167
RetryClientInterceptor(max_retries=3),
168
]
169
170
return grpc.intercept_channel(base_channel, *interceptors)
171
172
channel = create_intercepted_channel()
173
stub = my_service_pb2_grpc.MyServiceStub(channel)
174
response = stub.MyMethod(request) # Will be logged, authenticated, and retried if needed
175
```
176
177
### Server Interceptors
178
179
Server-side interceptors for modifying incoming RPCs, implementing cross-cutting concerns, and state management.
180
181
```python { .api }
182
class ServerInterceptor(abc.ABC):
183
"""Affords intercepting incoming RPCs on the service-side."""
184
185
def intercept_service(self, continuation, handler_call_details):
186
"""
187
Intercepts incoming RPCs before handing them over to a handler.
188
189
State can be passed from an interceptor to downstream interceptors
190
and handlers via contextvars. The first interceptor is called from an
191
empty contextvars.Context.
192
193
Parameters:
194
- continuation: Function that takes HandlerCallDetails and proceeds
195
- handler_call_details: HandlerCallDetails describing the RPC
196
197
Returns:
198
RpcMethodHandler or None: Handler if the RPC is considered serviced
199
"""
200
```
201
202
**Usage Examples:**
203
204
```python
205
import contextvars
206
import time
207
import uuid
208
209
# Context variables for passing state
210
request_id_var = contextvars.ContextVar('request_id')
211
user_id_var = contextvars.ContextVar('user_id')
212
213
class RequestIdServerInterceptor(grpc.ServerInterceptor):
214
"""Generates unique request IDs for tracing."""
215
216
def intercept_service(self, continuation, handler_call_details):
217
# Generate unique request ID
218
request_id = str(uuid.uuid4())
219
request_id_var.set(request_id)
220
221
print(f"[{request_id}] Incoming RPC: {handler_call_details.method}")
222
223
# Continue to next interceptor/handler
224
return continuation(handler_call_details)
225
226
class AuthenticationServerInterceptor(grpc.ServerInterceptor):
227
"""Validates authentication tokens."""
228
229
def __init__(self, token_validator):
230
self.token_validator = token_validator
231
232
def intercept_service(self, continuation, handler_call_details):
233
# Extract metadata
234
metadata = dict(handler_call_details.invocation_metadata)
235
auth_header = metadata.get('authorization', '')
236
237
if not auth_header.startswith('Bearer '):
238
# Return handler that aborts with authentication error
239
def abort_handler(request, context):
240
context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Missing or invalid token')
241
242
return grpc.unary_unary_rpc_method_handler(abort_handler)
243
244
token = auth_header[7:] # Remove 'Bearer ' prefix
245
try:
246
user_id = self.token_validator.validate_token(token)
247
user_id_var.set(user_id)
248
249
request_id = request_id_var.get('unknown')
250
print(f"[{request_id}] Authenticated user: {user_id}")
251
252
return continuation(handler_call_details)
253
except InvalidTokenError:
254
def abort_handler(request, context):
255
context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token')
256
257
return grpc.unary_unary_rpc_method_handler(abort_handler)
258
259
class LoggingServerInterceptor(grpc.ServerInterceptor):
260
"""Logs request/response timing and errors."""
261
262
def intercept_service(self, continuation, handler_call_details):
263
handler = continuation(handler_call_details)
264
if handler is None:
265
return None
266
267
def wrap_unary_unary(behavior):
268
def wrapped_behavior(request, context):
269
request_id = request_id_var.get('unknown')
270
user_id = user_id_var.get('anonymous')
271
start_time = time.time()
272
273
try:
274
print(f"[{request_id}] Processing {handler_call_details.method} for user {user_id}")
275
response = behavior(request, context)
276
duration = time.time() - start_time
277
print(f"[{request_id}] Completed in {duration:.3f}s")
278
return response
279
except Exception as e:
280
duration = time.time() - start_time
281
print(f"[{request_id}] Failed in {duration:.3f}s: {type(e).__name__}: {e}")
282
raise
283
284
return wrapped_behavior
285
286
def wrap_unary_stream(behavior):
287
def wrapped_behavior(request, context):
288
request_id = request_id_var.get('unknown')
289
start_time = time.time()
290
count = 0
291
292
try:
293
print(f"[{request_id}] Starting stream {handler_call_details.method}")
294
for response in behavior(request, context):
295
count += 1
296
yield response
297
duration = time.time() - start_time
298
print(f"[{request_id}] Stream completed in {duration:.3f}s with {count} responses")
299
except Exception as e:
300
duration = time.time() - start_time
301
print(f"[{request_id}] Stream failed in {duration:.3f}s after {count} responses: {e}")
302
raise
303
304
return wrapped_behavior
305
306
# Wrap the appropriate method based on handler type
307
if handler.unary_unary:
308
return handler._replace(unary_unary=wrap_unary_unary(handler.unary_unary))
309
elif handler.unary_stream:
310
return handler._replace(unary_stream=wrap_unary_stream(handler.unary_stream))
311
# Add similar wrapping for stream_unary and stream_stream
312
313
return handler
314
315
class RateLimitingServerInterceptor(grpc.ServerInterceptor):
316
"""Implements rate limiting per user."""
317
318
def __init__(self, max_requests_per_minute=60):
319
self.max_requests = max_requests_per_minute
320
self.request_counts = {}
321
self.lock = threading.Lock()
322
323
def intercept_service(self, continuation, handler_call_details):
324
user_id = user_id_var.get('anonymous')
325
current_minute = int(time.time() // 60)
326
327
with self.lock:
328
key = (user_id, current_minute)
329
count = self.request_counts.get(key, 0)
330
331
if count >= self.max_requests:
332
def rate_limit_handler(request, context):
333
context.abort(
334
grpc.StatusCode.RESOURCE_EXHAUSTED,
335
'Rate limit exceeded'
336
)
337
338
return grpc.unary_unary_rpc_method_handler(rate_limit_handler)
339
340
self.request_counts[key] = count + 1
341
342
# Clean up old entries
343
cutoff = current_minute - 5 # Keep last 5 minutes
344
keys_to_remove = [k for k in self.request_counts.keys() if k[1] < cutoff]
345
for k in keys_to_remove:
346
del self.request_counts[k]
347
348
return continuation(handler_call_details)
349
350
# Using server interceptors
351
def create_server_with_interceptors():
352
interceptors = [
353
RequestIdServerInterceptor(),
354
AuthenticationServerInterceptor(TokenValidator()),
355
RateLimitingServerInterceptor(max_requests_per_minute=100),
356
LoggingServerInterceptor(),
357
]
358
359
server = grpc.server(
360
futures.ThreadPoolExecutor(max_workers=10),
361
interceptors=interceptors
362
)
363
364
return server
365
366
# In servicer implementation, access context variables
367
class MyServiceServicer(my_service_pb2_grpc.MyServiceServicer):
368
def MyMethod(self, request, context):
369
request_id = request_id_var.get('unknown')
370
user_id = user_id_var.get('anonymous')
371
372
print(f"[{request_id}] Processing request for user {user_id}")
373
374
# Your business logic here
375
return my_service_pb2.MyResponse(message=f"Hello {user_id}")
376
```
377
378
### Advanced Interceptor Patterns
379
380
More sophisticated interceptor implementations for complex scenarios.
381
382
**Circuit Breaker Pattern:**
383
384
```python
385
import enum
386
from collections import defaultdict
387
388
class CircuitState(enum.Enum):
389
CLOSED = "closed"
390
OPEN = "open"
391
HALF_OPEN = "half_open"
392
393
class CircuitBreakerInterceptor(grpc.UnaryUnaryClientInterceptor):
394
def __init__(self, failure_threshold=5, timeout=60, half_open_max_calls=3):
395
self.failure_threshold = failure_threshold
396
self.timeout = timeout
397
self.half_open_max_calls = half_open_max_calls
398
399
self.failure_counts = defaultdict(int)
400
self.last_failure_time = defaultdict(float)
401
self.circuit_state = defaultdict(lambda: CircuitState.CLOSED)
402
self.half_open_calls = defaultdict(int)
403
self.lock = threading.Lock()
404
405
def intercept_unary_unary(self, continuation, client_call_details, request):
406
service = client_call_details.method.split('/')[1] # Extract service name
407
408
with self.lock:
409
state = self.circuit_state[service]
410
411
# Check if circuit should transition from OPEN to HALF_OPEN
412
if (state == CircuitState.OPEN and
413
time.time() - self.last_failure_time[service] > self.timeout):
414
self.circuit_state[service] = CircuitState.HALF_OPEN
415
self.half_open_calls[service] = 0
416
state = CircuitState.HALF_OPEN
417
418
# Reject calls when circuit is OPEN
419
if state == CircuitState.OPEN:
420
raise grpc.RpcError("Circuit breaker is OPEN")
421
422
# Limit calls in HALF_OPEN state
423
if (state == CircuitState.HALF_OPEN and
424
self.half_open_calls[service] >= self.half_open_max_calls):
425
raise grpc.RpcError("Circuit breaker is HALF_OPEN with max calls reached")
426
427
if state == CircuitState.HALF_OPEN:
428
self.half_open_calls[service] += 1
429
430
try:
431
response = continuation(client_call_details, request)
432
433
# Success - reset failure count or close circuit
434
with self.lock:
435
if state == CircuitState.HALF_OPEN:
436
self.circuit_state[service] = CircuitState.CLOSED
437
self.failure_counts[service] = 0
438
439
return response
440
441
except grpc.RpcError as e:
442
with self.lock:
443
self.failure_counts[service] += 1
444
self.last_failure_time[service] = time.time()
445
446
if self.failure_counts[service] >= self.failure_threshold:
447
self.circuit_state[service] = CircuitState.OPEN
448
449
raise
450
```
451
452
**Metrics Collection:**
453
454
```python
455
from collections import defaultdict
456
import threading
457
458
class MetricsInterceptor(grpc.ServerInterceptor):
459
def __init__(self):
460
self.request_counts = defaultdict(int)
461
self.error_counts = defaultdict(int)
462
self.latency_sums = defaultdict(float)
463
self.lock = threading.Lock()
464
465
def intercept_service(self, continuation, handler_call_details):
466
handler = continuation(handler_call_details)
467
if handler is None:
468
return None
469
470
method = handler_call_details.method
471
472
def wrap_behavior(behavior):
473
def wrapped_behavior(request, context):
474
start_time = time.time()
475
476
with self.lock:
477
self.request_counts[method] += 1
478
479
try:
480
if hasattr(behavior, '__call__'):
481
result = behavior(request, context)
482
else:
483
result = list(behavior(request, context)) # For streaming
484
485
duration = time.time() - start_time
486
with self.lock:
487
self.latency_sums[method] += duration
488
489
return result
490
491
except Exception as e:
492
with self.lock:
493
self.error_counts[method] += 1
494
raise
495
496
return wrapped_behavior
497
498
# Wrap the appropriate handler method
499
if handler.unary_unary:
500
return handler._replace(unary_unary=wrap_behavior(handler.unary_unary))
501
elif handler.unary_stream:
502
return handler._replace(unary_stream=wrap_behavior(handler.unary_stream))
503
504
return handler
505
506
def get_metrics(self):
507
with self.lock:
508
return {
509
'requests': dict(self.request_counts),
510
'errors': dict(self.error_counts),
511
'avg_latency': {
512
method: self.latency_sums[method] / self.request_counts[method]
513
for method in self.request_counts
514
if self.request_counts[method] > 0
515
}
516
}
517
```
518
519
## Types
520
521
```python { .api }
522
class ClientCallDetails(abc.ABC):
523
"""
524
Describes an RPC to be invoked.
525
526
Attributes:
527
- method: The method name of the RPC
528
- timeout: Optional duration of time in seconds to allow for the RPC
529
- metadata: Optional metadata to be transmitted to the service-side
530
- credentials: Optional CallCredentials for the RPC
531
- wait_for_ready: Optional flag to enable wait_for_ready mechanism
532
- compression: Optional compression element
533
"""
534
535
class HandlerCallDetails(abc.ABC):
536
"""
537
Describes an RPC that has just arrived for service.
538
539
Attributes:
540
- method: The method name of the RPC
541
- invocation_metadata: The metadata sent by the client
542
"""
543
```