HTTP/2-based RPC framework with synchronous and asynchronous APIs for building distributed systems
—
Client and server-side middleware for cross-cutting concerns like logging, metrics, authentication, request/response modification, and tracing, with comprehensive support for all RPC patterns and state passing via contextvars.
Client-side interceptors that can modify requests, responses, and RPC metadata for all four RPC patterns.
class UnaryUnaryClientInterceptor(abc.ABC):
"""Affords intercepting unary-unary invocations."""
def intercept_unary_unary(self, continuation, client_call_details, request):
"""
Intercepts a unary-unary invocation asynchronously.
Parameters:
- continuation: Function to proceed with the invocation
- client_call_details: ClientCallDetails object describing the outgoing RPC
- request: The request value for the RPC
Returns:
Call-Future: Object that is both a Call for the RPC and a Future
"""
class UnaryStreamClientInterceptor(abc.ABC):
"""Affords intercepting unary-stream invocations."""
def intercept_unary_stream(self, continuation, client_call_details, request):
"""
Intercepts a unary-stream invocation.
Parameters:
- continuation: Function to proceed with the invocation
- client_call_details: ClientCallDetails object describing the outgoing RPC
- request: The request value for the RPC
Returns:
Call-iterator: Object that is both a Call for the RPC and an iterator
"""
class StreamUnaryClientInterceptor(abc.ABC):
"""Affords intercepting stream-unary invocations."""
def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
"""
Intercepts a stream-unary invocation asynchronously.
Parameters:
- continuation: Function to proceed with the invocation
- client_call_details: ClientCallDetails object describing the outgoing RPC
- request_iterator: An iterator that yields request values for the RPC
Returns:
Call-Future: Object that is both a Call for the RPC and a Future
"""
class StreamStreamClientInterceptor(abc.ABC):
"""Affords intercepting stream-stream invocations."""
def intercept_stream_stream(self, continuation, client_call_details, request_iterator):
"""
Intercepts a stream-stream invocation.
Parameters:
- continuation: Function to proceed with the invocation
- client_call_details: ClientCallDetails object describing the outgoing RPC
- request_iterator: An iterator that yields request values for the RPC
Returns:
Call-iterator: Object that is both a Call for the RPC and an iterator
"""Usage Examples:
class LoggingClientInterceptor(
grpc.UnaryUnaryClientInterceptor,
grpc.UnaryStreamClientInterceptor,
grpc.StreamUnaryClientInterceptor,
grpc.StreamStreamClientInterceptor
):
"""Multi-pattern logging interceptor."""
def intercept_unary_unary(self, continuation, client_call_details, request):
print(f"[CLIENT] Calling {client_call_details.method}")
start_time = time.time()
try:
response = continuation(client_call_details, request)
duration = time.time() - start_time
print(f"[CLIENT] {client_call_details.method} completed in {duration:.3f}s")
return response
except grpc.RpcError as e:
duration = time.time() - start_time
print(f"[CLIENT] {client_call_details.method} failed in {duration:.3f}s: {e}")
raise
def intercept_unary_stream(self, continuation, client_call_details, request):
print(f"[CLIENT] Starting stream {client_call_details.method}")
response_iterator = continuation(client_call_details, request)
# Wrap iterator to log each response
def logged_iterator():
count = 0
try:
for response in response_iterator:
count += 1
yield response
print(f"[CLIENT] Stream {client_call_details.method} completed with {count} responses")
except grpc.RpcError as e:
print(f"[CLIENT] Stream {client_call_details.method} failed after {count} responses: {e}")
raise
return logged_iterator()
# Authentication interceptor
class AuthClientInterceptor(grpc.UnaryUnaryClientInterceptor):
def __init__(self, token_provider):
self.token_provider = token_provider
def intercept_unary_unary(self, continuation, client_call_details, request):
# Add authentication metadata
metadata = []
if client_call_details.metadata:
metadata = list(client_call_details.metadata)
# Get fresh token
token = self.token_provider.get_token()
metadata.append(('authorization', f'Bearer {token}'))
# Create new call details with auth metadata
authenticated_call_details = client_call_details._replace(metadata=metadata)
return continuation(authenticated_call_details, request)
# Retry interceptor
class RetryClientInterceptor(grpc.UnaryUnaryClientInterceptor):
def __init__(self, max_retries=3, backoff_factor=1.0):
self.max_retries = max_retries
self.backoff_factor = backoff_factor
def intercept_unary_unary(self, continuation, client_call_details, request):
for attempt in range(self.max_retries + 1):
try:
return continuation(client_call_details, request)
except grpc.RpcError as e:
if attempt == self.max_retries:
raise # Final attempt, re-raise the error
if e.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED]:
sleep_time = self.backoff_factor * (2 ** attempt)
print(f"Retrying in {sleep_time}s (attempt {attempt + 1}/{self.max_retries})")
time.sleep(sleep_time)
else:
raise # Don't retry for non-transient errors
# Using client interceptors
def create_intercepted_channel():
base_channel = grpc.insecure_channel('localhost:50051')
interceptors = [
AuthClientInterceptor(TokenProvider()),
LoggingClientInterceptor(),
RetryClientInterceptor(max_retries=3),
]
return grpc.intercept_channel(base_channel, *interceptors)
channel = create_intercepted_channel()
stub = my_service_pb2_grpc.MyServiceStub(channel)
response = stub.MyMethod(request) # Will be logged, authenticated, and retried if neededServer-side interceptors for modifying incoming RPCs, implementing cross-cutting concerns, and state management.
class ServerInterceptor(abc.ABC):
"""Affords intercepting incoming RPCs on the service-side."""
def intercept_service(self, continuation, handler_call_details):
"""
Intercepts incoming RPCs before handing them over to a handler.
State can be passed from an interceptor to downstream interceptors
and handlers via contextvars. The first interceptor is called from an
empty contextvars.Context.
Parameters:
- continuation: Function that takes HandlerCallDetails and proceeds
- handler_call_details: HandlerCallDetails describing the RPC
Returns:
RpcMethodHandler or None: Handler if the RPC is considered serviced
"""Usage Examples:
import contextvars
import time
import uuid
# Context variables for passing state
request_id_var = contextvars.ContextVar('request_id')
user_id_var = contextvars.ContextVar('user_id')
class RequestIdServerInterceptor(grpc.ServerInterceptor):
"""Generates unique request IDs for tracing."""
def intercept_service(self, continuation, handler_call_details):
# Generate unique request ID
request_id = str(uuid.uuid4())
request_id_var.set(request_id)
print(f"[{request_id}] Incoming RPC: {handler_call_details.method}")
# Continue to next interceptor/handler
return continuation(handler_call_details)
class AuthenticationServerInterceptor(grpc.ServerInterceptor):
"""Validates authentication tokens."""
def __init__(self, token_validator):
self.token_validator = token_validator
def intercept_service(self, continuation, handler_call_details):
# Extract metadata
metadata = dict(handler_call_details.invocation_metadata)
auth_header = metadata.get('authorization', '')
if not auth_header.startswith('Bearer '):
# Return handler that aborts with authentication error
def abort_handler(request, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Missing or invalid token')
return grpc.unary_unary_rpc_method_handler(abort_handler)
token = auth_header[7:] # Remove 'Bearer ' prefix
try:
user_id = self.token_validator.validate_token(token)
user_id_var.set(user_id)
request_id = request_id_var.get('unknown')
print(f"[{request_id}] Authenticated user: {user_id}")
return continuation(handler_call_details)
except InvalidTokenError:
def abort_handler(request, context):
context.abort(grpc.StatusCode.UNAUTHENTICATED, 'Invalid token')
return grpc.unary_unary_rpc_method_handler(abort_handler)
class LoggingServerInterceptor(grpc.ServerInterceptor):
"""Logs request/response timing and errors."""
def intercept_service(self, continuation, handler_call_details):
handler = continuation(handler_call_details)
if handler is None:
return None
def wrap_unary_unary(behavior):
def wrapped_behavior(request, context):
request_id = request_id_var.get('unknown')
user_id = user_id_var.get('anonymous')
start_time = time.time()
try:
print(f"[{request_id}] Processing {handler_call_details.method} for user {user_id}")
response = behavior(request, context)
duration = time.time() - start_time
print(f"[{request_id}] Completed in {duration:.3f}s")
return response
except Exception as e:
duration = time.time() - start_time
print(f"[{request_id}] Failed in {duration:.3f}s: {type(e).__name__}: {e}")
raise
return wrapped_behavior
def wrap_unary_stream(behavior):
def wrapped_behavior(request, context):
request_id = request_id_var.get('unknown')
start_time = time.time()
count = 0
try:
print(f"[{request_id}] Starting stream {handler_call_details.method}")
for response in behavior(request, context):
count += 1
yield response
duration = time.time() - start_time
print(f"[{request_id}] Stream completed in {duration:.3f}s with {count} responses")
except Exception as e:
duration = time.time() - start_time
print(f"[{request_id}] Stream failed in {duration:.3f}s after {count} responses: {e}")
raise
return wrapped_behavior
# Wrap the appropriate method based on handler type
if handler.unary_unary:
return handler._replace(unary_unary=wrap_unary_unary(handler.unary_unary))
elif handler.unary_stream:
return handler._replace(unary_stream=wrap_unary_stream(handler.unary_stream))
# Add similar wrapping for stream_unary and stream_stream
return handler
class RateLimitingServerInterceptor(grpc.ServerInterceptor):
"""Implements rate limiting per user."""
def __init__(self, max_requests_per_minute=60):
self.max_requests = max_requests_per_minute
self.request_counts = {}
self.lock = threading.Lock()
def intercept_service(self, continuation, handler_call_details):
user_id = user_id_var.get('anonymous')
current_minute = int(time.time() // 60)
with self.lock:
key = (user_id, current_minute)
count = self.request_counts.get(key, 0)
if count >= self.max_requests:
def rate_limit_handler(request, context):
context.abort(
grpc.StatusCode.RESOURCE_EXHAUSTED,
'Rate limit exceeded'
)
return grpc.unary_unary_rpc_method_handler(rate_limit_handler)
self.request_counts[key] = count + 1
# Clean up old entries
cutoff = current_minute - 5 # Keep last 5 minutes
keys_to_remove = [k for k in self.request_counts.keys() if k[1] < cutoff]
for k in keys_to_remove:
del self.request_counts[k]
return continuation(handler_call_details)
# Using server interceptors
def create_server_with_interceptors():
interceptors = [
RequestIdServerInterceptor(),
AuthenticationServerInterceptor(TokenValidator()),
RateLimitingServerInterceptor(max_requests_per_minute=100),
LoggingServerInterceptor(),
]
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=10),
interceptors=interceptors
)
return server
# In servicer implementation, access context variables
class MyServiceServicer(my_service_pb2_grpc.MyServiceServicer):
def MyMethod(self, request, context):
request_id = request_id_var.get('unknown')
user_id = user_id_var.get('anonymous')
print(f"[{request_id}] Processing request for user {user_id}")
# Your business logic here
return my_service_pb2.MyResponse(message=f"Hello {user_id}")More sophisticated interceptor implementations for complex scenarios.
Circuit Breaker Pattern:
import enum
from collections import defaultdict
class CircuitState(enum.Enum):
CLOSED = "closed"
OPEN = "open"
HALF_OPEN = "half_open"
class CircuitBreakerInterceptor(grpc.UnaryUnaryClientInterceptor):
def __init__(self, failure_threshold=5, timeout=60, half_open_max_calls=3):
self.failure_threshold = failure_threshold
self.timeout = timeout
self.half_open_max_calls = half_open_max_calls
self.failure_counts = defaultdict(int)
self.last_failure_time = defaultdict(float)
self.circuit_state = defaultdict(lambda: CircuitState.CLOSED)
self.half_open_calls = defaultdict(int)
self.lock = threading.Lock()
def intercept_unary_unary(self, continuation, client_call_details, request):
service = client_call_details.method.split('/')[1] # Extract service name
with self.lock:
state = self.circuit_state[service]
# Check if circuit should transition from OPEN to HALF_OPEN
if (state == CircuitState.OPEN and
time.time() - self.last_failure_time[service] > self.timeout):
self.circuit_state[service] = CircuitState.HALF_OPEN
self.half_open_calls[service] = 0
state = CircuitState.HALF_OPEN
# Reject calls when circuit is OPEN
if state == CircuitState.OPEN:
raise grpc.RpcError("Circuit breaker is OPEN")
# Limit calls in HALF_OPEN state
if (state == CircuitState.HALF_OPEN and
self.half_open_calls[service] >= self.half_open_max_calls):
raise grpc.RpcError("Circuit breaker is HALF_OPEN with max calls reached")
if state == CircuitState.HALF_OPEN:
self.half_open_calls[service] += 1
try:
response = continuation(client_call_details, request)
# Success - reset failure count or close circuit
with self.lock:
if state == CircuitState.HALF_OPEN:
self.circuit_state[service] = CircuitState.CLOSED
self.failure_counts[service] = 0
return response
except grpc.RpcError as e:
with self.lock:
self.failure_counts[service] += 1
self.last_failure_time[service] = time.time()
if self.failure_counts[service] >= self.failure_threshold:
self.circuit_state[service] = CircuitState.OPEN
raiseMetrics Collection:
from collections import defaultdict
import threading
class MetricsInterceptor(grpc.ServerInterceptor):
def __init__(self):
self.request_counts = defaultdict(int)
self.error_counts = defaultdict(int)
self.latency_sums = defaultdict(float)
self.lock = threading.Lock()
def intercept_service(self, continuation, handler_call_details):
handler = continuation(handler_call_details)
if handler is None:
return None
method = handler_call_details.method
def wrap_behavior(behavior):
def wrapped_behavior(request, context):
start_time = time.time()
with self.lock:
self.request_counts[method] += 1
try:
if hasattr(behavior, '__call__'):
result = behavior(request, context)
else:
result = list(behavior(request, context)) # For streaming
duration = time.time() - start_time
with self.lock:
self.latency_sums[method] += duration
return result
except Exception as e:
with self.lock:
self.error_counts[method] += 1
raise
return wrapped_behavior
# Wrap the appropriate handler method
if handler.unary_unary:
return handler._replace(unary_unary=wrap_behavior(handler.unary_unary))
elif handler.unary_stream:
return handler._replace(unary_stream=wrap_behavior(handler.unary_stream))
return handler
def get_metrics(self):
with self.lock:
return {
'requests': dict(self.request_counts),
'errors': dict(self.error_counts),
'avg_latency': {
method: self.latency_sums[method] / self.request_counts[method]
for method in self.request_counts
if self.request_counts[method] > 0
}
}class ClientCallDetails(abc.ABC):
"""
Describes an RPC to be invoked.
Attributes:
- method: The method name of the RPC
- timeout: Optional duration of time in seconds to allow for the RPC
- metadata: Optional metadata to be transmitted to the service-side
- credentials: Optional CallCredentials for the RPC
- wait_for_ready: Optional flag to enable wait_for_ready mechanism
- compression: Optional compression element
"""
class HandlerCallDetails(abc.ABC):
"""
Describes an RPC that has just arrived for service.
Attributes:
- method: The method name of the RPC
- invocation_metadata: The metadata sent by the client
"""Install with Tessl CLI
npx tessl i tessl/pypi-grpcio