0
# Shared Kernel Implementations
1
2
The `shared` module provides 50+ optimized kernel implementations that can be reused across different TensorFlow.js backends. These implementations handle the core mathematical operations, array manipulations, and specialized computations with broadcasting support and performance optimizations.
3
4
## Module Import
5
6
```typescript { .api }
7
import { shared } from '@tensorflow/tfjs-backend-cpu/base';
8
9
// Access individual implementations
10
const { addImpl, multiplyImpl, expImpl } = shared;
11
```
12
13
## Type Definitions
14
15
### Operation Function Types
16
17
```typescript { .api }
18
// Basic binary operation signature
19
type SimpleBinaryOperation = (a: number | string, b: number | string) => number;
20
21
// Binary kernel implementation with broadcasting
22
type SimpleBinaryKernelImpl = (
23
aShape: number[],
24
bShape: number[],
25
aVals: TypedArray | string[],
26
bVals: TypedArray | string[],
27
dtype: DataType
28
) => [TypedArray, number[]];
29
30
// Complex number binary operations
31
type ComplexBinaryOperation = (
32
aReal: number,
33
aImag: number,
34
bReal: number,
35
bImag: number
36
) => { real: number, imag: number };
37
38
// Complex binary kernel implementation
39
type ComplexBinaryKernelImpl = (
40
aShape: number[],
41
bShape: number[],
42
aRealVals: Float32Array,
43
aImagVals: Float32Array,
44
bRealVals: Float32Array,
45
bImagVals: Float32Array
46
) => [TypedArray, TypedArray, number[]];
47
```
48
49
### Common Input/Output Types
50
51
```typescript { .api }
52
// Common typed arrays used in implementations
53
type TypedArray =
54
| Float32Array
55
| Int32Array
56
| Uint8Array
57
| Uint16Array
58
| Uint32Array
59
| Int16Array
60
| Int8Array;
61
62
// Return type for most operations
63
type OperationResult = [TypedArray, number[]]; // [values, outputShape]
64
65
// Utility types for complex operations
66
interface ComplexResult {
67
real: TypedArray;
68
imag: TypedArray;
69
outputShape: number[];
70
}
71
```
72
73
## Mathematical Operations
74
75
### Basic Arithmetic
76
77
#### addImpl()
78
79
```typescript { .api }
80
function addImpl(
81
aShape: number[],
82
bShape: number[],
83
aVals: TypedArray | string[],
84
bVals: TypedArray | string[],
85
dtype: DataType
86
): [TypedArray, number[]]
87
```
88
89
Element-wise addition with automatic broadcasting support.
90
91
**Parameters:**
92
- `aShape: number[]` - Shape of first tensor
93
- `bShape: number[]` - Shape of second tensor
94
- `aVals: TypedArray | string[]` - Values of first tensor
95
- `bVals: TypedArray | string[]` - Values of second tensor
96
- `dtype: DataType` - Output data type
97
98
**Returns:** `[TypedArray, number[]]` - Result values and output shape
99
100
**Example:**
101
```typescript { .api }
102
import { shared } from '@tensorflow/tfjs-backend-cpu/base';
103
104
// Element-wise addition
105
const [result, shape] = shared.addImpl(
106
[2, 2], // aShape
107
[2, 2], // bShape
108
new Float32Array([1, 2, 3, 4]), // aVals
109
new Float32Array([5, 6, 7, 8]), // bVals
110
'float32' // dtype
111
);
112
console.log(result); // Float32Array([6, 8, 10, 12])
113
console.log(shape); // [2, 2]
114
115
// Broadcasting example
116
const [broadcastResult, broadcastShape] = shared.addImpl(
117
[3, 1], // aShape
118
[1, 4], // bShape (broadcasts to [3, 4])
119
new Float32Array([1, 2, 3]), // aVals
120
new Float32Array([10, 20, 30, 40]), // bVals
121
'float32'
122
);
123
console.log(broadcastShape); // [3, 4] - broadcasted output shape
124
```
125
126
#### multiplyImpl()
127
128
```typescript { .api }
129
function multiplyImpl(
130
aShape: number[],
131
bShape: number[],
132
aVals: TypedArray,
133
bVals: TypedArray,
134
dtype: DataType
135
): [TypedArray, number[]]
136
```
137
138
Element-wise multiplication with broadcasting.
139
140
**Example:**
141
```typescript { .api }
142
const [result, shape] = shared.multiplyImpl(
143
[2, 3], // aShape
144
[2, 3], // bShape
145
new Float32Array([1, 2, 3, 4, 5, 6]), // aVals
146
new Float32Array([2, 2, 2, 2, 2, 2]), // bVals
147
'float32'
148
);
149
console.log(result); // Float32Array([2, 4, 6, 8, 10, 12])
150
```
151
152
#### subImpl()
153
154
```typescript { .api }
155
function subImpl(
156
aShape: number[],
157
bShape: number[],
158
aVals: TypedArray,
159
bVals: TypedArray,
160
dtype: DataType
161
): [TypedArray, number[]]
162
```
163
164
Element-wise subtraction with broadcasting.
165
166
**Example:**
167
```typescript { .api }
168
const [result, shape] = shared.subImpl(
169
[3], // aShape
170
[3], // bShape
171
new Float32Array([5, 7, 9]), // aVals
172
new Float32Array([1, 2, 3]), // bVals
173
'float32'
174
);
175
console.log(result); // Float32Array([4, 5, 6])
176
```
177
178
### Unary Mathematical Functions
179
180
#### simpleAbsImpl()
181
182
```typescript { .api }
183
function simpleAbsImpl(vals: TypedArray): Float32Array
184
```
185
186
Computes absolute values element-wise.
187
188
**Parameters:**
189
- `vals: TypedArray` - Input values
190
191
**Returns:** `Float32Array` - Absolute values
192
193
**Example:**
194
```typescript { .api }
195
const result = shared.simpleAbsImpl(new Float32Array([-2, -1, 0, 1, 2]));
196
console.log(result); // Float32Array([2, 1, 0, 1, 2])
197
```
198
199
#### expImpl()
200
201
```typescript { .api }
202
function expImpl(vals: TypedArray): Float32Array
203
```
204
205
Computes exponential (e^x) element-wise.
206
207
**Example:**
208
```typescript { .api }
209
const result = shared.expImpl(new Float32Array([0, 1, 2]));
210
console.log(result); // Float32Array([1, 2.718..., 7.389...])
211
```
212
213
#### logImpl()
214
215
```typescript { .api }
216
function logImpl(vals: TypedArray): Float32Array
217
```
218
219
Computes natural logarithm element-wise.
220
221
**Example:**
222
```typescript { .api }
223
const result = shared.logImpl(new Float32Array([1, Math.E, Math.E * Math.E]));
224
console.log(result); // Float32Array([0, 1, 2])
225
```
226
227
#### sqrtImpl()
228
229
```typescript { .api }
230
function sqrtImpl(vals: TypedArray): Float32Array
231
```
232
233
Computes square root element-wise.
234
235
**Example:**
236
```typescript { .api }
237
const result = shared.sqrtImpl(new Float32Array([1, 4, 9, 16]));
238
console.log(result); // Float32Array([1, 2, 3, 4])
239
```
240
241
#### rsqrtImpl()
242
243
```typescript { .api }
244
function rsqrtImpl(vals: TypedArray): Float32Array
245
```
246
247
Computes reciprocal square root (1/sqrt(x)) element-wise.
248
249
**Example:**
250
```typescript { .api }
251
const result = shared.rsqrtImpl(new Float32Array([1, 4, 9, 16]));
252
console.log(result); // Float32Array([1, 0.5, 0.333..., 0.25])
253
```
254
255
#### negImpl()
256
257
```typescript { .api }
258
function negImpl(vals: TypedArray, dtype: DataType): TypedArray
259
```
260
261
Computes negation (-x) element-wise.
262
263
**Example:**
264
```typescript { .api }
265
const result = shared.negImpl(new Float32Array([1, -2, 3, -4]), 'float32');
266
console.log(result); // Float32Array([-1, 2, -3, 4])
267
```
268
269
### Advanced Mathematical Functions
270
271
#### sigmoidImpl()
272
273
```typescript { .api }
274
function sigmoidImpl(vals: TypedArray): Float32Array
275
```
276
277
Computes sigmoid activation function: 1 / (1 + e^(-x)).
278
279
**Example:**
280
```typescript { .api }
281
const result = shared.sigmoidImpl(new Float32Array([-2, -1, 0, 1, 2]));
282
console.log(result); // Float32Array([0.119, 0.269, 0.5, 0.731, 0.881])
283
```
284
285
#### expm1Impl()
286
287
```typescript { .api }
288
function expm1Impl(vals: TypedArray): Float32Array
289
```
290
291
Computes e^x - 1, more accurate for small values than exp(x) - 1.
292
293
**Example:**
294
```typescript { .api }
295
const result = shared.expm1Impl(new Float32Array([0, 0.001, 1]));
296
console.log(result); // More accurate than exp() - 1 for small values
297
```
298
299
#### ceilImpl()
300
301
```typescript { .api }
302
function ceilImpl(vals: TypedArray): Float32Array
303
```
304
305
Computes ceiling function element-wise.
306
307
**Example:**
308
```typescript { .api }
309
const result = shared.ceilImpl(new Float32Array([1.1, 2.8, -1.2, -0.5]));
310
console.log(result); // Float32Array([2, 3, -1, 0])
311
```
312
313
#### floorImpl()
314
315
```typescript { .api }
316
function floorImpl(vals: TypedArray): Float32Array
317
```
318
319
Computes floor function element-wise.
320
321
**Example:**
322
```typescript { .api }
323
const result = shared.floorImpl(new Float32Array([1.8, 2.1, -1.2, -0.5]));
324
console.log(result); // Float32Array([1, 2, -2, -1])
325
```
326
327
## Comparison Operations
328
329
### Basic Comparisons
330
331
#### equalImpl()
332
333
```typescript { .api }
334
function equalImpl(
335
aShape: number[],
336
bShape: number[],
337
aVals: TypedArray,
338
bVals: TypedArray,
339
dtype: DataType
340
): [Uint8Array, number[]]
341
```
342
343
Element-wise equality comparison with broadcasting.
344
345
**Returns:** `[Uint8Array, number[]]` - Boolean results as Uint8Array (1 for true, 0 for false)
346
347
**Example:**
348
```typescript { .api }
349
const [result, shape] = shared.equalImpl(
350
[3], // aShape
351
[3], // bShape
352
new Float32Array([1, 2, 3]), // aVals
353
new Float32Array([1, 4, 3]), // bVals
354
'float32'
355
);
356
console.log(result); // Uint8Array([1, 0, 1]) - true, false, true
357
```
358
359
#### greaterImpl()
360
361
```typescript { .api }
362
function greaterImpl(
363
aShape: number[],
364
bShape: number[],
365
aVals: TypedArray,
366
bVals: TypedArray,
367
dtype: DataType
368
): [Uint8Array, number[]]
369
```
370
371
Element-wise greater-than comparison.
372
373
**Example:**
374
```typescript { .api }
375
const [result, shape] = shared.greaterImpl(
376
[3], // aShape
377
[3], // bShape
378
new Float32Array([3, 2, 1]), // aVals
379
new Float32Array([1, 2, 3]), // bVals
380
'float32'
381
);
382
console.log(result); // Uint8Array([1, 0, 0]) - true, false, false
383
```
384
385
#### lessImpl()
386
387
```typescript { .api }
388
function lessImpl(
389
aShape: number[],
390
bShape: number[],
391
aVals: TypedArray,
392
bVals: TypedArray,
393
dtype: DataType
394
): [Uint8Array, number[]]
395
```
396
397
Element-wise less-than comparison.
398
399
**Example:**
400
```typescript { .api }
401
const [result, shape] = shared.lessImpl(
402
[2, 2], // aShape
403
[2, 2], // bShape
404
new Float32Array([1, 5, 3, 2]), // aVals
405
new Float32Array([2, 4, 3, 1]), // bVals
406
'float32'
407
);
408
console.log(result); // Uint8Array([1, 0, 0, 0])
409
```
410
411
### Advanced Comparisons
412
413
#### greaterEqualImpl()
414
415
```typescript { .api }
416
function greaterEqualImpl(
417
aShape: number[],
418
bShape: number[],
419
aVals: TypedArray,
420
bVals: TypedArray,
421
dtype: DataType
422
): [Uint8Array, number[]]
423
```
424
425
Element-wise greater-than-or-equal comparison.
426
427
#### lessEqualImpl()
428
429
```typescript { .api }
430
function lessEqualImpl(
431
aShape: number[],
432
bShape: number[],
433
aVals: TypedArray,
434
bVals: TypedArray,
435
dtype: DataType
436
): [Uint8Array, number[]]
437
```
438
439
Element-wise less-than-or-equal comparison.
440
441
#### notEqualImpl()
442
443
```typescript { .api }
444
function notEqualImpl(
445
aShape: number[],
446
bShape: number[],
447
aVals: TypedArray,
448
bVals: TypedArray,
449
dtype: DataType
450
): [Uint8Array, number[]]
451
```
452
453
Element-wise not-equal comparison.
454
455
## Binary Operations
456
457
### Advanced Arithmetic
458
459
#### floorDivImpl()
460
461
```typescript { .api }
462
function floorDivImpl(
463
aShape: number[],
464
bShape: number[],
465
aVals: TypedArray,
466
bVals: TypedArray,
467
dtype: DataType
468
): [TypedArray, number[]]
469
```
470
471
Element-wise floor division with broadcasting.
472
473
**Example:**
474
```typescript { .api }
475
const [result, shape] = shared.floorDivImpl(
476
[3], // aShape
477
[3], // bShape
478
new Float32Array([7, 8, 9]), // aVals
479
new Float32Array([2, 3, 4]), // bVals
480
'float32'
481
);
482
console.log(result); // Float32Array([3, 2, 2]) - floor(7/2), floor(8/3), floor(9/4)
483
```
484
485
#### maximumImpl()
486
487
```typescript { .api }
488
function maximumImpl(
489
aShape: number[],
490
bShape: number[],
491
aVals: TypedArray,
492
bVals: TypedArray,
493
dtype: DataType
494
): [TypedArray, number[]]
495
```
496
497
Element-wise maximum with broadcasting.
498
499
**Example:**
500
```typescript { .api }
501
const [result, shape] = shared.maximumImpl(
502
[3], // aShape
503
[3], // bShape
504
new Float32Array([1, 5, 3]), // aVals
505
new Float32Array([4, 2, 6]), // bVals
506
'float32'
507
);
508
console.log(result); // Float32Array([4, 5, 6])
509
```
510
511
#### minimumImpl()
512
513
```typescript { .api }
514
function minimumImpl(
515
aShape: number[],
516
bShape: number[],
517
aVals: TypedArray,
518
bVals: TypedArray,
519
dtype: DataType
520
): [TypedArray, number[]]
521
```
522
523
Element-wise minimum with broadcasting.
524
525
**Example:**
526
```typescript { .api }
527
const [result, shape] = shared.minimumImpl(
528
[3], // aShape
529
[3], // bShape
530
new Float32Array([1, 5, 3]), // aVals
531
new Float32Array([4, 2, 6]), // bVals
532
'float32'
533
);
534
console.log(result); // Float32Array([1, 2, 3])
535
```
536
537
#### squaredDifferenceImpl()
538
539
```typescript { .api }
540
function squaredDifferenceImpl(
541
aShape: number[],
542
bShape: number[],
543
aVals: TypedArray,
544
bVals: TypedArray,
545
dtype: DataType
546
): [TypedArray, number[]]
547
```
548
549
Element-wise squared difference: (a - b)^2.
550
551
**Example:**
552
```typescript { .api }
553
const [result, shape] = shared.squaredDifferenceImpl(
554
[3], // aShape
555
[3], // bShape
556
new Float32Array([4, 6, 8]), // aVals
557
new Float32Array([2, 4, 6]), // bVals
558
'float32'
559
);
560
console.log(result); // Float32Array([4, 4, 4]) - (4-2)^2, (6-4)^2, (8-6)^2
561
```
562
563
## Array Manipulation
564
565
### Concatenation and Slicing
566
567
#### concatImpl()
568
569
```typescript { .api }
570
function concatImpl(
571
tensors: TypedArray[],
572
outShape: number[],
573
dtype: DataType,
574
axis: number
575
): TypedArray
576
```
577
578
Concatenates multiple tensors along specified axis.
579
580
**Parameters:**
581
- `tensors: TypedArray[]` - Array of tensor values to concatenate
582
- `outShape: number[]` - Output shape after concatenation
583
- `dtype: DataType` - Output data type
584
- `axis: number` - Axis along which to concatenate
585
586
**Example:**
587
```typescript { .api }
588
const result = shared.concatImpl(
589
[
590
new Float32Array([1, 2]), // First tensor
591
new Float32Array([3, 4]), // Second tensor
592
new Float32Array([5, 6]) // Third tensor
593
],
594
[6], // Output shape: [6] (concatenating 3 tensors of shape [2])
595
'float32', // Data type
596
0 // Axis 0
597
);
598
console.log(result); // Float32Array([1, 2, 3, 4, 5, 6])
599
600
// 2D concatenation example
601
const result2d = shared.concatImpl(
602
[
603
new Float32Array([1, 2, 3, 4]), // Shape [2, 2]
604
new Float32Array([5, 6, 7, 8]) // Shape [2, 2]
605
],
606
[4, 2], // Output shape after concat along axis 0
607
'float32',
608
0 // Concatenate along rows
609
);
610
```
611
612
#### sliceImpl()
613
614
```typescript { .api }
615
function sliceImpl(
616
vals: TypedArray,
617
begin: number[],
618
size: number[],
619
shape: number[],
620
dtype: DataType
621
): [TypedArray, number[]]
622
```
623
624
Extracts a slice from a tensor.
625
626
**Parameters:**
627
- `vals: TypedArray` - Input tensor values
628
- `begin: number[]` - Starting indices for each dimension
629
- `size: number[]` - Size of slice in each dimension
630
- `shape: number[]` - Input tensor shape
631
- `dtype: DataType` - Data type
632
633
**Example:**
634
```typescript { .api }
635
const [result, outShape] = shared.sliceImpl(
636
new Float32Array([1, 2, 3, 4, 5, 6]), // Input: [[1,2,3], [4,5,6]]
637
[0, 1], // Begin at [row=0, col=1]
638
[2, 2], // Take 2 rows, 2 columns
639
[2, 3], // Input shape [2, 3]
640
'float32'
641
);
642
console.log(result); // Float32Array([2, 3, 5, 6])
643
console.log(outShape); // [2, 2]
644
```
645
646
#### stridedSliceImpl()
647
648
```typescript { .api }
649
function stridedSliceImpl(
650
vals: TypedArray,
651
begin: number[],
652
end: number[],
653
strides: number[],
654
beginMask: number,
655
endMask: number,
656
ellipsisMask: number,
657
newAxisMask: number,
658
shrinkAxisMask: number,
659
shape: number[],
660
dtype: DataType
661
): [TypedArray, number[]]
662
```
663
664
Advanced slicing with strides and masking options.
665
666
**Example:**
667
```typescript { .api }
668
const [result, outShape] = shared.stridedSliceImpl(
669
new Float32Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), // Input [10]
670
[1], // begin
671
[8], // end
672
[2], // strides (every 2nd element)
673
0, 0, 0, 0, 0, // masks
674
[10], // input shape
675
'float32'
676
);
677
console.log(result); // Float32Array([1, 3, 5, 7])
678
```
679
680
### Shape Manipulation
681
682
#### tileImpl()
683
684
```typescript { .api }
685
function tileImpl(
686
vals: TypedArray,
687
shape: number[],
688
dtype: DataType,
689
reps: number[]
690
): TypedArray
691
```
692
693
Tiles (repeats) a tensor according to repetition counts.
694
695
**Parameters:**
696
- `vals: TypedArray` - Input tensor values
697
- `shape: number[]` - Input tensor shape
698
- `dtype: DataType` - Data type
699
- `reps: number[]` - Repetition count for each dimension
700
701
**Example:**
702
```typescript { .api }
703
const result = shared.tileImpl(
704
new Float32Array([1, 2]), // Input: [1, 2]
705
[2], // Input shape: [2]
706
'float32',
707
[3] // Repeat 3 times
708
);
709
console.log(result); // Float32Array([1, 2, 1, 2, 1, 2])
710
711
// 2D tiling
712
const result2d = shared.tileImpl(
713
new Float32Array([1, 2, 3, 4]), // Input: [[1,2], [3,4]]
714
[2, 2], // Input shape: [2, 2]
715
'float32',
716
[2, 3] // Repeat 2x vertically, 3x horizontally
717
);
718
// Result shape: [4, 6]
719
```
720
721
#### transposeImpl()
722
723
```typescript { .api }
724
function transposeImpl(
725
vals: TypedArray,
726
shape: number[],
727
dtype: DataType,
728
perm: number[]
729
): TypedArray
730
```
731
732
Transposes tensor dimensions according to permutation.
733
734
**Parameters:**
735
- `vals: TypedArray` - Input tensor values
736
- `shape: number[]` - Input tensor shape
737
- `dtype: DataType` - Data type
738
- `perm: number[]` - Permutation of dimensions
739
740
**Example:**
741
```typescript { .api }
742
const result = shared.transposeImpl(
743
new Float32Array([1, 2, 3, 4, 5, 6]), // Input: [[1,2,3], [4,5,6]]
744
[2, 3], // Input shape [2 rows, 3 cols]
745
'float32',
746
[1, 0] // Permutation: swap dimensions [rows, cols] -> [cols, rows]
747
);
748
console.log(result); // Float32Array([1, 4, 2, 5, 3, 6])
749
// Result represents [[1,4], [2,5], [3,6]] - shape [3, 2]
750
```
751
752
## Specialized Operations
753
754
### Gather Operations
755
756
#### gatherV2Impl()
757
758
```typescript { .api }
759
function gatherV2Impl(
760
xVals: TypedArray,
761
xShape: number[],
762
xDtype: DataType,
763
indices: TypedArray,
764
indicesShape: number[],
765
axis: number
766
): [TypedArray, number[]]
767
```
768
769
Gathers values from tensor at specified indices along an axis.
770
771
**Example:**
772
```typescript { .api }
773
const [result, outShape] = shared.gatherV2Impl(
774
new Float32Array([10, 20, 30, 40, 50, 60]), // Input: [[10,20,30], [40,50,60]]
775
[2, 3], // Input shape
776
'float32',
777
new Int32Array([2, 0, 1]), // Indices to gather
778
[3], // Indices shape
779
1 // Axis 1 (columns)
780
);
781
console.log(result); // Gathers columns 2, 0, 1
782
```
783
784
#### gatherNdImpl()
785
786
```typescript { .api }
787
function gatherNdImpl(
788
indicesData: TypedArray,
789
indicesShape: number[],
790
paramsData: TypedArray,
791
paramsShape: number[],
792
dtype: DataType
793
): [TypedArray, number[]]
794
```
795
796
N-dimensional gather operation using multi-dimensional indices.
797
798
**Example:**
799
```typescript { .api }
800
const [result, outShape] = shared.gatherNdImpl(
801
new Int32Array([0, 1, 1, 0]), // Indices: [[0,1], [1,0]]
802
[2, 2], // Indices shape
803
new Float32Array([1, 2, 3, 4]), // Params: [[1,2], [3,4]]
804
[2, 2], // Params shape
805
'float32'
806
);
807
console.log(result); // Float32Array([2, 3]) - values at [0,1] and [1,0]
808
```
809
810
### Scatter Operations
811
812
#### scatterImpl()
813
814
```typescript { .api }
815
function scatterImpl(
816
indices: TypedArray,
817
updates: TypedArray,
818
shape: number[],
819
outputSize: number,
820
sliceSize: number,
821
numUpdates: number,
822
sliceRank: number,
823
strides: number[],
824
defaultValue: number,
825
sumDupeIndices: boolean
826
): TypedArray
827
```
828
829
Scatters updates into output tensor at specified indices.
830
831
**Example:**
832
```typescript { .api }
833
const result = shared.scatterImpl(
834
new Int32Array([1, 3]), // Indices where to scatter
835
new Float32Array([10, 20]), // Values to scatter
836
[5], // Output shape
837
5, // Output size
838
1, // Slice size
839
2, // Number of updates
840
0, // Slice rank
841
[1], // Strides
842
0, // Default value
843
false // Don't sum duplicate indices
844
);
845
console.log(result); // Float32Array([0, 10, 0, 20, 0])
846
```
847
848
### Statistical Operations
849
850
#### topKImpl()
851
852
```typescript { .api }
853
function topKImpl(
854
xVals: TypedArray,
855
xShape: number[],
856
xDtype: DataType,
857
k: number,
858
sorted: boolean
859
): [TypedArray, Int32Array, number[]]
860
```
861
862
Finds top-k largest values and their indices.
863
864
**Returns:** `[values, indices, outputShape]`
865
866
**Example:**
867
```typescript { .api }
868
const [values, indices, shape] = shared.topKImpl(
869
new Float32Array([3, 1, 4, 1, 5, 9, 2, 6]), // Input values
870
[8], // Input shape
871
'float32',
872
3, // k=3 (top 3)
873
true // sorted=true
874
);
875
console.log(values); // Float32Array([9, 6, 5]) - top 3 values
876
console.log(indices); // Int32Array([5, 7, 4]) - their indices
877
```
878
879
#### uniqueImpl()
880
881
```typescript { .api }
882
function uniqueImpl(
883
values: TypedArray,
884
axis: number,
885
shape: number[],
886
dtype: DataType
887
): { outputValues: TypedArray, outputShape: number[], indices: Int32Array }
888
```
889
890
Finds unique values in tensor.
891
892
**Example:**
893
```typescript { .api }
894
const result = shared.uniqueImpl(
895
new Float32Array([1, 2, 1, 3, 2, 4]), // Input with duplicates
896
0, // Axis
897
[6], // Shape
898
'float32'
899
);
900
console.log(result.outputValues); // Float32Array([1, 2, 3, 4])
901
console.log(result.indices); // Int32Array mapping original to unique indices
902
```
903
904
### Range Generation
905
906
#### rangeImpl()
907
908
```typescript { .api }
909
function rangeImpl(
910
start: number,
911
stop: number,
912
step: number,
913
dtype: 'float32' | 'int32'
914
): TypedArray
915
```
916
917
Generates a sequence of numbers.
918
919
**Example:**
920
```typescript { .api }
921
const result = shared.rangeImpl(0, 10, 2, 'float32');
922
console.log(result); // Float32Array([0, 2, 4, 6, 8])
923
924
const negativeStep = shared.rangeImpl(10, 0, -1.5, 'float32');
925
console.log(negativeStep); // Float32Array([10, 8.5, 7, 5.5, 4, 2.5, 1])
926
```
927
928
#### linSpaceImpl()
929
930
```typescript { .api }
931
function linSpaceImpl(
932
start: number,
933
stop: number,
934
num: number
935
): Float32Array
936
```
937
938
Generates linearly spaced values between start and stop.
939
940
**Example:**
941
```typescript { .api }
942
const result = shared.linSpaceImpl(0, 1, 5);
943
console.log(result); // Float32Array([0, 0.25, 0.5, 0.75, 1])
944
945
const reverse = shared.linSpaceImpl(10, 0, 6);
946
console.log(reverse); // Float32Array([10, 8, 6, 4, 2, 0])
947
```
948
949
## String Operations
950
951
### Text Processing
952
953
#### stringNGramsImpl()
954
955
```typescript { .api }
956
function stringNGramsImpl(
957
data: Uint8Array[],
958
dataSplits: TypedArray,
959
separator: string,
960
nGramWidths: number[],
961
leftPad: string,
962
rightPad: string,
963
padWidth: number,
964
preserveShortSequences: boolean
965
): [Uint8Array[], Int32Array]
966
```
967
968
Generates N-grams from string sequences.
969
970
**Example:**
971
```typescript { .api }
972
// Convert strings to Uint8Array format expected by implementation
973
const textEncoder = new TextEncoder();
974
const data = [
975
textEncoder.encode('hello'),
976
textEncoder.encode('world')
977
];
978
const splits = new Int32Array([0, 1, 2]); // Split positions
979
980
const [ngrams, ngramSplits] = shared.stringNGramsImpl(
981
data,
982
splits,
983
' ', // separator
984
[2, 3], // Generate 2-grams and 3-grams
985
'<s>', // left padding
986
'</s>', // right padding
987
1, // pad width
988
true // preserve short sequences
989
);
990
```
991
992
#### stringSplitImpl()
993
994
```typescript { .api }
995
function stringSplitImpl(
996
input: Uint8Array[],
997
delimiter: Uint8Array,
998
skipEmpty: boolean,
999
result: Uint8Array[],
1000
resultSplits: TypedArray,
1001
maxSplit?: number
1002
): void
1003
```
1004
1005
Splits strings by delimiter.
1006
1007
#### staticRegexReplaceImpl()
1008
1009
```typescript { .api }
1010
function staticRegexReplaceImpl(
1011
input: Uint8Array[],
1012
pattern: string,
1013
rewrite: string,
1014
global: boolean
1015
): Uint8Array[]
1016
```
1017
1018
Performs regex replacement on strings.
1019
1020
### String Utilities
1021
1022
#### stringToHashBucketFastImpl()
1023
1024
```typescript { .api }
1025
function stringToHashBucketFastImpl(
1026
input: Uint8Array[],
1027
numBuckets: number
1028
): Int32Array
1029
```
1030
1031
Hashes strings to bucket indices.
1032
1033
**Example:**
1034
```typescript { .api }
1035
const textEncoder = new TextEncoder();
1036
const strings = [
1037
textEncoder.encode('apple'),
1038
textEncoder.encode('banana'),
1039
textEncoder.encode('cherry')
1040
];
1041
1042
const buckets = shared.stringToHashBucketFastImpl(strings, 10);
1043
console.log(buckets); // Int32Array with bucket indices [0-9]
1044
```
1045
1046
## Sparse Operations
1047
1048
### Sparse Matrix Operations
1049
1050
#### sparseFillEmptyRowsImpl()
1051
1052
```typescript { .api }
1053
function sparseFillEmptyRowsImpl(
1054
indices: TypedArray,
1055
indicesShape: number[],
1056
values: TypedArray,
1057
denseShape: TypedArray,
1058
defaultValue: number | string
1059
): [TypedArray, TypedArray, Uint8Array, TypedArray]
1060
```
1061
1062
Fills empty rows in sparse matrix representation.
1063
1064
#### sparseReshapeImpl()
1065
1066
```typescript { .api }
1067
function sparseReshapeImpl(
1068
inputIndices: TypedArray,
1069
inputIndicesShape: number[],
1070
inputShape: number[],
1071
targetShape: number[]
1072
): [TypedArray, number[]]
1073
```
1074
1075
Reshapes sparse tensor representation.
1076
1077
#### sparseSegmentReductionImpl()
1078
1079
```typescript { .api }
1080
function sparseSegmentReductionImpl(
1081
input: TypedArray,
1082
inputShape: number[],
1083
inputDType: DataType,
1084
indices: TypedArray,
1085
segmentIds: TypedArray,
1086
isMean: boolean = false,
1087
defaultValue: number = 0
1088
): [TypedArray, number[]]
1089
```
1090
1091
Performs segment reduction on sparse data.
1092
1093
## Ragged Tensor Operations
1094
1095
### Ragged Data Processing
1096
1097
#### raggedGatherImpl()
1098
1099
```typescript { .api }
1100
function raggedGatherImpl(
1101
paramsNestedSplits: TypedArray[],
1102
paramsNestedSplitsShapes: number[][],
1103
paramsDenseValues: TypedArray,
1104
paramsDenseValuesShape: number[],
1105
paramsDenseValuesDType: DataType,
1106
indices: TypedArray,
1107
indicesShape: number[],
1108
outputRaggedRank: number
1109
): [TypedArray[], number[][], TypedArray, number[]]
1110
```
1111
1112
Gathers from ragged tensors using nested indices.
1113
1114
#### raggedRangeImpl()
1115
1116
```typescript { .api }
1117
function raggedRangeImpl(
1118
starts: TypedArray,
1119
startsShape: number[],
1120
limits: TypedArray,
1121
limitsShape: number[],
1122
deltas: TypedArray
1123
): [TypedArray, Int32Array]
1124
```
1125
1126
Generates ragged ranges with different limits per row.
1127
1128
#### raggedTensorToTensorImpl()
1129
1130
```typescript { .api }
1131
function raggedTensorToTensorImpl(
1132
shape: number[],
1133
shapesShape: number[],
1134
values: TypedArray,
1135
valuesShape: number[],
1136
valuesDType: DataType,
1137
defaultValue: TypedArray,
1138
defaultValueShape: number[],
1139
rowPartitionValues: TypedArray[],
1140
rowPartitionValuesShapes: number[][],
1141
rowPartitionTypes: string[]
1142
): TypedArray
1143
```
1144
1145
Converts ragged tensor to dense tensor representation.
1146
1147
## Utility Operations
1148
1149
### Type Conversion
1150
1151
#### castImpl()
1152
1153
```typescript { .api }
1154
function castImpl(
1155
values: TypedArray,
1156
shape: number[],
1157
inputDType: DataType,
1158
outputDType: DataType
1159
): [TypedArray, number[]]
1160
```
1161
1162
Casts tensor values between data types.
1163
1164
**Example:**
1165
```typescript { .api }
1166
const [result, shape] = shared.castImpl(
1167
new Float32Array([1.7, 2.3, 3.9]), // Input float values
1168
[3], // Shape
1169
'float32', // Input dtype
1170
'int32' // Output dtype
1171
);
1172
console.log(result); // Int32Array([1, 2, 3]) - truncated to integers
1173
```
1174
1175
### Reduction Operations
1176
1177
#### prodImpl()
1178
1179
```typescript { .api }
1180
function prodImpl(
1181
xVals: TypedArray,
1182
reduceSize: number,
1183
outShape: number[],
1184
dtype: DataType
1185
): TypedArray
1186
```
1187
1188
Computes product reduction.
1189
1190
**Example:**
1191
```typescript { .api }
1192
const result = shared.prodImpl(
1193
new Float32Array([2, 3, 4, 5]), // Input values
1194
4, // Reduce all 4 elements
1195
[], // Output shape (scalar)
1196
'float32'
1197
);
1198
console.log(result); // Float32Array([120]) - 2*3*4*5 = 120
1199
```
1200
1201
#### maxImpl()
1202
1203
```typescript { .api }
1204
function maxImpl(
1205
aVals: TypedArray,
1206
reduceSize: number,
1207
outShape: number[],
1208
dtype: DataType
1209
): TypedArray
1210
```
1211
1212
Computes maximum reduction.
1213
1214
**Example:**
1215
```typescript { .api }
1216
const result = shared.maxImpl(
1217
new Float32Array([3, 1, 4, 1, 5, 9, 2, 6]), // Input values
1218
8, // Reduce all elements
1219
[], // Scalar output
1220
'float32'
1221
);
1222
console.log(result); // Float32Array([9]) - maximum value
1223
```
1224
1225
### Histogram Operations
1226
1227
#### bincountImpl()
1228
1229
```typescript { .api }
1230
function bincountImpl(
1231
xVals: Int32Array,
1232
weightsVals: TypedArray,
1233
weightsDtype: DataType,
1234
weightsShape: number[],
1235
size: number
1236
): TypedArray
1237
```
1238
1239
Counts occurrences in bins with optional weights.
1240
1241
**Example:**
1242
```typescript { .api }
1243
const result = shared.bincountImpl(
1244
new Int32Array([1, 1, 2, 2, 2, 3]), // Values to bin
1245
new Float32Array([1, 1, 1, 1, 1, 1]), // Weights (all 1s = simple count)
1246
'float32', // Weights dtype
1247
[6], // Weights shape
1248
4 // Number of bins
1249
);
1250
console.log(result); // Float32Array([0, 2, 3, 1]) - counts per bin
1251
```
1252
1253
### Bitwise Operations
1254
1255
#### bitwiseAndImpl()
1256
1257
```typescript { .api }
1258
function bitwiseAndImpl(
1259
aShape: number[],
1260
bShape: number[],
1261
aVals: Int32Array,
1262
bVals: Int32Array,
1263
dtype: DataType
1264
): [Int32Array, number[]]
1265
```
1266
1267
Element-wise bitwise AND operation.
1268
1269
**Example:**
1270
```typescript { .api }
1271
const [result, shape] = shared.bitwiseAndImpl(
1272
[3], // aShape
1273
[3], // bShape
1274
new Int32Array([5, 3, 7]), // aVals (binary: 101, 011, 111)
1275
new Int32Array([3, 5, 1]), // bVals (binary: 011, 101, 001)
1276
'int32'
1277
);
1278
console.log(result); // Int32Array([1, 1, 1]) - bitwise AND results
1279
```
1280
1281
## Performance Considerations
1282
1283
### Optimized Implementations
1284
1285
All shared implementations are optimized for performance:
1286
1287
- **Broadcasting**: Automatic shape broadcasting for binary operations
1288
- **Memory Efficiency**: In-place operations where possible
1289
- **Type Safety**: Proper TypedArray usage for different data types
1290
- **SIMD-friendly**: Operations structured for potential SIMD optimizations
1291
1292
### Usage in Custom Backends
1293
1294
```typescript { .api }
1295
import { shared } from '@tensorflow/tfjs-backend-cpu/base';
1296
1297
class CustomBackend extends KernelBackend {
1298
customAddOperation(a: TensorInfo, b: TensorInfo): TensorInfo {
1299
const aVals = this.readSync(a.dataId) as Float32Array;
1300
const bVals = this.readSync(b.dataId) as Float32Array;
1301
1302
// Leverage optimized CPU implementation
1303
const [resultVals, resultShape] = shared.addImpl(
1304
a.shape,
1305
b.shape,
1306
aVals,
1307
bVals,
1308
a.dtype
1309
);
1310
1311
return this.makeOutput(resultVals, resultShape, a.dtype);
1312
}
1313
1314
// Custom implementation using multiple shared ops
1315
customComplexOp(input: TensorInfo): TensorInfo {
1316
const inputVals = this.readSync(input.dataId) as Float32Array;
1317
1318
// Chain multiple shared implementations
1319
const expResult = shared.expImpl(inputVals);
1320
const sqrtResult = shared.sqrtImpl(expResult);
1321
const absResult = shared.simpleAbsImpl(sqrtResult);
1322
1323
return this.makeOutput(absResult, input.shape, input.dtype);
1324
}
1325
}
1326
```
1327
1328
### Memory Management
1329
1330
```typescript { .api }
1331
// Efficient use of shared implementations
1332
function efficientBatchProcessing(backend: MathBackendCPU, inputs: TensorInfo[]): TensorInfo[] {
1333
const results: TensorInfo[] = [];
1334
1335
for (const input of inputs) {
1336
const vals = backend.readSync(input.dataId) as Float32Array;
1337
1338
// Process using shared implementation
1339
const processedVals = shared.sigmoidImpl(vals);
1340
const result = backend.makeOutput(processedVals, input.shape, input.dtype);
1341
1342
results.push(result);
1343
1344
// Clean up intermediate if not needed
1345
backend.disposeIntermediateTensorInfo(input);
1346
}
1347
1348
return results;
1349
}
1350
```