tessl install tessl/pypi-kserve@0.16.1KServe is a comprehensive Python SDK that provides standardized interfaces for building and deploying machine learning model serving infrastructure on Kubernetes.
Techniques and patterns for optimizing KServe model server performance including caching, batching, async processing, and resource optimization.
from functools import lru_cache
from kserve import Model
import hashlib
import json
class CachedModel(Model):
def __init__(self, name: str, cache_size: int = 1000):
super().__init__(name)
self.cache_size = cache_size
@lru_cache(maxsize=1000)
def _cached_predict(self, input_hash: str):
"""Cache predictions by input hash"""
# Retrieve original input from hash (stored separately)
instances = self._get_input_from_hash(input_hash)
predictions = self.model.predict(instances)
return tuple(predictions.tolist()[0])
def _hash_input(self, instances):
"""Create hash of input for caching"""
input_str = json.dumps(instances, sort_keys=True)
return hashlib.md5(input_str.encode()).hexdigest()
def predict(self, payload, headers=None):
instances = payload["instances"]
predictions = []
for instance in instances:
input_hash = self._hash_input(instance)
pred = self._cached_predict(input_hash)
predictions.append(list(pred))
return {"predictions": predictions}from kserve import Model
import redis
import json
import pickle
class RedisCachedModel(Model):
def __init__(self, name: str, redis_host: str = "localhost", ttl: int = 3600):
super().__init__(name)
self.redis_client = redis.Redis(host=redis_host, decode_responses=False)
self.ttl = ttl # Cache TTL in seconds
def predict(self, payload, headers=None):
instances = payload["instances"]
# Create cache key
cache_key = f"prediction:{self.name}:{json.dumps(instances)}"
# Check cache
cached = self.redis_client.get(cache_key)
if cached:
logger.debug(f"Cache hit for {cache_key}")
return pickle.loads(cached)
# Cache miss - run inference
logger.debug(f"Cache miss for {cache_key}")
predictions = self.model.predict(instances)
result = {"predictions": predictions.tolist()}
# Store in cache
self.redis_client.setex(
cache_key,
self.ttl,
pickle.dumps(result)
)
return resultfrom kserve import Model
import asyncio
from typing import List, Tuple
import time
class BatchingModel(Model):
def __init__(self, name: str, max_batch_size: int = 32, max_latency_ms: int = 100):
super().__init__(name)
self.max_batch_size = max_batch_size
self.max_latency_ms = max_latency_ms / 1000.0 # Convert to seconds
self.batch_queue = []
self.batch_lock = asyncio.Lock()
self.batch_event = asyncio.Event()
async def predict(self, payload, headers=None):
"""Batch requests for improved throughput"""
future = asyncio.Future()
async with self.batch_lock:
self.batch_queue.append((payload, future))
# Trigger batch processing if full
if len(self.batch_queue) >= self.max_batch_size:
asyncio.create_task(self._process_batch())
# Wait for result with timeout
try:
return await asyncio.wait_for(future, timeout=self.max_latency_ms)
except asyncio.TimeoutError:
# Process batch on timeout
asyncio.create_task(self._process_batch())
return await future
async def _process_batch(self):
"""Process accumulated batch"""
async with self.batch_lock:
if not self.batch_queue:
return
# Extract batch
batch = self.batch_queue[:]
self.batch_queue.clear()
# Combine all instances
all_instances = []
for payload, _ in batch:
all_instances.extend(payload["instances"])
# Run batch inference
try:
predictions = self.model.predict(all_instances)
# Distribute results
idx = 0
for payload, future in batch:
batch_size = len(payload["instances"])
batch_predictions = predictions[idx:idx+batch_size]
future.set_result({"predictions": batch_predictions.tolist()})
idx += batch_size
except Exception as e:
# Set exception for all futures
for _, future in batch:
future.set_exception(e)import asyncio
from concurrent.futures import ThreadPoolExecutor
from kserve import Model
class ThreadPoolModel(Model):
def __init__(self, name: str, max_workers: int = 4):
super().__init__(name)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
async def predict(self, payload, headers=None):
"""Run CPU-bound inference in thread pool"""
instances = payload["instances"]
# Offload to thread pool
loop = asyncio.get_event_loop()
predictions = await loop.run_in_executor(
self.executor,
self.model.predict,
instances
)
return {"predictions": predictions.tolist()}
def stop(self):
"""Cleanup executor"""
self.executor.shutdown(wait=True)import asyncio
from concurrent.futures import ProcessPoolExecutor
from kserve import Model
def predict_in_process(model_path, instances):
"""Prediction function for process pool"""
import joblib
model = joblib.load(model_path)
return model.predict(instances).tolist()
class ProcessPoolModel(Model):
def __init__(self, name: str, model_path: str, max_workers: int = 4):
super().__init__(name)
self.model_path = model_path
self.executor = ProcessPoolExecutor(max_workers=max_workers)
async def predict(self, payload, headers=None):
"""Run inference in process pool"""
instances = payload["instances"]
# Offload to process pool
loop = asyncio.get_event_loop()
predictions = await loop.run_in_executor(
self.executor,
predict_in_process,
self.model_path,
instances
)
return {"predictions": predictions}
def stop(self):
"""Cleanup executor"""
self.executor.shutdown(wait=True)from kserve import Model
class LazyLoadModel(Model):
def __init__(self, name: str, model_path: str):
super().__init__(name)
self.model_path = model_path
self.model = None
self._ready = False
def load(self):
"""Mark as ready without loading"""
self._ready = True
def predict(self, payload, headers=None):
# Load model on first prediction
if self.model is None:
logger.info(f"Lazy loading model {self.name}")
self.model = joblib.load(self.model_path)
return {"predictions": self.model.predict(payload["instances"])}from kserve import Model
import torch
class QuantizedModel(Model):
def load(self):
"""Load quantized model for reduced memory"""
# Load full precision model
model = torch.load("/mnt/models/model.pth")
# Quantize to int8
self.model = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
self.ready = True
logger.info(f"Model {self.name} loaded and quantized")from kserve import InferenceRESTClient, RESTConfig
import httpx
# Configure connection pooling
transport = httpx.AsyncHTTPTransport(
limits=httpx.Limits(
max_keepalive_connections=20,
max_connections=100,
keepalive_expiry=30
)
)
config = RESTConfig(
transport=transport,
protocol="v2",
http2=True # Enable HTTP/2 for multiplexing
)
client = InferenceRESTClient(config=config)from kserve import InferInput
import numpy as np
# Use binary encoding for large tensors
large_tensor = np.random.rand(1000, 1000).astype(np.float32)
input_tensor = InferInput(
name="input-0",
shape=list(large_tensor.shape),
datatype="FP32"
)
# Binary encoding reduces payload size significantly
input_tensor.set_data_from_numpy(large_tensor, binary_data=True)from kserve import ModelServer
from kserve.utils import cpu_count
# Calculate optimal workers
num_cpus = cpu_count()
num_workers = min(num_cpus * 2, 8) # 2x CPUs, max 8
server = ModelServer(
http_port=8080,
workers=num_workers,
max_threads=4,
max_asyncio_workers=100
)# Start with multiple workers
python model.py --workers 4 --max_threads 8
# Monitor worker health
ps aux | grep "python model.py"
# Graceful restart
kill -TERM <pid> # Sends SIGTERM for graceful shutdownfrom kserve import Model
import torch
class GPUModel(Model):
def load(self):
"""Load model on GPU"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Loading model on {device}")
self.model = torch.load("/mnt/models/model.pth")
self.model = self.model.to(device)
self.model.eval()
self.device = device
self.ready = True
def predict(self, payload, headers=None):
instances = torch.tensor(payload["instances"]).to(self.device)
with torch.no_grad():
predictions = self.model(instances)
return {"predictions": predictions.cpu().numpy().tolist()}import torch
from kserve import Model
class MultiGPUModel(Model):
def load(self):
"""Load model across multiple GPUs"""
if torch.cuda.device_count() > 1:
logger.info(f"Using {torch.cuda.device_count()} GPUs")
self.model = torch.nn.DataParallel(self.model)
self.model = self.model.cuda()
self.ready = Trueimport cProfile
import pstats
from kserve import Model
class ProfiledModel(Model):
def predict(self, payload, headers=None):
# Profile prediction
profiler = cProfile.Profile()
profiler.enable()
result = self.model.predict(payload["instances"])
profiler.disable()
# Print stats
stats = pstats.Stats(profiler)
stats.sort_stats('cumulative')
stats.print_stats(10) # Top 10 functions
return {"predictions": result.tolist()}from memory_profiler import profile
from kserve import Model
class MemoryProfiledModel(Model):
@profile
def load(self):
"""Profile memory usage during load"""
self.model = load_large_model()
self.ready = True
@profile
def predict(self, payload, headers=None):
"""Profile memory usage during inference"""
return {"predictions": self.model.predict(payload["instances"])}from kserve import ModelServer
from fastapi.middleware.gzip import GZipMiddleware
if __name__ == "__main__":
model = MyModel("my-model")
model.load()
server = ModelServer()
app = server.create_application()
# Add compression middleware
app.add_middleware(GZipMiddleware, minimum_size=1000)
server.start([model])from kserve import Model
import psycopg2.pool
class DatabaseModel(Model):
def __init__(self, name: str):
super().__init__(name)
# Create connection pool
self.db_pool = psycopg2.pool.SimpleConnectionPool(
minconn=1,
maxconn=10,
host="localhost",
database="models",
user="kserve",
password="password"
)
def predict(self, payload, headers=None):
# Get connection from pool
conn = self.db_pool.getconn()
try:
# Query model metadata
cursor = conn.cursor()
cursor.execute("SELECT version FROM models WHERE name = %s", (self.name,))
version = cursor.fetchone()[0]
# Run prediction
predictions = self.model.predict(payload["instances"])
return {
"predictions": predictions.tolist(),
"model_version": version
}
finally:
# Return connection to pool
self.db_pool.putconn(conn)
def stop(self):
"""Close all connections"""
self.db_pool.closeall()| Optimization | Latency Improvement | Throughput Improvement |
|---|---|---|
| Caching | 90-99% (cache hits) | 10-100x |
| Batching | 20-40% | 2-5x |
| Async Processing | 30-50% | 2-3x |
| Binary Encoding | 10-20% | 1.2-1.5x |
| Connection Pooling | 5-15% | 1.1-1.3x |
| GPU Acceleration | 50-90% | 5-20x |
| Compression | N/A | 1.5-3x (bandwidth) |