tessl install tessl/pypi-kserve@0.16.1KServe is a comprehensive Python SDK that provides standardized interfaces for building and deploying machine learning model serving infrastructure on Kubernetes.
Comprehensive coverage of edge cases, error scenarios, and advanced patterns for robust KServe implementations.
from kserve import Model
from kserve.errors import InvalidInput
import numpy as np
class PartialFailureModel(Model):
def predict(self, payload, headers=None):
"""Handle partial batch failures gracefully"""
instances = payload["instances"]
predictions = []
errors = []
for idx, instance in enumerate(instances):
try:
# Validate instance
if len(instance) != 4:
raise ValueError(f"Invalid shape")
# Predict
pred = self.model.predict([instance])[0]
predictions.append({
"index": idx,
"prediction": pred.tolist(),
"status": "success"
})
except Exception as e:
logger.warning(f"Instance {idx} failed: {e}")
predictions.append({
"index": idx,
"prediction": None,
"status": "error",
"error": str(e)
})
errors.append(idx)
return {
"predictions": predictions,
"failed_indices": errors,
"success_count": len(instances) - len(errors),
"failure_count": len(errors)
}from kserve import Model, logger
import time
class ResilientLoadModel(Model):
def __init__(self, name: str, max_load_attempts: int = 3):
super().__init__(name)
self.max_load_attempts = max_load_attempts
def load(self):
"""Load with retry logic"""
for attempt in range(self.max_load_attempts):
try:
logger.info(f"Loading {self.name} (attempt {attempt + 1}/{self.max_load_attempts})")
self.model = joblib.load("/mnt/models/model.pkl")
self.ready = True
logger.info(f"Model {self.name} loaded successfully")
return
except FileNotFoundError as e:
logger.error(f"Model file not found: {e}")
if attempt == self.max_load_attempts - 1:
raise
time.sleep(5) # Wait before retry
except Exception as e:
logger.error(f"Load failed: {e}", exc_info=True)
if attempt == self.max_load_attempts - 1:
raise
time.sleep(2 ** attempt) # Exponential backoff
raise RuntimeError(f"Failed to load model after {self.max_load_attempts} attempts")from kserve import Model
from kserve.errors import InvalidInput
import numpy as np
class RobustInputModel(Model):
def predict(self, payload, headers=None):
"""Handle corrupted or invalid input data"""
instances = payload.get("instances")
if instances is None:
raise InvalidInput("Missing 'instances' field")
# Validate and sanitize
sanitized = []
for idx, instance in enumerate(instances):
try:
# Convert to numpy array
arr = np.array(instance, dtype=np.float32)
# Check for NaN or Inf
if np.isnan(arr).any():
raise ValueError("Contains NaN values")
if np.isinf(arr).any():
raise ValueError("Contains Inf values")
# Check shape
if arr.shape != (4,):
raise ValueError(f"Invalid shape {arr.shape}, expected (4,)")
# Check value ranges
if (arr < -10).any() or (arr > 10).any():
raise ValueError("Values out of expected range [-10, 10]")
sanitized.append(arr)
except Exception as e:
raise InvalidInput(f"Instance {idx} is invalid: {e}")
# Run prediction
predictions = self.model.predict(np.array(sanitized))
return {"predictions": predictions.tolist()}from kserve import Model, logger
import psutil
import gc
class MemoryAwareModel(Model):
def __init__(self, name: str, memory_threshold: float = 80.0):
super().__init__(name)
self.memory_threshold = memory_threshold
def predict(self, payload, headers=None):
"""Monitor memory and trigger cleanup if needed"""
# Check memory before prediction
memory = psutil.virtual_memory()
if memory.percent > self.memory_threshold:
logger.warning(
f"Memory usage high ({memory.percent}%), "
f"triggering garbage collection"
)
gc.collect()
# Check again
memory = psutil.virtual_memory()
if memory.percent > 90:
raise InferenceError(
f"Memory usage critical ({memory.percent}%), "
f"refusing request"
)
# Run prediction
try:
result = self.model.predict(payload["instances"])
return {"predictions": result.tolist()}
finally:
# Cleanup after prediction
gc.collect()from kserve import Model
import shutil
import os
class DiskAwareModel(Model):
def load(self):
"""Check disk space before loading"""
model_path = "/mnt/models/model.pkl"
model_size = os.path.getsize(model_path)
# Check available disk space
stat = shutil.disk_usage("/mnt/models")
available_gb = stat.free / (1024**3)
required_gb = model_size * 2 / (1024**3) # 2x for safety
if available_gb < required_gb:
raise RuntimeError(
f"Insufficient disk space: {available_gb:.2f}GB available, "
f"{required_gb:.2f}GB required"
)
self.model = joblib.load(model_path)
self.ready = Truefrom kserve import Model
import threading
import copy
class ThreadSafeModel(Model):
def __init__(self, name: str):
super().__init__(name)
self.model_lock = threading.RLock()
self.update_lock = threading.Lock()
def load(self):
"""Thread-safe model loading"""
with self.model_lock:
self.model = joblib.load("/mnt/models/model.pkl")
self.ready = True
def predict(self, payload, headers=None):
"""Thread-safe prediction"""
# Acquire read lock
with self.model_lock:
# Make a copy to avoid race conditions
model_snapshot = self.model
# Run prediction without holding lock
predictions = model_snapshot.predict(payload["instances"])
return {"predictions": predictions.tolist()}
def update_model(self, new_model_path: str):
"""Update model safely during serving"""
# Load new model
new_model = joblib.load(new_model_path)
# Atomic swap
with self.model_lock:
old_model = self.model
self.model = new_model
logger.info(f"Model {self.name} updated successfully")
# Cleanup old model
del old_model
gc.collect()from kserve import ModelRepository
import asyncio
class SafeModelRepository:
def __init__(self, repository: ModelRepository):
self.repository = repository
self.locks = {}
self.global_lock = asyncio.Lock()
async def get_model_safe(self, name: str):
"""Get model with locking"""
async with self.global_lock:
if name not in self.locks:
self.locks[name] = asyncio.Lock()
async with self.locks[name]:
return self.repository.get_model(name)
async def update_model_safe(self, model, name: str = None):
"""Update model with locking"""
model_name = name or model.name
async with self.global_lock:
if model_name not in self.locks:
self.locks[model_name] = asyncio.Lock()
async with self.locks[model_name]:
self.repository.update(model, name)from kserve import InferenceRESTClient, RESTConfig
import httpx
import asyncio
async def resilient_inference(
base_url: str,
model_name: str,
data: dict,
max_retries: int = 3,
timeout: float = 30.0
):
"""Make inference with connection failure handling"""
config = RESTConfig(protocol="v2", timeout=timeout, retries=max_retries)
client = InferenceRESTClient(config=config)
for attempt in range(max_retries):
try:
response = await client.infer(
base_url=base_url,
model_name=model_name,
data=data
)
await client.close()
return response
except httpx.ConnectError as e:
logger.warning(f"Connection failed (attempt {attempt + 1}): {e}")
if attempt == max_retries - 1:
raise
await asyncio.sleep(2 ** attempt)
except httpx.TimeoutException as e:
logger.warning(f"Request timeout (attempt {attempt + 1}): {e}")
if attempt == max_retries - 1:
raise
await asyncio.sleep(1)
except httpx.HTTPStatusError as e:
# Don't retry on 4xx errors
if 400 <= e.response.status_code < 500:
raise
logger.warning(f"HTTP error {e.response.status_code} (attempt {attempt + 1})")
if attempt == max_retries - 1:
raise
await asyncio.sleep(2 ** attempt)from kserve import Model
import asyncio
class TimeoutAwareModel(Model):
def __init__(self, name: str, predictor_timeout: float = 30.0):
super().__init__(name)
self.predictor_timeout = predictor_timeout
async def predict(self, payload, headers=None):
"""Predict with timeout protection"""
try:
# Run prediction with timeout
result = await asyncio.wait_for(
self._run_prediction(payload),
timeout=self.predictor_timeout
)
return result
except asyncio.TimeoutError:
logger.error(f"Prediction timeout after {self.predictor_timeout}s")
raise InferenceError(
f"Prediction timeout after {self.predictor_timeout}s"
)
async def _run_prediction(self, payload):
"""Actual prediction logic"""
instances = payload["instances"]
predictions = self.model.predict(instances)
return {"predictions": predictions.tolist()}from kserve import Model, InferInput, InferRequest, InferResponse, InferOutput
import numpy as np
class MixedTypeModel(Model):
def predict(self, payload, headers=None):
"""Handle mixed data types in request"""
if isinstance(payload, InferRequest):
# v2 protocol with typed inputs
results = []
for input_tensor in payload.inputs:
if input_tensor.datatype == "FP32":
data = input_tensor.as_numpy()
pred = self.model.predict(data)
elif input_tensor.datatype == "BYTES":
# Handle string/bytes data
strings = input_tensor.as_string()
pred = self.text_model.predict(strings)
elif input_tensor.datatype == "INT64":
# Handle integer data
data = input_tensor.as_numpy()
pred = self.int_model.predict(data)
else:
raise InvalidInput(f"Unsupported datatype: {input_tensor.datatype}")
results.append(pred)
# Create response
outputs = [
InferOutput(
name=f"output-{i}",
shape=list(result.shape),
datatype="FP32"
)
for i, result in enumerate(results)
]
for output, result in zip(outputs, results):
output.set_data_from_numpy(result)
return InferResponse(
model_name=self.name,
infer_outputs=outputs
)
else:
# v1 protocol
return {"predictions": self.model.predict(payload["instances"])}from kserve import Model, InferInput
import numpy as np
class LargeTensorModel(Model):
def __init__(self, name: str, max_tensor_size_mb: int = 100):
super().__init__(name)
self.max_tensor_size_mb = max_tensor_size_mb
def predict(self, payload, headers=None):
"""Handle large tensor inputs"""
if isinstance(payload, InferRequest):
for input_tensor in payload.inputs:
# Check tensor size
data = input_tensor.as_numpy()
size_mb = data.nbytes / (1024 * 1024)
if size_mb > self.max_tensor_size_mb:
raise InvalidInput(
f"Tensor size {size_mb:.2f}MB exceeds maximum "
f"{self.max_tensor_size_mb}MB"
)
# Process in chunks if needed
if size_mb > 50:
logger.info(f"Processing large tensor ({size_mb:.2f}MB) in chunks")
return self._predict_chunked(data)
# Normal prediction
return {"predictions": self.model.predict(payload["instances"])}
def _predict_chunked(self, data: np.ndarray, chunk_size: int = 1000):
"""Process large tensor in chunks"""
results = []
for i in range(0, len(data), chunk_size):
chunk = data[i:i+chunk_size]
chunk_result = self.model.predict(chunk)
results.append(chunk_result)
return {"predictions": np.concatenate(results).tolist()}from kserve import Model, InferInput
import numpy as np
class BinaryDataModel(Model):
def predict(self, payload, headers=None):
"""Handle binary data correctly"""
if isinstance(payload, InferRequest):
for input_tensor in payload.inputs:
if input_tensor.datatype == "BYTES":
# Handle BYTES datatype
try:
strings = input_tensor.as_string()
# Process strings
result = self.text_model.predict(strings)
except Exception as e:
# Fallback: treat as raw bytes
logger.warning(f"Failed to decode as strings: {e}")
raw_data = input_tensor.data
result = self.binary_model.predict(raw_data)
return {"predictions": result}
return {"predictions": self.model.predict(payload["instances"])}from kserve import Model, InferRequest
from kserve.errors import UnsupportedProtocol
class ProtocolAwareModel(Model):
def predict(self, payload, headers=None):
"""Handle both v1 and v2 protocols"""
# Detect protocol
if isinstance(payload, InferRequest):
# v2 protocol
logger.debug("Using v2 protocol")
return self._predict_v2(payload)
elif isinstance(payload, dict):
if "inputs" in payload:
# v2 dict format
logger.debug("Using v2 dict format")
return self._predict_v2_dict(payload)
elif "instances" in payload:
# v1 format
logger.debug("Using v1 protocol")
return self._predict_v1(payload)
else:
raise InvalidInput("Unknown payload format")
else:
raise UnsupportedProtocol("Unsupported payload type")
def _predict_v1(self, payload):
instances = payload["instances"]
predictions = self.model.predict(instances)
return {"predictions": predictions.tolist()}
def _predict_v2(self, payload: InferRequest):
input_data = payload.inputs[0].as_numpy()
predictions = self.model.predict(input_data)
output = InferOutput(
name="predictions",
shape=list(predictions.shape),
datatype="FP32"
)
output.set_data_from_numpy(predictions)
return InferResponse(
model_name=self.name,
infer_outputs=[output]
)
def _predict_v2_dict(self, payload):
# Extract from dict format
input_data = payload["inputs"][0]["data"]
predictions = self.model.predict(input_data)
return {"outputs": [{"data": predictions.tolist()}]}from kserve import Model
from kserve.utils import is_structured_cloudevent
class CloudEventModel(Model):
def predict(self, payload, headers=None):
"""Handle CloudEvents format"""
# Check if CloudEvent
if headers and is_structured_cloudevent(headers):
logger.info("Processing CloudEvent request")
# Extract data from CloudEvent
if "data" in payload:
actual_payload = payload["data"]
else:
actual_payload = payload
# Process
instances = actual_payload.get("instances", [])
predictions = self.model.predict(instances)
# Return as CloudEvent
return {
"specversion": "1.0",
"type": "org.kserve.inference.response",
"source": f"model/{self.name}",
"id": headers.get("ce-id", "unknown"),
"data": {"predictions": predictions.tolist()}
}
else:
# Regular request
instances = payload["instances"]
predictions = self.model.predict(instances)
return {"predictions": predictions.tolist()}from kserve import Model
import asyncio
class CancellableModel(Model):
async def predict(self, payload, headers=None):
"""Handle cancellation gracefully"""
try:
# Long-running operation
instances = payload["instances"]
# Check for cancellation periodically
result = await self._predict_with_cancellation_check(instances)
return {"predictions": result.tolist()}
except asyncio.CancelledError:
logger.warning(f"Prediction cancelled for {self.name}")
# Cleanup resources
self._cleanup()
raise
async def _predict_with_cancellation_check(self, instances):
"""Run prediction with cancellation checks"""
# Split into chunks
chunk_size = 100
results = []
for i in range(0, len(instances), chunk_size):
# Check if cancelled
if asyncio.current_task().cancelled():
raise asyncio.CancelledError()
chunk = instances[i:i+chunk_size]
chunk_result = await asyncio.to_thread(self.model.predict, chunk)
results.append(chunk_result)
return np.concatenate(results)from kserve import Model
import asyncio
import concurrent.futures
class EventLoopSafeModel(Model):
def __init__(self, name: str):
super().__init__(name)
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def predict(self, payload, headers=None):
"""Run blocking code safely in event loop"""
instances = payload["instances"]
# Run blocking prediction in executor
loop = asyncio.get_event_loop()
try:
predictions = await loop.run_in_executor(
self.executor,
self.model.predict,
instances
)
return {"predictions": predictions.tolist()}
except Exception as e:
logger.error(f"Prediction failed: {e}", exc_info=True)
raisefrom kserve import KServeClient, V1beta1PredictorSpec
from kubernetes.client.rest import ApiException
def deploy_with_anti_affinity():
"""Deploy with pod anti-affinity for high availability"""
client = KServeClient()
predictor = V1beta1PredictorSpec(
min_replicas=3,
sklearn={
"storageUri": "gs://models/sklearn/iris"
},
affinity={
"podAntiAffinity": {
"requiredDuringSchedulingIgnoredDuringExecution": [{
"labelSelector": {
"matchExpressions": [{
"key": "serving.kserve.io/inferenceservice",
"operator": "In",
"values": ["sklearn-iris"]
}]
},
"topologyKey": "kubernetes.io/hostname"
}]
}
}
)from kserve import Model
import os
import time
class StorageResilientModel(Model):
def load(self):
"""Load with storage failure retry"""
model_path = "/mnt/models/model.pkl"
max_attempts = 5
for attempt in range(max_attempts):
try:
# Check if file exists
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model not found: {model_path}")
# Check if file is complete (not being written)
initial_size = os.path.getsize(model_path)
time.sleep(1)
final_size = os.path.getsize(model_path)
if initial_size != final_size:
logger.warning("Model file still being written")
time.sleep(5)
continue
# Load model
self.model = joblib.load(model_path)
self.ready = True
return
except Exception as e:
logger.warning(f"Load attempt {attempt + 1} failed: {e}")
if attempt == max_attempts - 1:
raise
time.sleep(5)from kserve import InferenceGRPCClient
import grpc
async def create_grpc_client_with_large_messages():
"""Create gRPC client with increased message size limits"""
# Configure channel options for large messages
channel_args = [
('grpc.max_send_message_length', 100 * 1024 * 1024), # 100MB
('grpc.max_receive_message_length', 100 * 1024 * 1024), # 100MB
('grpc.keepalive_time_ms', 30000),
('grpc.keepalive_timeout_ms', 10000),
('grpc.keepalive_permit_without_calls', True),
('grpc.http2.max_pings_without_data', 0)
]
client = InferenceGRPCClient(
url="localhost:8081",
channel_args=channel_args,
timeout=120
)
return clientfrom kserve import InferenceGRPCClient
import grpc
import asyncio
async def grpc_with_retry(url: str, request, max_retries: int = 3):
"""gRPC inference with retry on connection failures"""
for attempt in range(max_retries):
client = None
try:
client = InferenceGRPCClient(url=url, retries=0)
response = await client.infer(request)
await client.close()
return response
except grpc.RpcError as e:
if e.code() == grpc.StatusCode.UNAVAILABLE:
logger.warning(f"gRPC unavailable (attempt {attempt + 1})")
if attempt < max_retries - 1:
await asyncio.sleep(2 ** attempt)
continue
raise
finally:
if client:
await client.close()from kserve import Model
import threading
class VersionTransitionModel(Model):
def __init__(self, name: str):
super().__init__(name)
self.current_version = None
self.transition_lock = threading.Lock()
self.in_transition = False
def load(self):
"""Load initial version"""
self.model = joblib.load("/mnt/models/v1/model.pkl")
self.current_version = "v1"
self.ready = True
def predict(self, payload, headers=None):
"""Predict with version awareness"""
# Wait if in transition
with self.transition_lock:
if self.in_transition:
logger.info("Model transition in progress, waiting...")
# Run prediction
predictions = self.model.predict(payload["instances"])
return {
"predictions": predictions.tolist(),
"model_version": self.current_version
}
def transition_to_version(self, new_version: str, model_path: str):
"""Transition to new model version"""
with self.transition_lock:
self.in_transition = True
try:
logger.info(f"Transitioning from {self.current_version} to {new_version}")
# Load new model
new_model = joblib.load(model_path)
# Atomic swap
old_model = self.model
self.model = new_model
self.current_version = new_version
logger.info(f"Transition to {new_version} complete")
# Cleanup
del old_model
gc.collect()
finally:
self.in_transition = Falsefrom kserve import Model, get_labels
from kserve.metrics import PREDICT_HIST_TIME
import time
class SafeMetricsModel(Model):
def predict(self, payload, headers=None):
"""Predict with safe metrics collection"""
start_time = time.time()
try:
# Run prediction
result = self.model.predict(payload["instances"])
# Try to record metrics
try:
elapsed = time.time() - start_time
labels = get_labels(self.name)
PREDICT_HIST_TIME.labels(**labels).observe(elapsed)
except Exception as e:
# Don't fail prediction if metrics fail
logger.warning(f"Failed to record metrics: {e}")
return {"predictions": result.tolist()}
except Exception as e:
# Still try to record error metrics
try:
elapsed = time.time() - start_time
labels = get_labels(self.name)
PREDICT_HIST_TIME.labels(**labels).observe(elapsed)
except:
pass
raisefrom kserve import Model
from kserve.errors import InvalidInput
import re
class SecureModel(Model):
def predict(self, payload, headers=None):
"""Validate and sanitize input"""
instances = payload.get("instances", [])
# Check batch size limits
if len(instances) > 1000:
raise InvalidInput("Batch size exceeds security limit of 1000")
# Validate each instance
for idx, instance in enumerate(instances):
# Check for SQL injection patterns (if storing)
if isinstance(instance, str):
if re.search(r"(--|;|'|\"|\\|/\*|\*/|xp_|sp_)", instance, re.IGNORECASE):
raise InvalidInput(f"Instance {idx} contains potentially malicious content")
# Check for path traversal
if isinstance(instance, str) and (".." in instance or "/" in instance):
raise InvalidInput(f"Instance {idx} contains path traversal characters")
# Validate numeric ranges
if isinstance(instance, (list, np.ndarray)):
arr = np.array(instance)
if np.abs(arr).max() > 1e6:
raise InvalidInput(f"Instance {idx} contains extremely large values")
# Run prediction
return {"predictions": self.model.predict(instances)}from kserve import Model
from kserve.errors import InferenceError
import time
from collections import defaultdict
class RateLimitedModel(Model):
def __init__(self, name: str, max_requests_per_minute: int = 100):
super().__init__(name)
self.max_requests_per_minute = max_requests_per_minute
self.request_counts = defaultdict(list)
def predict(self, payload, headers=None):
"""Enforce rate limiting"""
# Get client identifier
client_id = headers.get("x-client-id", "default") if headers else "default"
# Clean old timestamps
now = time.time()
self.request_counts[client_id] = [
ts for ts in self.request_counts[client_id]
if now - ts < 60
]
# Check rate limit
if len(self.request_counts[client_id]) >= self.max_requests_per_minute:
raise InferenceError(
f"Rate limit exceeded: {self.max_requests_per_minute} requests per minute"
)
# Record request
self.request_counts[client_id].append(now)
# Run prediction
return {"predictions": self.model.predict(payload["instances"])}from kserve import Model
import weakref
class ResourceManagedModel(Model):
def __init__(self, name: str):
super().__init__(name)
self.resources = []
self._finalizer = weakref.finalize(self, self._cleanup_resources, self.resources)
def load(self):
"""Load with resource tracking"""
# Open resources
self.model = joblib.load("/mnt/models/model.pkl")
self.db_connection = create_db_connection()
self.resources.append(self.db_connection)
self.ready = True
def predict(self, payload, headers=None):
return {"predictions": self.model.predict(payload["instances"])}
def stop(self):
"""Explicit cleanup"""
self._cleanup_resources(self.resources)
@staticmethod
def _cleanup_resources(resources):
"""Cleanup all resources"""
for resource in resources:
try:
if hasattr(resource, 'close'):
resource.close()
except Exception as e:
logger.error(f"Failed to close resource: {e}")