or run

npx @tessl/cli init
Log in

Version

Tile

Overview

Evals

Files

Files

docs

auth.mdbom.mdconfig.mdfilestore.mdindex.mdinference.mdmetadata.mdpolicy.md

inference.mddocs/

0

# ML Inference Client

1

2

Standardized client framework for machine learning model inference supporting both individual and batch processing with REST and gRPC protocol support. The inference system provides high-performance model serving capabilities with comprehensive request/response modeling and configurable service endpoints for scalable ML operations.

3

4

## Capabilities

5

6

### Inference Client Interface

7

8

Abstract base interface defining standardized inference operations for both single predictions and batch processing with pluggable implementation support for different service protocols.

9

10

```python { .api }

11

class InferenceClient(ABC):

12

"""

13

Interface for inference client.

14

15

Attributes:

16

- _config: InferenceConfig - Inference configuration

17

"""

18

19

def __init__(self) -> None:

20

"""Initialize with InferenceConfig"""

21

...

22

23

@classmethod

24

def __subclasshook__(cls, subclass):

25

"""Class method for subclass checking"""

26

...

27

28

@abstractmethod

29

def infer(self, inference_request: InferenceRequest) -> InferenceResult:

30

"""

31

Abstract method to invoke inference.

32

33

Parameters:

34

- inference_request: InferenceRequest - Request containing inference data

35

36

Returns:

37

InferenceResult - Inference prediction result

38

"""

39

...

40

41

@abstractmethod

42

def infer_batch(self, inference_request_batch: InferenceRequestBatch) -> list[InferenceResultBatch]:

43

"""

44

Abstract method to invoke batch inference.

45

46

Parameters:

47

- inference_request_batch: InferenceRequestBatch - Batch of inference requests

48

49

Returns:

50

list[InferenceResultBatch] - List of batch inference results

51

"""

52

...

53

```

54

55

### Inference Configuration

56

57

Configuration management for inference service endpoints supporting both REST and gRPC protocols with configurable URLs and ports for flexible deployment scenarios.

58

59

```python { .api }

60

class InferenceConfig:

61

"""

62

Configurations for inference.

63

"""

64

65

def __init__(self) -> None:

66

"""Initialize with inference.properties"""

67

...

68

69

def rest_service_url(self) -> str:

70

"""Returns URL of inference REST service (default: "http://localhost")"""

71

...

72

73

def rest_service_port(self) -> str:

74

"""Returns port of inference REST service (default: "7080")"""

75

...

76

77

def grpc_service_url(self) -> str:

78

"""Returns URL of inference gRPC service (default: "http://localhost")"""

79

...

80

81

def grpc_service_port(self) -> str:

82

"""Returns port of inference gRPC service (default: "7081")"""

83

...

84

```

85

86

### Inference Request Models

87

88

Comprehensive data models for capturing inference request information including source metadata, request categorization, and extensible attributes for diverse ML use cases.

89

90

```python { .api }

91

class InferenceRequest:

92

"""

93

Contains details necessary for inference to be invoked.

94

95

Properties (with getters/setters):

96

- source_ip_address: str - Source IP address

97

- created: int - Creation timestamp

98

- kind: str - Request kind

99

- category: str - Request category

100

- outcome: str - Request outcome

101

"""

102

103

def __init__(self, source_ip_address: str = "", created: int = 0, kind: str = "", category: str = "", outcome: str = "") -> None:

104

"""

105

Constructor with default values.

106

107

Parameters:

108

- source_ip_address: str - Source IP address (default: "")

109

- created: int - Creation timestamp (default: 0)

110

- kind: str - Request kind (default: "")

111

- category: str - Request category (default: "")

112

- outcome: str - Request outcome (default: "")

113

"""

114

...

115

116

class InferenceRequestBatch:

117

"""

118

Contains details necessary for inference to be invoked on a batch.

119

120

Properties (with getters/setters):

121

- row_id_key: str - Row ID key

122

- data: list[InferenceRequest] - List of inference requests

123

"""

124

125

def __init__(self, row_id_key: str, data: list[InferenceRequest]) -> None:

126

"""

127

Constructor.

128

129

Parameters:

130

- row_id_key: str - Row identifier key

131

- data: list[InferenceRequest] - List of inference requests

132

"""

133

...

134

```

135

136

### Inference Result Models

137

138

Structured result models for capturing inference predictions including threat detection indicators, confidence scores, and batch processing support with row-level result tracking.

139

140

```python { .api }

141

class InferenceResult:

142

"""

143

Contains details about the results of an inference request.

144

145

Properties (with getters/setters):

146

- threat_detected: bool - Whether threat was detected

147

- score: int - Inference score

148

"""

149

150

def __init__(self, threat_detected: bool = False, score: int = 0) -> None:

151

"""

152

Constructor with default values.

153

154

Parameters:

155

- threat_detected: bool - Whether threat was detected (default: False)

156

- score: int - Inference score (default: 0)

157

"""

158

...

159

160

class InferenceResultBatch:

161

"""

162

Represents a single result of a batch inference.

163

164

Properties (with getters/setters):

165

- row_id_key: str - Row ID key

166

- result: InferenceResult - Inference result

167

"""

168

169

def __init__(self, row_id_key: str, result: InferenceResult) -> None:

170

"""

171

Constructor.

172

173

Parameters:

174

- row_id_key: str - Row identifier key

175

- result: InferenceResult - Inference result for this row

176

"""

177

...

178

```

179

180

### REST Inference Client

181

182

Production-ready REST-based inference client implementation with async HTTP operations, JSON serialization, and comprehensive error handling for scalable model serving.

183

184

```python { .api }

185

class RestInferenceClient(InferenceClient):

186

"""

187

REST-based implementation of InferenceClient.

188

"""

189

190

async def infer(self, inference_request: InferenceRequest) -> InferenceResult:

191

"""

192

Async method for single inference.

193

194

Parameters:

195

- inference_request: InferenceRequest - Request data for inference

196

197

Returns:

198

InferenceResult - Prediction result

199

"""

200

...

201

202

async def infer_batch(self, inference_request_batch: InferenceRequestBatch) -> list[InferenceResultBatch]:

203

"""

204

Async method for batch inference.

205

206

Parameters:

207

- inference_request_batch: InferenceRequestBatch - Batch of requests

208

209

Returns:

210

list[InferenceResultBatch] - List of batch results

211

"""

212

...

213

```

214

215

## Usage Examples

216

217

### Basic Inference Operations

218

219

```python

220

from inference.rest_inference_client import RestInferenceClient

221

from inference.inference_request import InferenceRequest

222

from inference.inference_result import InferenceResult

223

import asyncio

224

from datetime import datetime

225

226

async def basic_inference_example():

227

"""Demonstrate basic inference operations"""

228

229

# Initialize REST inference client

230

client = RestInferenceClient()

231

232

# Create inference request

233

request = InferenceRequest(

234

source_ip_address="192.168.1.100",

235

created=int(datetime.now().timestamp()),

236

kind="security_scan",

237

category="network_traffic",

238

outcome="" # Will be populated by inference

239

)

240

241

try:

242

# Perform single inference

243

result = await client.infer(request)

244

245

print(f"Inference completed:")

246

print(f" Threat detected: {result.threat_detected}")

247

print(f" Confidence score: {result.score}")

248

249

if result.threat_detected:

250

print("⚠️ Security threat identified - taking protective action")

251

else:

252

print("✅ No threats detected - traffic appears safe")

253

254

except Exception as e:

255

print(f"Inference failed: {e}")

256

257

# Run the example

258

asyncio.run(basic_inference_example())

259

```

260

261

### Batch Inference Processing

262

263

```python

264

from inference.rest_inference_client import RestInferenceClient

265

from inference.inference_request import InferenceRequest, InferenceRequestBatch

266

from inference.inference_result import InferenceResultBatch

267

import asyncio

268

import pandas as pd

269

from datetime import datetime

270

271

class BatchInferenceProcessor:

272

"""Process large datasets using batch inference"""

273

274

def __init__(self):

275

self.client = RestInferenceClient()

276

277

async def process_security_logs(self, log_data: pd.DataFrame) -> pd.DataFrame:

278

"""Process security logs through batch inference"""

279

280

# Convert DataFrame to inference requests

281

requests = []

282

for idx, row in log_data.iterrows():

283

request = InferenceRequest(

284

source_ip_address=row.get('source_ip', ''),

285

created=int(datetime.now().timestamp()),

286

kind="log_analysis",

287

category=row.get('log_type', 'general'),

288

outcome=""

289

)

290

requests.append(request)

291

292

# Create batch request

293

batch_request = InferenceRequestBatch(

294

row_id_key="log_id",

295

data=requests

296

)

297

298

try:

299

# Perform batch inference

300

print(f"Processing {len(requests)} log entries...")

301

batch_results = await self.client.infer_batch(batch_request)

302

303

# Convert results back to DataFrame

304

results_data = []

305

for i, batch_result in enumerate(batch_results):

306

results_data.append({

307

'log_index': i,

308

'threat_detected': batch_result.result.threat_detected,

309

'threat_score': batch_result.result.score,

310

'source_ip': requests[i].source_ip_address,

311

'log_category': requests[i].category

312

})

313

314

results_df = pd.DataFrame(results_data)

315

316

# Print summary

317

threat_count = results_df['threat_detected'].sum()

318

print(f"Batch processing complete:")

319

print(f" Total logs processed: {len(results_df)}")

320

print(f" Threats detected: {threat_count}")

321

print(f" Threat percentage: {(threat_count/len(results_df)*100):.2f}%")

322

323

return results_df

324

325

except Exception as e:

326

print(f"Batch inference failed: {e}")

327

return pd.DataFrame()

328

329

async def real_time_stream_processing(self, stream_data):

330

"""Process streaming data with real-time inference"""

331

threat_alerts = []

332

333

for data_point in stream_data:

334

request = InferenceRequest(

335

source_ip_address=data_point.get('ip'),

336

created=int(datetime.now().timestamp()),

337

kind="real_time_scan",

338

category=data_point.get('type', 'unknown')

339

)

340

341

try:

342

result = await self.client.infer(request)

343

344

# Handle high-risk threats immediately

345

if result.threat_detected and result.score > 80:

346

alert = {

347

'timestamp': datetime.now().isoformat(),

348

'source_ip': request.source_ip_address,

349

'threat_score': result.score,

350

'category': request.category,

351

'severity': 'HIGH' if result.score > 90 else 'MEDIUM'

352

}

353

threat_alerts.append(alert)

354

print(f"🚨 HIGH THREAT ALERT: {alert}")

355

356

except Exception as e:

357

print(f"Real-time inference error: {e}")

358

359

return threat_alerts

360

361

# Usage example

362

async def run_batch_processing():

363

processor = BatchInferenceProcessor()

364

365

# Sample log data

366

sample_logs = pd.DataFrame({

367

'source_ip': ['192.168.1.100', '10.0.0.50', '172.16.0.25', '203.0.113.10'],

368

'log_type': ['access_log', 'error_log', 'security_log', 'access_log'],

369

'timestamp': [datetime.now().isoformat()] * 4

370

})

371

372

# Process batch

373

results = await processor.process_security_logs(sample_logs)

374

print("Batch results:")

375

print(results)

376

377

# Process real-time stream

378

stream_data = [

379

{'ip': '192.168.1.200', 'type': 'suspicious_activity'},

380

{'ip': '10.0.0.75', 'type': 'normal_traffic'},

381

{'ip': '172.16.0.100', 'type': 'potential_intrusion'}

382

]

383

384

alerts = await processor.real_time_stream_processing(stream_data)

385

print(f"Generated {len(alerts)} threat alerts")

386

387

# Run the batch processing example

388

asyncio.run(run_batch_processing())

389

```

390

391

### High-Performance Inference Pipeline

392

393

```python

394

from inference.rest_inference_client import RestInferenceClient

395

from inference.inference_config import InferenceConfig

396

from inference.inference_request import InferenceRequest

397

import asyncio

398

import aiohttp

399

from typing import List, Dict, Any

400

import time

401

from dataclasses import dataclass

402

403

@dataclass

404

class PerformanceMetrics:

405

"""Track inference performance metrics"""

406

total_requests: int = 0

407

successful_inferences: int = 0

408

failed_inferences: int = 0

409

total_latency: float = 0.0

410

min_latency: float = float('inf')

411

max_latency: float = 0.0

412

413

@property

414

def average_latency(self) -> float:

415

return self.total_latency / self.total_requests if self.total_requests > 0 else 0.0

416

417

@property

418

def success_rate(self) -> float:

419

return self.successful_inferences / self.total_requests if self.total_requests > 0 else 0.0

420

421

class HighPerformanceInferenceClient:

422

"""High-performance inference client with connection pooling and metrics"""

423

424

def __init__(self, max_concurrent_requests: int = 50):

425

self.config = InferenceConfig()

426

self.max_concurrent_requests = max_concurrent_requests

427

self.metrics = PerformanceMetrics()

428

self.semaphore = asyncio.Semaphore(max_concurrent_requests)

429

430

async def __aenter__(self):

431

"""Async context manager entry"""

432

connector = aiohttp.TCPConnector(

433

limit=100, # Total connection pool size

434

limit_per_host=50, # Connections per host

435

ttl_dns_cache=300, # DNS cache TTL

436

use_dns_cache=True,

437

)

438

439

timeout = aiohttp.ClientTimeout(

440

total=30, # Total timeout

441

connect=5, # Connection timeout

442

sock_read=10 # Socket read timeout

443

)

444

445

self.session = aiohttp.ClientSession(

446

connector=connector,

447

timeout=timeout

448

)

449

return self

450

451

async def __aexit__(self, exc_type, exc_val, exc_tb):

452

"""Async context manager exit"""

453

await self.session.close()

454

455

async def infer_with_metrics(self, request: InferenceRequest) -> Dict[str, Any]:

456

"""Perform inference with performance tracking"""

457

async with self.semaphore: # Limit concurrent requests

458

start_time = time.time()

459

460

try:

461

# Use direct HTTP call for performance

462

url = f"{self.config.rest_service_url()}:{self.config.rest_service_port()}/infer"

463

464

request_data = {

465

'source_ip_address': request.source_ip_address,

466

'created': request.created,

467

'kind': request.kind,

468

'category': request.category,

469

'outcome': request.outcome

470

}

471

472

async with self.session.post(url, json=request_data) as response:

473

response.raise_for_status()

474

result_data = await response.json()

475

476

# Calculate latency

477

latency = time.time() - start_time

478

479

# Update metrics

480

self.metrics.total_requests += 1

481

self.metrics.successful_inferences += 1

482

self.metrics.total_latency += latency

483

self.metrics.min_latency = min(self.metrics.min_latency, latency)

484

self.metrics.max_latency = max(self.metrics.max_latency, latency)

485

486

return {

487

'success': True,

488

'result': result_data,

489

'latency': latency

490

}

491

492

except Exception as e:

493

latency = time.time() - start_time

494

495

# Update error metrics

496

self.metrics.total_requests += 1

497

self.metrics.failed_inferences += 1

498

self.metrics.total_latency += latency

499

500

return {

501

'success': False,

502

'error': str(e),

503

'latency': latency

504

}

505

506

async def batch_infer_concurrent(self, requests: List[InferenceRequest]) -> List[Dict[str, Any]]:

507

"""Process multiple requests concurrently"""

508

tasks = [self.infer_with_metrics(request) for request in requests]

509

results = await asyncio.gather(*tasks, return_exceptions=True)

510

return results

511

512

def get_performance_report(self) -> Dict[str, Any]:

513

"""Generate comprehensive performance report"""

514

return {

515

'total_requests': self.metrics.total_requests,

516

'successful_requests': self.metrics.successful_inferences,

517

'failed_requests': self.metrics.failed_inferences,

518

'success_rate_percent': self.metrics.success_rate * 100,

519

'average_latency_ms': self.metrics.average_latency * 1000,

520

'min_latency_ms': self.metrics.min_latency * 1000 if self.metrics.min_latency != float('inf') else 0,

521

'max_latency_ms': self.metrics.max_latency * 1000,

522

'requests_per_second': self.metrics.total_requests / (self.metrics.total_latency / self.metrics.total_requests) if self.metrics.total_requests > 0 else 0

523

}

524

525

async def performance_benchmark():

526

"""Benchmark inference performance with different load patterns"""

527

528

async with HighPerformanceInferenceClient(max_concurrent_requests=25) as client:

529

530

# Generate test requests

531

test_requests = []

532

for i in range(100):

533

request = InferenceRequest(

534

source_ip_address=f"192.168.1.{100 + i % 50}",

535

created=int(time.time()),

536

kind="performance_test",

537

category=f"test_category_{i % 5}",

538

outcome=""

539

)

540

test_requests.append(request)

541

542

print("Starting performance benchmark...")

543

start_time = time.time()

544

545

# Process all requests concurrently

546

results = await client.batch_infer_concurrent(test_requests)

547

548

total_time = time.time() - start_time

549

550

# Analyze results

551

successful_results = [r for r in results if isinstance(r, dict) and r.get('success')]

552

failed_results = [r for r in results if isinstance(r, dict) and not r.get('success')]

553

554

print(f"\nPerformance Benchmark Results:")

555

print(f"Total processing time: {total_time:.2f} seconds")

556

print(f"Successful inferences: {len(successful_results)}")

557

print(f"Failed inferences: {len(failed_results)}")

558

print(f"Overall throughput: {len(results)/total_time:.2f} requests/second")

559

560

# Get detailed performance report

561

report = client.get_performance_report()

562

print(f"\nDetailed Performance Metrics:")

563

for metric, value in report.items():

564

print(f" {metric}: {value:.2f}")

565

566

# Run performance benchmark

567

asyncio.run(performance_benchmark())

568

```

569

570

### Custom Inference Client Implementation

571

572

```python

573

from inference.inference_client import InferenceClient

574

from inference.inference_config import InferenceConfig

575

from inference.inference_request import InferenceRequest, InferenceRequestBatch

576

from inference.inference_result import InferenceResult, InferenceResultBatch

577

import grpc

578

import json

579

from typing import List

580

581

class GrpcInferenceClient(InferenceClient):

582

"""Custom gRPC-based inference client implementation"""

583

584

def __init__(self):

585

super().__init__()

586

self.config = InferenceConfig()

587

self.channel = None

588

self.stub = None

589

590

async def connect(self):

591

"""Establish gRPC connection"""

592

server_address = f"{self.config.grpc_service_url()}:{self.config.grpc_service_port()}"

593

594

# Create gRPC channel with configuration

595

channel_options = [

596

('grpc.keepalive_time_ms', 30000),

597

('grpc.keepalive_timeout_ms', 5000),

598

('grpc.keepalive_permit_without_calls', True),

599

('grpc.http2.max_pings_without_data', 0),

600

('grpc.http2.min_time_between_pings_ms', 10000),

601

]

602

603

self.channel = grpc.aio.insecure_channel(server_address, options=channel_options)

604

605

# Create stub (this would use generated gRPC stub in real implementation)

606

# self.stub = inference_service_pb2_grpc.InferenceServiceStub(self.channel)

607

608

print(f"Connected to gRPC inference service at {server_address}")

609

610

async def disconnect(self):

611

"""Close gRPC connection"""

612

if self.channel:

613

await self.channel.close()

614

615

async def infer(self, inference_request: InferenceRequest) -> InferenceResult:

616

"""Perform single inference via gRPC"""

617

if not self.channel:

618

await self.connect()

619

620

try:

621

# Convert to gRPC request format (pseudo-code)

622

grpc_request = {

623

'source_ip': inference_request.source_ip_address,

624

'timestamp': inference_request.created,

625

'request_type': inference_request.kind,

626

'category': inference_request.category

627

}

628

629

# Make gRPC call (pseudo-code)

630

# response = await self.stub.Infer(grpc_request)

631

632

# Simulate gRPC response for example

633

response = {

634

'threat_detected': True, # Simulate threat detection

635

'confidence_score': 85 # Simulate confidence score

636

}

637

638

# Convert gRPC response to InferenceResult

639

result = InferenceResult(

640

threat_detected=response['threat_detected'],

641

score=response['confidence_score']

642

)

643

644

return result

645

646

except Exception as e:

647

print(f"gRPC inference failed: {e}")

648

raise

649

650

async def infer_batch(self, inference_request_batch: InferenceRequestBatch) -> List[InferenceResultBatch]:

651

"""Perform batch inference via gRPC"""

652

if not self.channel:

653

await self.connect()

654

655

try:

656

# Convert batch to gRPC format

657

grpc_requests = []

658

for i, request in enumerate(inference_request_batch.data):

659

grpc_request = {

660

'id': f"{inference_request_batch.row_id_key}_{i}",

661

'source_ip': request.source_ip_address,

662

'timestamp': request.created,

663

'request_type': request.kind,

664

'category': request.category

665

}

666

grpc_requests.append(grpc_request)

667

668

# Make batch gRPC call (pseudo-code)

669

# batch_response = await self.stub.InferBatch({'requests': grpc_requests})

670

671

# Simulate batch response

672

batch_results = []

673

for i, request in enumerate(grpc_requests):

674

result = InferenceResultBatch(

675

row_id_key=request['id'],

676

result=InferenceResult(

677

threat_detected=i % 3 == 0, # Simulate some threats

678

score=70 + (i * 5) % 30 # Simulate varying scores

679

)

680

)

681

batch_results.append(result)

682

683

return batch_results

684

685

except Exception as e:

686

print(f"gRPC batch inference failed: {e}")

687

raise

688

689

class InferenceClientFactory:

690

"""Factory for creating different types of inference clients"""

691

692

@staticmethod

693

def create_client(client_type: str = "rest") -> InferenceClient:

694

"""Create inference client based on type"""

695

if client_type.lower() == "rest":

696

from inference.rest_inference_client import RestInferenceClient

697

return RestInferenceClient()

698

elif client_type.lower() == "grpc":

699

return GrpcInferenceClient()

700

else:

701

raise ValueError(f"Unsupported client type: {client_type}")

702

703

@staticmethod

704

def create_best_available_client() -> InferenceClient:

705

"""Create the best available client based on service availability"""

706

config = InferenceConfig()

707

708

# Try gRPC first for better performance

709

try:

710

grpc_client = GrpcInferenceClient()

711

# Test connection (simplified)

712

print("Using gRPC inference client")

713

return grpc_client

714

except Exception:

715

print("gRPC not available, falling back to REST")

716

717

# Fallback to REST

718

try:

719

from inference.rest_inference_client import RestInferenceClient

720

rest_client = RestInferenceClient()

721

print("Using REST inference client")

722

return rest_client

723

except Exception as e:

724

raise RuntimeError(f"No inference clients available: {e}")

725

726

# Usage example

727

async def multi_client_example():

728

"""Demonstrate multiple client types"""

729

730

# Create different client types

731

rest_client = InferenceClientFactory.create_client("rest")

732

grpc_client = InferenceClientFactory.create_client("grpc")

733

best_client = InferenceClientFactory.create_best_available_client()

734

735

# Test request

736

request = InferenceRequest(

737

source_ip_address="10.0.0.100",

738

created=int(time.time()),

739

kind="multi_client_test",

740

category="performance_comparison"

741

)

742

743

# Compare performance between clients

744

clients = [

745

("REST", rest_client),

746

("gRPC", grpc_client),

747

("Best Available", best_client)

748

]

749

750

for client_name, client in clients:

751

try:

752

start_time = time.time()

753

result = await client.infer(request)

754

latency = time.time() - start_time

755

756

print(f"{client_name} Client:")

757

print(f" Latency: {latency*1000:.2f} ms")

758

print(f" Threat detected: {result.threat_detected}")

759

print(f" Score: {result.score}")

760

print()

761

762

except Exception as e:

763

print(f"{client_name} Client failed: {e}")

764

765

# Clean up gRPC connection

766

if hasattr(client, 'disconnect'):

767

await client.disconnect()

768

769

# Run the multi-client example

770

asyncio.run(multi_client_example())

771

```

772

773

## Best Practices

774

775

### Client Configuration

776

- Use appropriate client types based on performance requirements (gRPC for high-throughput, REST for simplicity)

777

- Configure connection pooling and timeouts appropriately

778

- Implement health checks for service availability

779

- Use circuit breakers for fault tolerance

780

781

### Performance Optimization

782

- Implement concurrent request processing for batch operations

783

- Use connection pooling to minimize connection overhead

784

- Monitor latency and throughput metrics

785

- Implement caching for frequently requested predictions

786

787

### Error Handling and Reliability

788

- Implement retry logic with exponential backoff

789

- Use timeouts to prevent hanging requests

790

- Log detailed error information for debugging

791

- Implement graceful degradation for service failures

792

793

### Security Considerations

794

- Use HTTPS/TLS for REST endpoints

795

- Implement authentication for inference services

796

- Validate and sanitize input data

797

- Monitor for suspicious inference patterns