tessl install tessl/pypi-kserve@0.16.1KServe is a comprehensive Python SDK that provides standardized interfaces for building and deploying machine learning model serving infrastructure on Kubernetes.
Complete, production-ready examples demonstrating common KServe use cases with proper error handling, monitoring, and best practices.
Complete image classification model server with preprocessing, validation, and monitoring.
from kserve import Model, ModelServer, logger
from kserve.errors import InvalidInput, InferenceError
from kserve.metrics import PREDICT_HIST_TIME, get_labels
import torch
import torchvision.transforms as transforms
from PIL import Image
import io
import base64
import time
class ImageClassifier(Model):
def __init__(self, name: str):
super().__init__(name)
self.model = None
self.transform = None
self.classes = None
self.device = None
def load(self):
"""Load ResNet model"""
logger.info(f"Loading image classifier {self.name}")
# Setup device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {self.device}")
# Load model
self.model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
self.model = self.model.to(self.device)
self.model.eval()
# Setup transforms
self.transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load class labels
self.classes = self._load_imagenet_classes()
self.ready = True
logger.info(f"Image classifier {self.name} loaded successfully")
def preprocess(self, body, headers=None):
"""Decode and preprocess images"""
instances = body.get("instances")
if not instances:
raise InvalidInput("Missing 'instances' in request")
if len(instances) > 32:
raise InvalidInput(f"Batch size {len(instances)} exceeds maximum of 32")
processed_images = []
for idx, instance in enumerate(instances):
try:
# Decode base64 image
if isinstance(instance, str):
image_bytes = base64.b64decode(instance)
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
elif isinstance(instance, dict) and "b64" in instance:
image_bytes = base64.b64decode(instance["b64"])
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
else:
raise InvalidInput(f"Instance {idx}: Invalid format, expected base64 string")
# Apply transforms
tensor = self.transform(image)
processed_images.append(tensor)
except Exception as e:
raise InvalidInput(f"Instance {idx}: Failed to process image: {e}")
# Stack into batch
batch = torch.stack(processed_images)
return {"tensor": batch}
def predict(self, payload, headers=None):
"""Run inference"""
labels = get_labels(self.name)
start_time = time.time()
try:
# Get preprocessed tensor
batch = payload["tensor"].to(self.device)
# Run inference
with torch.no_grad():
outputs = self.model(batch)
probabilities = torch.nn.functional.softmax(outputs, dim=1)
# Get top-5 predictions
top5_prob, top5_idx = torch.topk(probabilities, 5, dim=1)
# Record latency
elapsed = time.time() - start_time
PREDICT_HIST_TIME.labels(**labels).observe(elapsed)
return {
"probabilities": top5_prob.cpu().numpy().tolist(),
"class_indices": top5_idx.cpu().numpy().tolist()
}
except Exception as e:
logger.error(f"Inference failed: {e}", exc_info=True)
raise InferenceError(f"Image classification failed: {e}")
def postprocess(self, response, headers=None):
"""Add class names to response"""
class_indices = response["class_indices"]
probabilities = response["probabilities"]
results = []
for indices, probs in zip(class_indices, probabilities):
predictions = [
{
"class": self.classes[idx],
"confidence": float(prob)
}
for idx, prob in zip(indices, probs)
]
results.append(predictions)
return {"predictions": results}
def _load_imagenet_classes(self):
"""Load ImageNet class labels"""
# Simplified - load from file in production
return [f"class_{i}" for i in range(1000)]
if __name__ == "__main__":
model = ImageClassifier("resnet50")
model.load()
ModelServer(
http_port=8080,
workers=2,
enable_docs_url=True
).start([model])Ensemble model combining multiple models with weighted voting.
from kserve import Model, ModelServer
import numpy as np
class EnsembleModel(Model):
def __init__(self, name: str, model_paths: dict, weights: dict):
super().__init__(name)
self.model_paths = model_paths
self.weights = weights
self.models = {}
def load(self):
"""Load all ensemble models"""
logger.info(f"Loading ensemble {self.name}")
for model_name, model_path in self.model_paths.items():
logger.info(f"Loading sub-model: {model_name}")
self.models[model_name] = joblib.load(model_path)
logger.info(f"Loaded {len(self.models)} models")
self.ready = True
def predict(self, payload, headers=None):
"""Run ensemble prediction with weighted voting"""
instances = payload["instances"]
# Get predictions from all models
all_predictions = {}
for model_name, model in self.models.items():
try:
preds = model.predict_proba(instances)
all_predictions[model_name] = preds
except Exception as e:
logger.warning(f"Model {model_name} failed: {e}")
# Continue with other models
if not all_predictions:
raise InferenceError("All ensemble models failed")
# Weighted average
weighted_sum = None
total_weight = 0
for model_name, preds in all_predictions.items():
weight = self.weights.get(model_name, 1.0)
if weighted_sum is None:
weighted_sum = preds * weight
else:
weighted_sum += preds * weight
total_weight += weight
# Final predictions
ensemble_probs = weighted_sum / total_weight
ensemble_classes = np.argmax(ensemble_probs, axis=1)
return {
"predictions": ensemble_classes.tolist(),
"probabilities": ensemble_probs.tolist(),
"model_contributions": {
name: self.weights.get(name, 1.0)
for name in all_predictions.keys()
}
}
if __name__ == "__main__":
model = EnsembleModel(
name="ensemble-classifier",
model_paths={
"model_a": "/mnt/models/model_a.pkl",
"model_b": "/mnt/models/model_b.pkl",
"model_c": "/mnt/models/model_c.pkl"
},
weights={
"model_a": 0.5,
"model_b": 0.3,
"model_c": 0.2
}
)
model.load()
ModelServer().start([model])Pre-processing transformer that chains to a predictor service.
from kserve import Model, ModelServer
from kserve.context import set_predictor_config, get_predictor_config
from kserve import PredictorConfig
import httpx
class ImageTransformer(Model):
def __init__(self, name: str, predictor_host: str):
super().__init__(name)
self.predictor_host = predictor_host
def load(self):
"""Setup transformer"""
# Configure predictor
config = PredictorConfig()
config.predictor_host = self.predictor_host
config.predictor_protocol = "v2"
config.predictor_use_ssl = False
config.predictor_request_timeout_seconds = 60
set_predictor_config(config)
self.ready = True
logger.info(f"Transformer {self.name} ready")
def preprocess(self, body, headers=None):
"""Transform images"""
instances = body["instances"]
# Apply transformations
transformed = []
for image in instances:
# Resize, normalize, etc.
processed = self._transform_image(image)
transformed.append(processed)
return {"instances": transformed}
async def predict(self, payload, headers=None):
"""Forward to predictor"""
config = get_predictor_config()
# Call predictor
async with httpx.AsyncClient() as client:
url = f"{config.predictor_base_url}/v2/models/predictor/infer"
response = await client.post(
url,
json=payload,
timeout=config.predictor_request_timeout_seconds
)
return response.json()
def postprocess(self, response, headers=None):
"""Add transformer metadata"""
return {
**response,
"transformer": self.name,
"transformer_version": "1.0.0"
}
if __name__ == "__main__":
transformer = ImageTransformer(
name="image-transformer",
predictor_host="resnet-predictor.default.svc.cluster.local:8080"
)
transformer.load()
ModelServer().start([transformer])Serve a large language model with OpenAI-compatible endpoints.
from kserve import ModelServer
from kserve.protocol.rest.openai import OpenAIGenerativeModel
from kserve.protocol.rest.openai.types import (
ChatCompletionRequest,
ChatCompletion,
CompletionRequest,
Completion
)
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import time
import uuid
class LLMModel(OpenAIGenerativeModel):
def __init__(self, name: str, model_name: str = "gpt2"):
super().__init__(name)
self.model_name = model_name
self.model = None
self.tokenizer = None
self.device = None
def load(self):
"""Load LLM model"""
logger.info(f"Loading LLM {self.model_name}")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load model and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model = AutoModelForCausalLM.from_pretrained(self.model_name)
self.model = self.model.to(self.device)
self.model.eval()
self.ready = True
logger.info(f"LLM {self.name} loaded on {self.device}")
async def create_chat_completion(self, request, raw_request, context):
"""Generate chat completion"""
# Extract messages
messages = request.messages
# Format prompt from messages
prompt = self._format_chat_prompt(messages)
# Generate response
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_length=request.max_tokens or 100,
temperature=request.temperature or 0.7,
top_p=request.top_p or 0.9,
do_sample=True
)
# Decode response
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
response_text = response_text[len(prompt):].strip()
# Create OpenAI-compatible response
return ChatCompletion(
id=f"chatcmpl-{uuid.uuid4()}",
object="chat.completion",
created=int(time.time()),
model=self.name,
choices=[{
"index": 0,
"message": {
"role": "assistant",
"content": response_text
},
"finish_reason": "stop"
}],
usage={
"prompt_tokens": len(inputs.input_ids[0]),
"completion_tokens": len(outputs[0]) - len(inputs.input_ids[0]),
"total_tokens": len(outputs[0])
}
)
async def create_completion(self, request, raw_request, context):
"""Generate text completion"""
prompt = request.prompt
# Generate
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
inputs.input_ids,
max_length=request.max_tokens or 100,
temperature=request.temperature or 1.0
)
response_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return Completion(
id=f"cmpl-{uuid.uuid4()}",
object="text_completion",
created=int(time.time()),
model=self.name,
choices=[{
"text": response_text,
"index": 0,
"finish_reason": "stop"
}]
)
def _format_chat_prompt(self, messages):
"""Format messages into prompt"""
prompt_parts = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
prompt_parts.append(f"{role.capitalize()}: {content}")
prompt_parts.append("Assistant:")
return "\n".join(prompt_parts)
if __name__ == "__main__":
model = LLMModel("gpt2-model", model_name="gpt2")
model.load()
ModelServer(
http_port=8080,
enable_docs_url=True
).start([model])Usage:
# Chat completion
curl -X POST http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt2-model",
"messages": [
{"role": "user", "content": "What is machine learning?"}
],
"max_tokens": 100
}'
# Text completion
curl -X POST http://localhost:8080/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt2-model",
"prompt": "Machine learning is",
"max_tokens": 50
}'Time series model with sliding window preprocessing.
from kserve import Model, ModelServer
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
class TimeSeriesModel(Model):
def __init__(self, name: str, window_size: int = 24):
super().__init__(name)
self.window_size = window_size
self.scaler = None
def load(self):
"""Load forecasting model"""
self.model = joblib.load("/mnt/models/lstm_model.pkl")
self.scaler = joblib.load("/mnt/models/scaler.pkl")
self.ready = True
def preprocess(self, body, headers=None):
"""Create sliding windows from time series"""
instances = body["instances"]
windows = []
for series in instances:
# Validate length
if len(series) < self.window_size:
raise InvalidInput(
f"Series length {len(series)} is less than window size {self.window_size}"
)
# Normalize
normalized = self.scaler.transform(np.array(series).reshape(-1, 1))
# Create windows
for i in range(len(normalized) - self.window_size + 1):
window = normalized[i:i+self.window_size]
windows.append(window.flatten())
return {"instances": windows}
def predict(self, payload, headers=None):
"""Forecast next values"""
windows = np.array(payload["instances"])
# Predict
forecasts = self.model.predict(windows)
# Inverse transform
forecasts = self.scaler.inverse_transform(forecasts.reshape(-1, 1))
return {"predictions": forecasts.flatten().tolist()}
def postprocess(self, response, headers=None):
"""Add forecast metadata"""
predictions = response["predictions"]
# Add timestamps for forecasts
now = datetime.now()
timestamps = [
(now + timedelta(hours=i)).isoformat()
for i in range(1, len(predictions) + 1)
]
return {
"forecasts": [
{"timestamp": ts, "value": val}
for ts, val in zip(timestamps, predictions)
]
}
if __name__ == "__main__":
model = TimeSeriesModel("lstm-forecaster", window_size=24)
model.load()
ModelServer().start([model])Serve multiple model versions with traffic splitting.
from kserve import Model, ModelServer, ModelRepository
import random
class VersionedModel(Model):
def __init__(self, name: str, version: str, model_path: str):
super().__init__(f"{name}-{version}")
self.version = version
self.model_path = model_path
def load(self):
self.model = joblib.load(self.model_path)
self.ready = True
logger.info(f"Loaded {self.name} version {self.version}")
def predict(self, payload, headers=None):
predictions = self.model.predict(payload["instances"])
return {
"predictions": predictions.tolist(),
"model_version": self.version
}
class ABTestingRouter(Model):
def __init__(self, name: str, models: dict, traffic_split: dict):
super().__init__(name)
self.models = models
self.traffic_split = traffic_split
def load(self):
"""Load all model versions"""
for model in self.models.values():
model.load()
self.ready = True
def predict(self, payload, headers=None):
"""Route to model based on traffic split"""
# Select model based on traffic split
rand = random.random()
cumulative = 0
selected_version = None
for version, percentage in self.traffic_split.items():
cumulative += percentage
if rand <= cumulative:
selected_version = version
break
# Route to selected model
model = self.models[selected_version]
result = model.predict(payload)
logger.info(f"Routed to version {selected_version}")
return result
if __name__ == "__main__":
# Create model versions
model_v1 = VersionedModel("classifier", "v1", "/mnt/models/v1/model.pkl")
model_v2 = VersionedModel("classifier", "v2", "/mnt/models/v2/model.pkl")
# Create A/B testing router
router = ABTestingRouter(
name="classifier-router",
models={"v1": model_v1, "v2": model_v2},
traffic_split={"v1": 0.9, "v2": 0.1} # 90% v1, 10% v2
)
router.load()
ModelServer().start([router])Model that fetches features from external feature store.
from kserve import Model, ModelServer
import httpx
import asyncio
class FeatureStoreModel(Model):
def __init__(self, name: str, feature_store_url: str):
super().__init__(name)
self.feature_store_url = feature_store_url
def load(self):
self.model = joblib.load("/mnt/models/model.pkl")
self.ready = True
async def preprocess(self, body, headers=None):
"""Fetch features from feature store"""
entity_ids = body["entity_ids"]
# Fetch features
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.feature_store_url}/features",
json={"entity_ids": entity_ids},
timeout=5.0
)
features = response.json()["features"]
return {"instances": features}
def predict(self, payload, headers=None):
"""Run prediction with fetched features"""
instances = payload["instances"]
predictions = self.model.predict(instances)
return {"predictions": predictions.tolist()}
if __name__ == "__main__":
model = FeatureStoreModel(
name="feature-model",
feature_store_url="http://feature-store:8080"
)
model.load()
ModelServer().start([model])Model that provides explanations using SHAP.
from kserve import Model, ModelServer
import shap
import numpy as np
class ExplainableModel(Model):
def __init__(self, name: str):
super().__init__(name)
self.explainer = None
def load(self):
"""Load model and explainer"""
self.model = joblib.load("/mnt/models/model.pkl")
# Load background data for SHAP
background = np.load("/mnt/models/background_data.npy")
self.explainer = shap.KernelExplainer(
self.model.predict_proba,
background
)
self.ready = True
logger.info(f"Model {self.name} loaded with explainer")
def predict(self, payload, headers=None):
"""Run prediction"""
instances = payload["instances"]
predictions = self.model.predict(instances)
probabilities = self.model.predict_proba(instances)
return {
"predictions": predictions.tolist(),
"probabilities": probabilities.tolist()
}
def explain(self, payload, headers=None):
"""Generate SHAP explanations"""
instances = np.array(payload["instances"])
# Compute SHAP values
shap_values = self.explainer.shap_values(instances)
# Format explanations
explanations = []
for idx, instance in enumerate(instances):
explanation = {
"instance": instance.tolist(),
"shap_values": shap_values[idx].tolist(),
"base_value": self.explainer.expected_value.tolist()
}
explanations.append(explanation)
return {"explanations": explanations}
if __name__ == "__main__":
model = ExplainableModel("explainable-classifier")
model.load()
ModelServer().start([model])Usage:
# Make prediction
curl -X POST http://localhost:8080/v1/models/explainable-classifier:predict \
-d '{"instances": [[5.1, 3.5, 1.4, 0.2]]}'
# Get explanation
curl -X POST http://localhost:8080/v1/models/explainable-classifier:explain \
-d '{"instances": [[5.1, 3.5, 1.4, 0.2]]}'Comprehensive input validation with detailed error messages.
from kserve import Model, ModelServer
from kserve.errors import InvalidInput
import numpy as np
from pydantic import BaseModel, validator
from typing import List
class PredictionRequest(BaseModel):
"""Validated request schema"""
instances: List[List[float]]
@validator('instances')
def validate_instances(cls, v):
if not v:
raise ValueError("instances cannot be empty")
if len(v) > 32:
raise ValueError(f"batch size {len(v)} exceeds maximum of 32")
for idx, instance in enumerate(v):
if len(instance) != 4:
raise ValueError(f"instance {idx} must have 4 features")
return v
class ValidatedModel(Model):
def predict(self, payload, headers=None):
"""Predict with Pydantic validation"""
try:
# Validate using Pydantic
request = PredictionRequest(**payload)
instances = request.instances
except ValueError as e:
raise InvalidInput(str(e))
# Run prediction
predictions = self.model.predict(instances)
return {"predictions": predictions.tolist()}
if __name__ == "__main__":
model = ValidatedModel("validated-model")
model.load()
ModelServer().start([model])Model repository with dynamic model loading/unloading.
from kserve import ModelServer, ModelRepository, Model
from fastapi import FastAPI, HTTPException
import os
class DynamicModel(Model):
def __init__(self, name: str, model_path: str):
super().__init__(name)
self.model_path = model_path
def load(self):
if not os.path.exists(self.model_path):
raise FileNotFoundError(f"Model not found: {self.model_path}")
self.model = joblib.load(self.model_path)
self.ready = True
logger.info(f"Loaded {self.name} from {self.model_path}")
def predict(self, payload, headers=None):
return {"predictions": self.model.predict(payload["instances"])}
# Create repository
repository = ModelRepository(models_dir="/mnt/models")
# Create server with repository
server = ModelServer(registered_models=repository)
app = server.create_application()
# Custom endpoints for model management
@app.post("/v1/models/{model_name}/load")
async def load_model(model_name: str, model_path: str):
"""Load a model dynamically"""
try:
model = DynamicModel(model_name, model_path)
model.load()
repository.update(model)
return {"status": "loaded", "model": model_name}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.delete("/v1/models/{model_name}")
async def unload_model(model_name: str):
"""Unload a model"""
try:
repository.unload(model_name)
return {"status": "unloaded", "model": model_name}
except Exception as e:
raise HTTPException(status_code=404, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List all loaded models"""
models = repository.get_models()
return {
"models": [
{
"name": name,
"ready": model.ready
}
for name, model in models.items()
]
}
if __name__ == "__main__":
server.start([])Complete Kubernetes deployment with auto-scaling configuration.
from kserve import (
KServeClient,
V1beta1InferenceService,
V1beta1InferenceServiceSpec,
V1beta1PredictorSpec,
V1beta1SKLearnSpec
)
def deploy_autoscaling_model():
"""Deploy model with auto-scaling"""
client = KServeClient()
isvc = V1beta1InferenceService(
api_version="serving.kserve.io/v1beta1",
kind="InferenceService",
metadata={
"name": "sklearn-iris",
"namespace": "default",
"annotations": {
"autoscaling.knative.dev/target": "100",
"autoscaling.knative.dev/metric": "concurrency"
}
},
spec=V1beta1InferenceServiceSpec(
predictor=V1beta1PredictorSpec(
min_replicas=2,
max_replicas=10,
sklearn=V1beta1SKLearnSpec(
storage_uri="gs://models/sklearn/iris",
protocol_version="v2",
resources={
"limits": {"cpu": "2", "memory": "4Gi"},
"requests": {"cpu": "1", "memory": "2Gi"}
},
env=[
{"name": "WORKERS", "value": "2"},
{"name": "MAX_THREADS", "value": "4"}
]
),
timeout=60,
logger={
"mode": "all",
"url": "http://logger-service:8080"
}
)
)
)
# Create and wait for ready
logger.info("Creating InferenceService...")
client.create(isvc, namespace="default", watch=True, timeout_seconds=300)
# Get status
status = client.get_isvc_status("sklearn-iris", namespace="default")
logger.info(f"InferenceService ready at: {status.url}")
return status.url
if __name__ == "__main__":
url = deploy_autoscaling_model()
print(f"Model deployed at: {url}")Production model with Redis caching and comprehensive monitoring.
from kserve import Model, ModelServer, logger
from kserve.metrics import PREDICT_HIST_TIME, get_labels
from prometheus_client import Counter, Histogram
import redis
import json
import time
# Custom metrics
CACHE_HITS = Counter('cache_hits_total', 'Cache hits', ['model_name'])
CACHE_MISSES = Counter('cache_misses_total', 'Cache misses', ['model_name'])
CACHE_LATENCY = Histogram('cache_latency_seconds', 'Cache lookup latency', ['model_name'])
class CachedMonitoredModel(Model):
def __init__(self, name: str, redis_host: str = "localhost", cache_ttl: int = 3600):
super().__init__(name)
self.redis_client = redis.Redis(host=redis_host, decode_responses=True)
self.cache_ttl = cache_ttl
def load(self):
self.model = joblib.load("/mnt/models/model.pkl")
self.ready = True
def predict(self, payload, headers=None):
instances = payload["instances"]
cache_key = f"pred:{self.name}:{json.dumps(instances)}"
# Check cache
cache_start = time.time()
cached = self.redis_client.get(cache_key)
cache_elapsed = time.time() - cache_start
CACHE_LATENCY.labels(model_name=self.name).observe(cache_elapsed)
if cached:
CACHE_HITS.labels(model_name=self.name).inc()
logger.debug(f"Cache hit for {cache_key}")
return json.loads(cached)
# Cache miss
CACHE_MISSES.labels(model_name=self.name).inc()
logger.debug(f"Cache miss for {cache_key}")
# Run inference
labels = get_labels(self.name)
start = time.time()
predictions = self.model.predict(instances)
result = {"predictions": predictions.tolist()}
elapsed = time.time() - start
PREDICT_HIST_TIME.labels(**labels).observe(elapsed)
# Store in cache
self.redis_client.setex(cache_key, self.cache_ttl, json.dumps(result))
return result
if __name__ == "__main__":
model = CachedMonitoredModel("cached-model", redis_host="redis:6379")
model.load()
ModelServer(enable_docs_url=True).start([model])Optimized batch inference with request accumulation.
from kserve import Model, ModelServer
import asyncio
from typing import List, Tuple
import time
class BatchInferenceModel(Model):
def __init__(self, name: str, max_batch_size: int = 64, max_wait_ms: int = 50):
super().__init__(name)
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms / 1000.0
self.batch_queue: List[Tuple[dict, asyncio.Future]] = []
self.batch_lock = asyncio.Lock()
self.batch_task = None
def load(self):
self.model = joblib.load("/mnt/models/model.pkl")
self.ready = True
# Start batch processor
self.batch_task = asyncio.create_task(self._batch_processor())
async def predict(self, payload, headers=None):
"""Add request to batch queue"""
future = asyncio.Future()
async with self.batch_lock:
self.batch_queue.append((payload, future))
# Process immediately if batch is full
if len(self.batch_queue) >= self.max_batch_size:
await self._process_batch()
return await future
async def _batch_processor(self):
"""Background task to process batches"""
while True:
await asyncio.sleep(self.max_wait_ms)
async with self.batch_lock:
if self.batch_queue:
await self._process_batch()
async def _process_batch(self):
"""Process accumulated batch"""
if not self.batch_queue:
return
# Extract batch
batch = self.batch_queue[:]
self.batch_queue.clear()
logger.info(f"Processing batch of {len(batch)} requests")
# Combine instances
all_instances = []
batch_sizes = []
for payload, _ in batch:
instances = payload["instances"]
all_instances.extend(instances)
batch_sizes.append(len(instances))
# Run batch inference
try:
predictions = self.model.predict(all_instances)
# Distribute results
idx = 0
for (payload, future), size in zip(batch, batch_sizes):
batch_preds = predictions[idx:idx+size]
future.set_result({"predictions": batch_preds.tolist()})
idx += size
except Exception as e:
# Set exception for all futures
for _, future in batch:
if not future.done():
future.set_exception(e)
def stop(self):
"""Cancel batch processor"""
if self.batch_task:
self.batch_task.cancel()
if __name__ == "__main__":
model = BatchInferenceModel("batch-model", max_batch_size=64, max_wait_ms=50)
model.load()
ModelServer().start([model])