0
# Machine Learning
1
2
Comprehensive machine learning capabilities through PyTorch, TensorFlow, ONNX, and specialized ML framework bindings for training and inference operations.
3
4
## Capabilities
5
6
### PyTorch Integration
7
8
Deep learning framework with dynamic computation graphs, comprehensive tensor operations, and neural network building blocks.
9
10
```java { .api }
11
/**
12
* Multi-dimensional tensor for numerical computations
13
*/
14
public class Tensor extends Pointer {
15
/**
16
* Create tensor filled with zeros
17
* @param sizes Tensor dimensions
18
* @return Zero-filled tensor
19
*/
20
public static native Tensor zeros(long[] sizes);
21
22
/**
23
* Create tensor filled with ones
24
* @param sizes Tensor dimensions
25
* @return One-filled tensor
26
*/
27
public static native Tensor ones(long[] sizes);
28
29
/**
30
* Create tensor from data
31
* @param data Input data array
32
* @param sizes Tensor dimensions
33
* @return Tensor containing data
34
*/
35
public static native Tensor from_blob(FloatPointer data, long[] sizes);
36
37
/**
38
* Add tensors element-wise
39
* @param other Tensor to add
40
* @return Result tensor
41
*/
42
public native Tensor add(Tensor other);
43
44
/**
45
* Add scalar to tensor
46
* @param scalar Scalar value to add
47
* @return Result tensor
48
*/
49
public native Tensor add(double scalar);
50
51
/**
52
* Matrix multiplication
53
* @param mat2 Second matrix
54
* @return Matrix product
55
*/
56
public native Tensor mm(Tensor mat2);
57
58
/**
59
* Batch matrix multiplication
60
* @param mat2 Second batch of matrices
61
* @return Batch matrix product
62
*/
63
public native Tensor bmm(Tensor mat2);
64
65
/**
66
* Get tensor size
67
* @return Array of dimension sizes
68
*/
69
public native long[] sizes();
70
71
/**
72
* Get number of dimensions
73
* @return Number of dimensions
74
*/
75
public native int dim();
76
77
/**
78
* Get data type
79
* @return Tensor data type
80
*/
81
public native ScalarType dtype();
82
83
/**
84
* Move tensor to device
85
* @param device Target device
86
* @return Tensor on specified device
87
*/
88
public native Tensor to(Device device);
89
90
/**
91
* Reshape tensor
92
* @param shape New shape
93
* @return Reshaped tensor
94
*/
95
public native Tensor reshape(long[] shape);
96
97
/**
98
* Get tensor as CPU float data
99
* @return Float pointer to tensor data
100
*/
101
public native FloatPointer data_ptr_float();
102
}
103
104
/**
105
* Device specification for tensor placement
106
*/
107
public class Device extends Pointer {
108
/**
109
* CPU device
110
*/
111
public static final int kCPU = 0;
112
113
/**
114
* CUDA device
115
*/
116
public static final int kCUDA = 1;
117
118
/**
119
* Create device specification
120
* @param type Device type (kCPU, kCUDA)
121
* @param index Device index (for multi-GPU)
122
*/
123
public Device(int type, int index);
124
125
/**
126
* Create CPU device
127
* @return CPU device
128
*/
129
public static native Device cpu();
130
131
/**
132
* Create CUDA device
133
* @param index GPU index
134
* @return CUDA device
135
*/
136
public static native Device cuda(int index);
137
}
138
139
/**
140
* Neural network module base class
141
*/
142
public class Module extends Pointer {
143
/**
144
* Forward pass
145
* @param inputs Input tensors
146
* @return Output tensors
147
*/
148
public native TensorVector forward(TensorVector inputs);
149
150
/**
151
* Set training mode
152
* @param mode Training mode flag
153
*/
154
public native void train(boolean mode);
155
156
/**
157
* Set evaluation mode
158
*/
159
public native void eval();
160
161
/**
162
* Get module parameters
163
* @return Parameter tensors
164
*/
165
public native TensorVector parameters();
166
167
/**
168
* Move module to device
169
* @param device Target device
170
*/
171
public native void to(Device device);
172
}
173
```
174
175
### TensorFlow Lite Integration
176
177
Lightweight ML inference framework optimized for mobile and edge devices.
178
179
```java { .api }
180
/**
181
* TensorFlow Lite interpreter for model inference
182
*/
183
public class Interpreter extends Pointer {
184
/**
185
* Create interpreter from model buffer
186
* @param modelBuffer Serialized model data
187
*/
188
public Interpreter(ByteBuffer modelBuffer);
189
190
/**
191
* Create interpreter from byte array
192
* @param modelData Serialized model data
193
*/
194
public Interpreter(byte[] modelData);
195
196
/**
197
* Run inference
198
* @param inputs Input tensors array
199
* @param outputs Output tensors array
200
*/
201
public native void run(Object[] inputs, Object[] outputs);
202
203
/**
204
* Run inference with single input/output
205
* @param input Input tensor
206
* @param output Output tensor
207
*/
208
public native void run(Object input, Object output);
209
210
/**
211
* Resize input tensor
212
* @param inputIndex Input tensor index
213
* @param shape New input shape
214
*/
215
public native void resizeInput(int inputIndex, int[] shape);
216
217
/**
218
* Allocate tensors for inference
219
*/
220
public native void allocateTensors();
221
222
/**
223
* Get number of input tensors
224
* @return Number of inputs
225
*/
226
public native int getInputTensorCount();
227
228
/**
229
* Get number of output tensors
230
* @return Number of outputs
231
*/
232
public native int getOutputTensorCount();
233
234
/**
235
* Get input tensor info
236
* @param inputIndex Input index
237
* @return Tensor information
238
*/
239
public native Tensor getInputTensor(int inputIndex);
240
241
/**
242
* Get output tensor info
243
* @param outputIndex Output index
244
* @return Tensor information
245
*/
246
public native Tensor getOutputTensor(int outputIndex);
247
}
248
249
/**
250
* TensorFlow Lite model representation
251
*/
252
public class Model extends Pointer {
253
/**
254
* Load model from file
255
* @param modelPath Path to model file
256
* @return Loaded model
257
*/
258
public static native Model fromFile(String modelPath);
259
260
/**
261
* Load model from buffer
262
* @param modelData Model data buffer
263
* @return Loaded model
264
*/
265
public static native Model fromBuffer(ByteBuffer modelData);
266
}
267
```
268
269
### ONNX Integration
270
271
Open Neural Network Exchange format for model interoperability.
272
273
```java { .api }
274
/**
275
* ONNX model representation
276
*/
277
public class ModelProto extends Pointer {
278
/**
279
* Parse model from file
280
* @param filename Path to ONNX model file
281
* @return Parsed model
282
*/
283
public static native ModelProto parseFromFile(String filename);
284
285
/**
286
* Parse model from bytes
287
* @param data Serialized model data
288
* @return Parsed model
289
*/
290
public static native ModelProto parseFromBytes(BytePointer data);
291
292
/**
293
* Get model graph
294
* @return Model computation graph
295
*/
296
public native GraphProto graph();
297
298
/**
299
* Get model version
300
* @return Model version
301
*/
302
public native long ir_version();
303
304
/**
305
* Get model opset imports
306
* @return Opset version information
307
*/
308
public native OperatorSetIdProtoVector opset_import();
309
}
310
311
/**
312
* ONNX computation graph
313
*/
314
public class GraphProto extends Pointer {
315
/**
316
* Get graph nodes
317
* @return Graph computation nodes
318
*/
319
public native NodeProtoVector node();
320
321
/**
322
* Get graph inputs
323
* @return Input value info
324
*/
325
public native ValueInfoProtoVector input();
326
327
/**
328
* Get graph outputs
329
* @return Output value info
330
*/
331
public native ValueInfoProtoVector output();
332
333
/**
334
* Get graph initializers
335
* @return Constant tensor initializers
336
*/
337
public native TensorProtoVector initializer();
338
}
339
340
/**
341
* ONNX computation node
342
*/
343
public class NodeProto extends Pointer {
344
/**
345
* Get node operation type
346
* @return Operation name
347
*/
348
public native String op_type();
349
350
/**
351
* Get node inputs
352
* @return Input tensor names
353
*/
354
public native StringVector input();
355
356
/**
357
* Get node outputs
358
* @return Output tensor names
359
*/
360
public native StringVector output();
361
362
/**
363
* Get node attributes
364
* @return Operation attributes
365
*/
366
public native AttributeProtoVector attribute();
367
}
368
```
369
370
### ONNX Runtime Integration
371
372
High-performance inference engine for ONNX models with hardware acceleration.
373
374
```java { .api }
375
/**
376
* ONNX Runtime environment
377
*/
378
public class OrtEnv extends Pointer {
379
/**
380
* Create environment with logging level
381
* @param logLevel Logging level
382
* @param logId Log identifier
383
* @return Environment instance
384
*/
385
public static native OrtEnv create(int logLevel, String logId);
386
387
/**
388
* Create default environment
389
* @return Default environment
390
*/
391
public static native OrtEnv create();
392
}
393
394
/**
395
* ONNX Runtime inference session
396
*/
397
public class OrtSession extends Pointer {
398
/**
399
* Create session from model file
400
* @param env Environment
401
* @param modelPath Path to ONNX model
402
* @param sessionOptions Session options
403
* @return Inference session
404
*/
405
public static native OrtSession create(OrtEnv env, String modelPath,
406
OrtSessionOptions sessionOptions);
407
408
/**
409
* Run inference
410
* @param inputs Input tensors map
411
* @return Output tensors
412
*/
413
public native OrtValueVector run(OrtValueVector inputs);
414
415
/**
416
* Run inference with input/output names
417
* @param inputNames Input tensor names
418
* @param inputs Input tensors
419
* @param outputNames Output tensor names
420
* @return Output tensors
421
*/
422
public native OrtValueVector run(StringVector inputNames, OrtValueVector inputs,
423
StringVector outputNames);
424
425
/**
426
* Get input count
427
* @return Number of model inputs
428
*/
429
public native int getInputCount();
430
431
/**
432
* Get output count
433
* @return Number of model outputs
434
*/
435
public native int getOutputCount();
436
437
/**
438
* Get input name
439
* @param index Input index
440
* @return Input name
441
*/
442
public native String getInputName(int index);
443
444
/**
445
* Get output name
446
* @param index Output index
447
* @return Output name
448
*/
449
public native String getOutputName(int index);
450
}
451
452
/**
453
* ONNX Runtime tensor value
454
*/
455
public class OrtValue extends Pointer {
456
/**
457
* Create tensor from float array
458
* @param data Float data
459
* @param shape Tensor shape
460
* @return Tensor value
461
*/
462
public static native OrtValue createTensor(float[] data, long[] shape);
463
464
/**
465
* Create tensor from byte array
466
* @param data Byte data
467
* @param shape Tensor shape
468
* @return Tensor value
469
*/
470
public static native OrtValue createTensor(byte[] data, long[] shape);
471
472
/**
473
* Get tensor data as float array
474
* @return Float data array
475
*/
476
public native float[] getFloatArray();
477
478
/**
479
* Get tensor shape
480
* @return Shape dimensions
481
*/
482
public native long[] getShape();
483
484
/**
485
* Check if value is tensor
486
* @return true if tensor, false otherwise
487
*/
488
public native boolean isTensor();
489
}
490
```
491
492
### TensorRT Integration
493
494
NVIDIA GPU-accelerated inference optimization for deep learning models.
495
496
```java { .api }
497
/**
498
* TensorRT inference builder
499
*/
500
public class IBuilder extends Pointer {
501
/**
502
* Create inference builder
503
* @param logger Logger instance
504
* @return Builder instance
505
*/
506
public static native IBuilder createInferBuilder(ILogger logger);
507
508
/**
509
* Create network definition
510
* @param flags Network creation flags
511
* @return Network definition
512
*/
513
public native INetworkDefinition createNetworkV2(int flags);
514
515
/**
516
* Create builder configuration
517
* @return Builder configuration
518
*/
519
public native IBuilderConfig createBuilderConfig();
520
521
/**
522
* Build serialized network
523
* @param network Network definition
524
* @param config Builder configuration
525
* @return Serialized engine
526
*/
527
public native IHostMemory buildSerializedNetwork(INetworkDefinition network,
528
IBuilderConfig config);
529
}
530
531
/**
532
* TensorRT network definition for building models
533
*/
534
public class INetworkDefinition extends Pointer {
535
/**
536
* Add input tensor
537
* @param name Input name
538
* @param type Data type
539
* @param dims Tensor dimensions
540
* @return Input tensor
541
*/
542
public native ITensor addInput(String name, DataType type, Dims dims);
543
544
/**
545
* Add convolution layer
546
* @param input Input tensor
547
* @param nbOutputMaps Number of output feature maps
548
* @param kernelSize Convolution kernel size
549
* @param kernelWeights Kernel weights
550
* @param biasWeights Bias weights
551
* @return Convolution layer
552
*/
553
public native IConvolutionLayer addConvolutionNd(ITensor input, int nbOutputMaps,
554
DimsHW kernelSize, Weights kernelWeights, Weights biasWeights);
555
556
/**
557
* Add activation layer
558
* @param input Input tensor
559
* @param type Activation type
560
* @return Activation layer
561
*/
562
public native IActivationLayer addActivation(ITensor input, ActivationType type);
563
564
/**
565
* Add pooling layer
566
* @param input Input tensor
567
* @param type Pooling type
568
* @param windowSize Pooling window size
569
* @return Pooling layer
570
*/
571
public native IPoolingLayer addPoolingNd(ITensor input, PoolingType type,
572
DimsHW windowSize);
573
574
/**
575
* Mark tensor as network output
576
* @param tensor Output tensor
577
*/
578
public native void markOutput(ITensor tensor);
579
}
580
581
/**
582
* TensorRT CUDA inference engine
583
*/
584
public class ICudaEngine extends Pointer {
585
/**
586
* Create execution context
587
* @return Execution context
588
*/
589
public native IExecutionContext createExecutionContext();
590
591
/**
592
* Get number of bindings (inputs + outputs)
593
* @return Number of bindings
594
*/
595
public native int getNbBindings();
596
597
/**
598
* Get binding name
599
* @param index Binding index
600
* @return Binding name
601
*/
602
public native String getBindingName(int index);
603
604
/**
605
* Check if binding is input
606
* @param index Binding index
607
* @return true if input binding
608
*/
609
public native boolean bindingIsInput(int index);
610
611
/**
612
* Get binding dimensions
613
* @param index Binding index
614
* @return Tensor dimensions
615
*/
616
public native Dims getBindingDimensions(int index);
617
}
618
619
/**
620
* TensorRT execution context for inference
621
*/
622
public class IExecutionContext extends Pointer {
623
/**
624
* Execute inference synchronously
625
* @param bindings Array of device memory pointers for inputs/outputs
626
* @return true if successful
627
*/
628
public native boolean execute(PointerPointer bindings);
629
630
/**
631
* Execute inference asynchronously
632
* @param bindings Array of device memory pointers for inputs/outputs
633
* @param stream CUDA stream
634
* @return true if successful
635
*/
636
public native boolean enqueue(PointerPointer bindings, Pointer stream);
637
638
/**
639
* Set binding dimensions for dynamic shapes
640
* @param index Binding index
641
* @param dimensions New dimensions
642
* @return true if successful
643
*/
644
public native boolean setBindingDimensions(int index, Dims dimensions);
645
}
646
```
647
648
## Usage Examples
649
650
### PyTorch Tensor Operations
651
652
```java
653
import org.bytedeco.pytorch.*;
654
import org.bytedeco.pytorch.torch.*;
655
import static org.bytedeco.pytorch.global.torch.*;
656
657
public class TensorExample {
658
static {
659
Loader.load(torch.class);
660
}
661
662
public static void tensorOperations() {
663
try (PointerScope scope = new PointerScope()) {
664
// Create tensors
665
Tensor a = zeros(new long[]{3, 4});
666
Tensor b = ones(new long[]{3, 4});
667
668
// Basic operations
669
Tensor sum = a.add(b);
670
Tensor scaled = sum.mul(2.0);
671
672
// Matrix operations
673
Tensor matrix1 = rand(new long[]{3, 4});
674
Tensor matrix2 = rand(new long[]{4, 5});
675
Tensor product = matrix1.mm(matrix2);
676
677
System.out.println("Matrix product shape: " +
678
java.util.Arrays.toString(product.sizes()));
679
680
// Move to GPU if available
681
if (cuda_is_available()) {
682
Device gpu = Device.cuda(0);
683
Tensor gpuTensor = matrix1.to(gpu);
684
System.out.println("Tensor moved to GPU");
685
}
686
}
687
}
688
}
689
```
690
691
### TensorFlow Lite Inference
692
693
```java
694
import org.bytedeco.tensorflowlite.*;
695
import java.nio.ByteBuffer;
696
import java.nio.ByteOrder;
697
698
public class TFLiteInference {
699
static {
700
Loader.load(tensorflowlite.class);
701
}
702
703
public static void runInference(String modelPath, float[] inputData) {
704
try (PointerScope scope = new PointerScope()) {
705
// Load model
706
Model model = Model.fromFile(modelPath);
707
Interpreter interpreter = new Interpreter(model);
708
709
// Prepare input
710
int inputSize = inputData.length;
711
ByteBuffer inputBuffer = ByteBuffer.allocateDirect(inputSize * 4);
712
inputBuffer.order(ByteOrder.nativeOrder());
713
for (float value : inputData) {
714
inputBuffer.putFloat(value);
715
}
716
717
// Prepare output
718
int outputSize = interpreter.getOutputTensor(0).numElements();
719
ByteBuffer outputBuffer = ByteBuffer.allocateDirect(outputSize * 4);
720
outputBuffer.order(ByteOrder.nativeOrder());
721
722
// Run inference
723
interpreter.run(inputBuffer, outputBuffer);
724
725
// Process results
726
outputBuffer.rewind();
727
float[] results = new float[outputSize];
728
outputBuffer.asFloatBuffer().get(results);
729
730
System.out.println("Inference completed. Output size: " + outputSize);
731
}
732
}
733
}
734
```
735
736
### ONNX Runtime Inference
737
738
```java
739
import org.bytedeco.onnxruntime.*;
740
import static org.bytedeco.onnxruntime.global.onnxruntime.*;
741
742
public class ONNXInference {
743
static {
744
Loader.load(onnxruntime.class);
745
}
746
747
public static void runONNXModel(String modelPath, float[] inputData, long[] inputShape) {
748
try (PointerScope scope = new PointerScope()) {
749
// Create environment and session
750
OrtEnv env = OrtEnv.create(ORT_LOGGING_LEVEL_WARNING, "ONNXInference");
751
OrtSessionOptions sessionOptions = new OrtSessionOptions();
752
OrtSession session = OrtSession.create(env, modelPath, sessionOptions);
753
754
// Create input tensor
755
OrtValue inputTensor = OrtValue.createTensor(inputData, inputShape);
756
OrtValueVector inputs = new OrtValueVector(inputTensor);
757
758
// Get input/output names
759
String inputName = session.getInputName(0);
760
String outputName = session.getOutputName(0);
761
762
StringVector inputNames = new StringVector(inputName);
763
StringVector outputNames = new StringVector(outputName);
764
765
// Run inference
766
OrtValueVector outputs = session.run(inputNames, inputs, outputNames);
767
768
// Get results
769
OrtValue outputTensor = outputs.get(0);
770
float[] results = outputTensor.getFloatArray();
771
long[] outputShape = outputTensor.getShape();
772
773
System.out.println("ONNX inference completed");
774
System.out.println("Output shape: " + java.util.Arrays.toString(outputShape));
775
System.out.println("First 5 results: " +
776
java.util.Arrays.toString(java.util.Arrays.copyOf(results, 5)));
777
}
778
}
779
}
780
```
781
782
### TensorRT Optimized Inference
783
784
```java
785
import org.bytedeco.tensorrt.*;
786
import org.bytedeco.cuda.cudart.*;
787
import static org.bytedeco.tensorrt.global.nvinfer.*;
788
import static org.bytedeco.cuda.global.cudart.*;
789
790
public class TensorRTInference {
791
static {
792
Loader.load(nvinfer.class);
793
Loader.load(cudart.class);
794
}
795
796
public static void optimizeAndInfer() {
797
try (PointerScope scope = new PointerScope()) {
798
// Create logger and builder
799
ILogger logger = new Logger();
800
IBuilder builder = IBuilder.createInferBuilder(logger);
801
802
// Create network
803
INetworkDefinition network = builder.createNetworkV2(
804
1 << NetworkDefinitionCreationFlag.kEXPLICIT_BATCH);
805
806
// Define network architecture (simplified example)
807
Dims inputDims = new Dims(4);
808
inputDims.d(0, 1).d(1, 3).d(2, 224).d(3, 224); // NCHW format
809
810
ITensor input = network.addInput("input", DataType.kFLOAT, inputDims);
811
812
// Add layers (this is a simplified example)
813
// In practice, you would load weights and build complete network
814
815
network.markOutput(input); // Placeholder output
816
817
// Build engine
818
IBuilderConfig config = builder.createBuilderConfig();
819
config.setMaxWorkspaceSize(1L << 30); // 1GB workspace
820
821
IHostMemory serializedEngine = builder.buildSerializedNetwork(network, config);
822
823
// Create runtime and deserialize engine
824
IRuntime runtime = IRuntime.createInferRuntime(logger);
825
ICudaEngine engine = runtime.deserializeCudaEngine(
826
serializedEngine.data(), serializedEngine.size());
827
828
// Create execution context
829
IExecutionContext context = engine.createExecutionContext();
830
831
// Allocate GPU memory for inputs/outputs
832
int inputSize = 1 * 3 * 224 * 224 * 4; // batch_size * channels * height * width * sizeof(float)
833
Pointer inputGPU = new Pointer();
834
Pointer outputGPU = new Pointer();
835
836
cudaMalloc(inputGPU, inputSize);
837
cudaMalloc(outputGPU, inputSize); // Assume same size for output
838
839
// Setup bindings
840
PointerPointer bindings = new PointerPointer(2);
841
bindings.put(0, inputGPU);
842
bindings.put(1, outputGPU);
843
844
// Copy input data to GPU (simplified)
845
// cudaMemcpy(inputGPU, hostInputData, inputSize, cudaMemcpyHostToDevice);
846
847
// Execute inference
848
boolean success = context.execute(bindings);
849
850
if (success) {
851
System.out.println("TensorRT inference completed successfully");
852
853
// Copy results back to host
854
// cudaMemcpy(hostOutputData, outputGPU, outputSize, cudaMemcpyDeviceToHost);
855
}
856
857
// Cleanup GPU memory
858
cudaFree(inputGPU);
859
cudaFree(outputGPU);
860
}
861
}
862
863
static class Logger extends ILogger {
864
@Override
865
public void log(Severity severity, String msg) {
866
System.out.println("[TensorRT " + severity + "] " + msg);
867
}
868
}
869
}
870
```