0
# ML Inference Client
1
2
Standardized client framework for machine learning model inference supporting both individual and batch processing with REST and gRPC protocol support. The inference system provides high-performance model serving capabilities with comprehensive request/response modeling and configurable service endpoints for scalable ML operations.
3
4
## Capabilities
5
6
### Inference Client Interface
7
8
Abstract base interface defining standardized inference operations for both single predictions and batch processing with pluggable implementation support for different service protocols.
9
10
```python { .api }
11
class InferenceClient(ABC):
12
"""
13
Interface for inference client.
14
15
Attributes:
16
- _config: InferenceConfig - Inference configuration
17
"""
18
19
def __init__(self) -> None:
20
"""Initialize with InferenceConfig"""
21
...
22
23
@classmethod
24
def __subclasshook__(cls, subclass):
25
"""Class method for subclass checking"""
26
...
27
28
@abstractmethod
29
def infer(self, inference_request: InferenceRequest) -> InferenceResult:
30
"""
31
Abstract method to invoke inference.
32
33
Parameters:
34
- inference_request: InferenceRequest - Request containing inference data
35
36
Returns:
37
InferenceResult - Inference prediction result
38
"""
39
...
40
41
@abstractmethod
42
def infer_batch(self, inference_request_batch: InferenceRequestBatch) -> list[InferenceResultBatch]:
43
"""
44
Abstract method to invoke batch inference.
45
46
Parameters:
47
- inference_request_batch: InferenceRequestBatch - Batch of inference requests
48
49
Returns:
50
list[InferenceResultBatch] - List of batch inference results
51
"""
52
...
53
```
54
55
### Inference Configuration
56
57
Configuration management for inference service endpoints supporting both REST and gRPC protocols with configurable URLs and ports for flexible deployment scenarios.
58
59
```python { .api }
60
class InferenceConfig:
61
"""
62
Configurations for inference.
63
"""
64
65
def __init__(self) -> None:
66
"""Initialize with inference.properties"""
67
...
68
69
def rest_service_url(self) -> str:
70
"""Returns URL of inference REST service (default: "http://localhost")"""
71
...
72
73
def rest_service_port(self) -> str:
74
"""Returns port of inference REST service (default: "7080")"""
75
...
76
77
def grpc_service_url(self) -> str:
78
"""Returns URL of inference gRPC service (default: "http://localhost")"""
79
...
80
81
def grpc_service_port(self) -> str:
82
"""Returns port of inference gRPC service (default: "7081")"""
83
...
84
```
85
86
### Inference Request Models
87
88
Comprehensive data models for capturing inference request information including source metadata, request categorization, and extensible attributes for diverse ML use cases.
89
90
```python { .api }
91
class InferenceRequest:
92
"""
93
Contains details necessary for inference to be invoked.
94
95
Properties (with getters/setters):
96
- source_ip_address: str - Source IP address
97
- created: int - Creation timestamp
98
- kind: str - Request kind
99
- category: str - Request category
100
- outcome: str - Request outcome
101
"""
102
103
def __init__(self, source_ip_address: str = "", created: int = 0, kind: str = "", category: str = "", outcome: str = "") -> None:
104
"""
105
Constructor with default values.
106
107
Parameters:
108
- source_ip_address: str - Source IP address (default: "")
109
- created: int - Creation timestamp (default: 0)
110
- kind: str - Request kind (default: "")
111
- category: str - Request category (default: "")
112
- outcome: str - Request outcome (default: "")
113
"""
114
...
115
116
class InferenceRequestBatch:
117
"""
118
Contains details necessary for inference to be invoked on a batch.
119
120
Properties (with getters/setters):
121
- row_id_key: str - Row ID key
122
- data: list[InferenceRequest] - List of inference requests
123
"""
124
125
def __init__(self, row_id_key: str, data: list[InferenceRequest]) -> None:
126
"""
127
Constructor.
128
129
Parameters:
130
- row_id_key: str - Row identifier key
131
- data: list[InferenceRequest] - List of inference requests
132
"""
133
...
134
```
135
136
### Inference Result Models
137
138
Structured result models for capturing inference predictions including threat detection indicators, confidence scores, and batch processing support with row-level result tracking.
139
140
```python { .api }
141
class InferenceResult:
142
"""
143
Contains details about the results of an inference request.
144
145
Properties (with getters/setters):
146
- threat_detected: bool - Whether threat was detected
147
- score: int - Inference score
148
"""
149
150
def __init__(self, threat_detected: bool = False, score: int = 0) -> None:
151
"""
152
Constructor with default values.
153
154
Parameters:
155
- threat_detected: bool - Whether threat was detected (default: False)
156
- score: int - Inference score (default: 0)
157
"""
158
...
159
160
class InferenceResultBatch:
161
"""
162
Represents a single result of a batch inference.
163
164
Properties (with getters/setters):
165
- row_id_key: str - Row ID key
166
- result: InferenceResult - Inference result
167
"""
168
169
def __init__(self, row_id_key: str, result: InferenceResult) -> None:
170
"""
171
Constructor.
172
173
Parameters:
174
- row_id_key: str - Row identifier key
175
- result: InferenceResult - Inference result for this row
176
"""
177
...
178
```
179
180
### REST Inference Client
181
182
Production-ready REST-based inference client implementation with async HTTP operations, JSON serialization, and comprehensive error handling for scalable model serving.
183
184
```python { .api }
185
class RestInferenceClient(InferenceClient):
186
"""
187
REST-based implementation of InferenceClient.
188
"""
189
190
async def infer(self, inference_request: InferenceRequest) -> InferenceResult:
191
"""
192
Async method for single inference.
193
194
Parameters:
195
- inference_request: InferenceRequest - Request data for inference
196
197
Returns:
198
InferenceResult - Prediction result
199
"""
200
...
201
202
async def infer_batch(self, inference_request_batch: InferenceRequestBatch) -> list[InferenceResultBatch]:
203
"""
204
Async method for batch inference.
205
206
Parameters:
207
- inference_request_batch: InferenceRequestBatch - Batch of requests
208
209
Returns:
210
list[InferenceResultBatch] - List of batch results
211
"""
212
...
213
```
214
215
## Usage Examples
216
217
### Basic Inference Operations
218
219
```python
220
from inference.rest_inference_client import RestInferenceClient
221
from inference.inference_request import InferenceRequest
222
from inference.inference_result import InferenceResult
223
import asyncio
224
from datetime import datetime
225
226
async def basic_inference_example():
227
"""Demonstrate basic inference operations"""
228
229
# Initialize REST inference client
230
client = RestInferenceClient()
231
232
# Create inference request
233
request = InferenceRequest(
234
source_ip_address="192.168.1.100",
235
created=int(datetime.now().timestamp()),
236
kind="security_scan",
237
category="network_traffic",
238
outcome="" # Will be populated by inference
239
)
240
241
try:
242
# Perform single inference
243
result = await client.infer(request)
244
245
print(f"Inference completed:")
246
print(f" Threat detected: {result.threat_detected}")
247
print(f" Confidence score: {result.score}")
248
249
if result.threat_detected:
250
print("⚠️ Security threat identified - taking protective action")
251
else:
252
print("✅ No threats detected - traffic appears safe")
253
254
except Exception as e:
255
print(f"Inference failed: {e}")
256
257
# Run the example
258
asyncio.run(basic_inference_example())
259
```
260
261
### Batch Inference Processing
262
263
```python
264
from inference.rest_inference_client import RestInferenceClient
265
from inference.inference_request import InferenceRequest, InferenceRequestBatch
266
from inference.inference_result import InferenceResultBatch
267
import asyncio
268
import pandas as pd
269
from datetime import datetime
270
271
class BatchInferenceProcessor:
272
"""Process large datasets using batch inference"""
273
274
def __init__(self):
275
self.client = RestInferenceClient()
276
277
async def process_security_logs(self, log_data: pd.DataFrame) -> pd.DataFrame:
278
"""Process security logs through batch inference"""
279
280
# Convert DataFrame to inference requests
281
requests = []
282
for idx, row in log_data.iterrows():
283
request = InferenceRequest(
284
source_ip_address=row.get('source_ip', ''),
285
created=int(datetime.now().timestamp()),
286
kind="log_analysis",
287
category=row.get('log_type', 'general'),
288
outcome=""
289
)
290
requests.append(request)
291
292
# Create batch request
293
batch_request = InferenceRequestBatch(
294
row_id_key="log_id",
295
data=requests
296
)
297
298
try:
299
# Perform batch inference
300
print(f"Processing {len(requests)} log entries...")
301
batch_results = await self.client.infer_batch(batch_request)
302
303
# Convert results back to DataFrame
304
results_data = []
305
for i, batch_result in enumerate(batch_results):
306
results_data.append({
307
'log_index': i,
308
'threat_detected': batch_result.result.threat_detected,
309
'threat_score': batch_result.result.score,
310
'source_ip': requests[i].source_ip_address,
311
'log_category': requests[i].category
312
})
313
314
results_df = pd.DataFrame(results_data)
315
316
# Print summary
317
threat_count = results_df['threat_detected'].sum()
318
print(f"Batch processing complete:")
319
print(f" Total logs processed: {len(results_df)}")
320
print(f" Threats detected: {threat_count}")
321
print(f" Threat percentage: {(threat_count/len(results_df)*100):.2f}%")
322
323
return results_df
324
325
except Exception as e:
326
print(f"Batch inference failed: {e}")
327
return pd.DataFrame()
328
329
async def real_time_stream_processing(self, stream_data):
330
"""Process streaming data with real-time inference"""
331
threat_alerts = []
332
333
for data_point in stream_data:
334
request = InferenceRequest(
335
source_ip_address=data_point.get('ip'),
336
created=int(datetime.now().timestamp()),
337
kind="real_time_scan",
338
category=data_point.get('type', 'unknown')
339
)
340
341
try:
342
result = await self.client.infer(request)
343
344
# Handle high-risk threats immediately
345
if result.threat_detected and result.score > 80:
346
alert = {
347
'timestamp': datetime.now().isoformat(),
348
'source_ip': request.source_ip_address,
349
'threat_score': result.score,
350
'category': request.category,
351
'severity': 'HIGH' if result.score > 90 else 'MEDIUM'
352
}
353
threat_alerts.append(alert)
354
print(f"🚨 HIGH THREAT ALERT: {alert}")
355
356
except Exception as e:
357
print(f"Real-time inference error: {e}")
358
359
return threat_alerts
360
361
# Usage example
362
async def run_batch_processing():
363
processor = BatchInferenceProcessor()
364
365
# Sample log data
366
sample_logs = pd.DataFrame({
367
'source_ip': ['192.168.1.100', '10.0.0.50', '172.16.0.25', '203.0.113.10'],
368
'log_type': ['access_log', 'error_log', 'security_log', 'access_log'],
369
'timestamp': [datetime.now().isoformat()] * 4
370
})
371
372
# Process batch
373
results = await processor.process_security_logs(sample_logs)
374
print("Batch results:")
375
print(results)
376
377
# Process real-time stream
378
stream_data = [
379
{'ip': '192.168.1.200', 'type': 'suspicious_activity'},
380
{'ip': '10.0.0.75', 'type': 'normal_traffic'},
381
{'ip': '172.16.0.100', 'type': 'potential_intrusion'}
382
]
383
384
alerts = await processor.real_time_stream_processing(stream_data)
385
print(f"Generated {len(alerts)} threat alerts")
386
387
# Run the batch processing example
388
asyncio.run(run_batch_processing())
389
```
390
391
### High-Performance Inference Pipeline
392
393
```python
394
from inference.rest_inference_client import RestInferenceClient
395
from inference.inference_config import InferenceConfig
396
from inference.inference_request import InferenceRequest
397
import asyncio
398
import aiohttp
399
from typing import List, Dict, Any
400
import time
401
from dataclasses import dataclass
402
403
@dataclass
404
class PerformanceMetrics:
405
"""Track inference performance metrics"""
406
total_requests: int = 0
407
successful_inferences: int = 0
408
failed_inferences: int = 0
409
total_latency: float = 0.0
410
min_latency: float = float('inf')
411
max_latency: float = 0.0
412
413
@property
414
def average_latency(self) -> float:
415
return self.total_latency / self.total_requests if self.total_requests > 0 else 0.0
416
417
@property
418
def success_rate(self) -> float:
419
return self.successful_inferences / self.total_requests if self.total_requests > 0 else 0.0
420
421
class HighPerformanceInferenceClient:
422
"""High-performance inference client with connection pooling and metrics"""
423
424
def __init__(self, max_concurrent_requests: int = 50):
425
self.config = InferenceConfig()
426
self.max_concurrent_requests = max_concurrent_requests
427
self.metrics = PerformanceMetrics()
428
self.semaphore = asyncio.Semaphore(max_concurrent_requests)
429
430
async def __aenter__(self):
431
"""Async context manager entry"""
432
connector = aiohttp.TCPConnector(
433
limit=100, # Total connection pool size
434
limit_per_host=50, # Connections per host
435
ttl_dns_cache=300, # DNS cache TTL
436
use_dns_cache=True,
437
)
438
439
timeout = aiohttp.ClientTimeout(
440
total=30, # Total timeout
441
connect=5, # Connection timeout
442
sock_read=10 # Socket read timeout
443
)
444
445
self.session = aiohttp.ClientSession(
446
connector=connector,
447
timeout=timeout
448
)
449
return self
450
451
async def __aexit__(self, exc_type, exc_val, exc_tb):
452
"""Async context manager exit"""
453
await self.session.close()
454
455
async def infer_with_metrics(self, request: InferenceRequest) -> Dict[str, Any]:
456
"""Perform inference with performance tracking"""
457
async with self.semaphore: # Limit concurrent requests
458
start_time = time.time()
459
460
try:
461
# Use direct HTTP call for performance
462
url = f"{self.config.rest_service_url()}:{self.config.rest_service_port()}/infer"
463
464
request_data = {
465
'source_ip_address': request.source_ip_address,
466
'created': request.created,
467
'kind': request.kind,
468
'category': request.category,
469
'outcome': request.outcome
470
}
471
472
async with self.session.post(url, json=request_data) as response:
473
response.raise_for_status()
474
result_data = await response.json()
475
476
# Calculate latency
477
latency = time.time() - start_time
478
479
# Update metrics
480
self.metrics.total_requests += 1
481
self.metrics.successful_inferences += 1
482
self.metrics.total_latency += latency
483
self.metrics.min_latency = min(self.metrics.min_latency, latency)
484
self.metrics.max_latency = max(self.metrics.max_latency, latency)
485
486
return {
487
'success': True,
488
'result': result_data,
489
'latency': latency
490
}
491
492
except Exception as e:
493
latency = time.time() - start_time
494
495
# Update error metrics
496
self.metrics.total_requests += 1
497
self.metrics.failed_inferences += 1
498
self.metrics.total_latency += latency
499
500
return {
501
'success': False,
502
'error': str(e),
503
'latency': latency
504
}
505
506
async def batch_infer_concurrent(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:
507
"""Process multiple requests concurrently"""
508
tasks = [self.infer_with_metrics(request) for request in requests]
509
results = await asyncio.gather(*tasks, return_exceptions=True)
510
return results
511
512
def get_performance_report(self) -> Dict[str, Any]:
513
"""Generate comprehensive performance report"""
514
return {
515
'total_requests': self.metrics.total_requests,
516
'successful_requests': self.metrics.successful_inferences,
517
'failed_requests': self.metrics.failed_inferences,
518
'success_rate_percent': self.metrics.success_rate * 100,
519
'average_latency_ms': self.metrics.average_latency * 1000,
520
'min_latency_ms': self.metrics.min_latency * 1000 if self.metrics.min_latency != float('inf') else 0,
521
'max_latency_ms': self.metrics.max_latency * 1000,
522
'requests_per_second': self.metrics.total_requests / (self.metrics.total_latency / self.metrics.total_requests) if self.metrics.total_requests > 0 else 0
523
}
524
525
async def performance_benchmark():
526
"""Benchmark inference performance with different load patterns"""
527
528
async with HighPerformanceInferenceClient(max_concurrent_requests=25) as client:
529
530
# Generate test requests
531
test_requests = []
532
for i in range(100):
533
request = InferenceRequest(
534
source_ip_address=f"192.168.1.{100 + i % 50}",
535
created=int(time.time()),
536
kind="performance_test",
537
category=f"test_category_{i % 5}",
538
outcome=""
539
)
540
test_requests.append(request)
541
542
print("Starting performance benchmark...")
543
start_time = time.time()
544
545
# Process all requests concurrently
546
results = await client.batch_infer_concurrent(test_requests)
547
548
total_time = time.time() - start_time
549
550
# Analyze results
551
successful_results = [r for r in results if isinstance(r, dict) and r.get('success')]
552
failed_results = [r for r in results if isinstance(r, dict) and not r.get('success')]
553
554
print(f"\nPerformance Benchmark Results:")
555
print(f"Total processing time: {total_time:.2f} seconds")
556
print(f"Successful inferences: {len(successful_results)}")
557
print(f"Failed inferences: {len(failed_results)}")
558
print(f"Overall throughput: {len(results)/total_time:.2f} requests/second")
559
560
# Get detailed performance report
561
report = client.get_performance_report()
562
print(f"\nDetailed Performance Metrics:")
563
for metric, value in report.items():
564
print(f" {metric}: {value:.2f}")
565
566
# Run performance benchmark
567
asyncio.run(performance_benchmark())
568
```
569
570
### Custom Inference Client Implementation
571
572
```python
573
from inference.inference_client import InferenceClient
574
from inference.inference_config import InferenceConfig
575
from inference.inference_request import InferenceRequest, InferenceRequestBatch
576
from inference.inference_result import InferenceResult, InferenceResultBatch
577
import grpc
578
import json
579
from typing import List
580
581
class GrpcInferenceClient(InferenceClient):
582
"""Custom gRPC-based inference client implementation"""
583
584
def __init__(self):
585
super().__init__()
586
self.config = InferenceConfig()
587
self.channel = None
588
self.stub = None
589
590
async def connect(self):
591
"""Establish gRPC connection"""
592
server_address = f"{self.config.grpc_service_url()}:{self.config.grpc_service_port()}"
593
594
# Create gRPC channel with configuration
595
channel_options = [
596
('grpc.keepalive_time_ms', 30000),
597
('grpc.keepalive_timeout_ms', 5000),
598
('grpc.keepalive_permit_without_calls', True),
599
('grpc.http2.max_pings_without_data', 0),
600
('grpc.http2.min_time_between_pings_ms', 10000),
601
]
602
603
self.channel = grpc.aio.insecure_channel(server_address, options=channel_options)
604
605
# Create stub (this would use generated gRPC stub in real implementation)
606
# self.stub = inference_service_pb2_grpc.InferenceServiceStub(self.channel)
607
608
print(f"Connected to gRPC inference service at {server_address}")
609
610
async def disconnect(self):
611
"""Close gRPC connection"""
612
if self.channel:
613
await self.channel.close()
614
615
async def infer(self, inference_request: InferenceRequest) -> InferenceResult:
616
"""Perform single inference via gRPC"""
617
if not self.channel:
618
await self.connect()
619
620
try:
621
# Convert to gRPC request format (pseudo-code)
622
grpc_request = {
623
'source_ip': inference_request.source_ip_address,
624
'timestamp': inference_request.created,
625
'request_type': inference_request.kind,
626
'category': inference_request.category
627
}
628
629
# Make gRPC call (pseudo-code)
630
# response = await self.stub.Infer(grpc_request)
631
632
# Simulate gRPC response for example
633
response = {
634
'threat_detected': True, # Simulate threat detection
635
'confidence_score': 85 # Simulate confidence score
636
}
637
638
# Convert gRPC response to InferenceResult
639
result = InferenceResult(
640
threat_detected=response['threat_detected'],
641
score=response['confidence_score']
642
)
643
644
return result
645
646
except Exception as e:
647
print(f"gRPC inference failed: {e}")
648
raise
649
650
async def infer_batch(self, inference_request_batch: InferenceRequestBatch) -> List[InferenceResultBatch]:
651
"""Perform batch inference via gRPC"""
652
if not self.channel:
653
await self.connect()
654
655
try:
656
# Convert batch to gRPC format
657
grpc_requests = []
658
for i, request in enumerate(inference_request_batch.data):
659
grpc_request = {
660
'id': f"{inference_request_batch.row_id_key}_{i}",
661
'source_ip': request.source_ip_address,
662
'timestamp': request.created,
663
'request_type': request.kind,
664
'category': request.category
665
}
666
grpc_requests.append(grpc_request)
667
668
# Make batch gRPC call (pseudo-code)
669
# batch_response = await self.stub.InferBatch({'requests': grpc_requests})
670
671
# Simulate batch response
672
batch_results = []
673
for i, request in enumerate(grpc_requests):
674
result = InferenceResultBatch(
675
row_id_key=request['id'],
676
result=InferenceResult(
677
threat_detected=i % 3 == 0, # Simulate some threats
678
score=70 + (i * 5) % 30 # Simulate varying scores
679
)
680
)
681
batch_results.append(result)
682
683
return batch_results
684
685
except Exception as e:
686
print(f"gRPC batch inference failed: {e}")
687
raise
688
689
class InferenceClientFactory:
690
"""Factory for creating different types of inference clients"""
691
692
@staticmethod
693
def create_client(client_type: str = "rest") -> InferenceClient:
694
"""Create inference client based on type"""
695
if client_type.lower() == "rest":
696
from inference.rest_inference_client import RestInferenceClient
697
return RestInferenceClient()
698
elif client_type.lower() == "grpc":
699
return GrpcInferenceClient()
700
else:
701
raise ValueError(f"Unsupported client type: {client_type}")
702
703
@staticmethod
704
def create_best_available_client() -> InferenceClient:
705
"""Create the best available client based on service availability"""
706
config = InferenceConfig()
707
708
# Try gRPC first for better performance
709
try:
710
grpc_client = GrpcInferenceClient()
711
# Test connection (simplified)
712
print("Using gRPC inference client")
713
return grpc_client
714
except Exception:
715
print("gRPC not available, falling back to REST")
716
717
# Fallback to REST
718
try:
719
from inference.rest_inference_client import RestInferenceClient
720
rest_client = RestInferenceClient()
721
print("Using REST inference client")
722
return rest_client
723
except Exception as e:
724
raise RuntimeError(f"No inference clients available: {e}")
725
726
# Usage example
727
async def multi_client_example():
728
"""Demonstrate multiple client types"""
729
730
# Create different client types
731
rest_client = InferenceClientFactory.create_client("rest")
732
grpc_client = InferenceClientFactory.create_client("grpc")
733
best_client = InferenceClientFactory.create_best_available_client()
734
735
# Test request
736
request = InferenceRequest(
737
source_ip_address="10.0.0.100",
738
created=int(time.time()),
739
kind="multi_client_test",
740
category="performance_comparison"
741
)
742
743
# Compare performance between clients
744
clients = [
745
("REST", rest_client),
746
("gRPC", grpc_client),
747
("Best Available", best_client)
748
]
749
750
for client_name, client in clients:
751
try:
752
start_time = time.time()
753
result = await client.infer(request)
754
latency = time.time() - start_time
755
756
print(f"{client_name} Client:")
757
print(f" Latency: {latency*1000:.2f} ms")
758
print(f" Threat detected: {result.threat_detected}")
759
print(f" Score: {result.score}")
760
print()
761
762
except Exception as e:
763
print(f"{client_name} Client failed: {e}")
764
765
# Clean up gRPC connection
766
if hasattr(client, 'disconnect'):
767
await client.disconnect()
768
769
# Run the multi-client example
770
asyncio.run(multi_client_example())
771
```
772
773
## Best Practices
774
775
### Client Configuration
776
- Use appropriate client types based on performance requirements (gRPC for high-throughput, REST for simplicity)
777
- Configure connection pooling and timeouts appropriately
778
- Implement health checks for service availability
779
- Use circuit breakers for fault tolerance
780
781
### Performance Optimization
782
- Implement concurrent request processing for batch operations
783
- Use connection pooling to minimize connection overhead
784
- Monitor latency and throughput metrics
785
- Implement caching for frequently requested predictions
786
787
### Error Handling and Reliability
788
- Implement retry logic with exponential backoff
789
- Use timeouts to prevent hanging requests
790
- Log detailed error information for debugging
791
- Implement graceful degradation for service failures
792
793
### Security Considerations
794
- Use HTTPS/TLS for REST endpoints
795
- Implement authentication for inference services
796
- Validate and sanitize input data
797
- Monitor for suspicious inference patterns