0
# Array Type Annotations (array_types)
1
2
Comprehensive type annotations for NumPy, JAX, TensorFlow, and PyTorch arrays with specific precision types for type-safe machine learning code development. Provides a unified typing system across different ML frameworks.
3
4
## Capabilities
5
6
### Core Array Types
7
8
Fundamental array type definitions for general use.
9
10
```python { .api }
11
ArrayLike = Union[np.ndarray, list, tuple, int, float]
12
# Array-like objects including lists, tuples, scalars, and arrays
13
14
Array = ArrayAliasMeta(shape=None, dtype=None)
15
# General NumPy array type
16
17
FloatArray = ArrayAliasMeta(shape=None, dtype=AnyFloat())
18
# Arrays containing floating-point numbers
19
20
IntArray = ArrayAliasMeta(shape=None, dtype=AnyInt())
21
# Arrays containing integer numbers
22
23
BoolArray = ArrayAliasMeta(shape=None, dtype=np.bool_)
24
# Arrays containing boolean values
25
26
StrArray = ArrayAliasMeta(shape=None, dtype=np.str_)
27
# Arrays containing string values
28
```
29
30
### Precision-Specific Integer Types
31
32
Specific integer precision types for precise memory and performance control.
33
34
```python { .api }
35
# Unsigned integer types
36
ui8 = ArrayAliasMeta(shape=None, dtype=np.uint8) # 8-bit unsigned integer arrays (0 to 255)
37
ui16 = ArrayAliasMeta(shape=None, dtype=np.uint16) # 16-bit unsigned integer arrays (0 to 65,535)
38
ui32 = ArrayAliasMeta(shape=None, dtype=np.uint32) # 32-bit unsigned integer arrays (0 to 4.3B)
39
ui64 = ArrayAliasMeta(shape=None, dtype=np.uint64) # 64-bit unsigned integer arrays (0 to 18.4E)
40
41
# Signed integer types
42
i8 = ArrayAliasMeta(shape=None, dtype=np.int8) # 8-bit signed integer arrays (-128 to 127)
43
i16 = ArrayAliasMeta(shape=None, dtype=np.int16) # 16-bit signed integer arrays (-32K to 32K)
44
i32 = ArrayAliasMeta(shape=None, dtype=np.int32) # 32-bit signed integer arrays (-2.1B to 2.1B)
45
i64 = ArrayAliasMeta(shape=None, dtype=np.int64) # 64-bit signed integer arrays (-9.2E to 9.2E)
46
```
47
48
### Precision-Specific Float Types
49
50
Specific floating-point precision types.
51
52
```python { .api }
53
f16 = ArrayAliasMeta(shape=None, dtype=np.float16) # 16-bit half-precision float arrays
54
f32 = ArrayAliasMeta(shape=None, dtype=np.float32) # 32-bit single-precision float arrays
55
f64 = ArrayAliasMeta(shape=None, dtype=np.float64) # 64-bit double-precision float arrays
56
```
57
58
### Complex Number Types
59
60
Complex number array types.
61
62
```python { .api }
63
complex64 = ArrayAliasMeta(shape=None, dtype=np.complex64) # 64-bit complex arrays (32-bit real + 32-bit imag)
64
complex128 = ArrayAliasMeta(shape=None, dtype=np.complex128) # 128-bit complex arrays (64-bit real + 64-bit imag)
65
```
66
67
### Special Types
68
69
Additional specialized array types.
70
71
```python { .api }
72
bool_ = ArrayAliasMeta(shape=None, dtype=np.bool_) # Boolean array type alias
73
PRNGKey = ui32[2] # Random number generator key arrays (2-element uint32 arrays for JAX)
74
```
75
76
### Meta Classes
77
78
Metaclass for creating array type aliases.
79
80
```python { .api }
81
class ArrayAliasMeta(type):
82
"""
83
Metaclass for creating array type aliases with shape and dtype constraints.
84
85
Used internally to create type aliases like f32, i64, etc. that represent
86
arrays with specific data types and optional shape constraints.
87
"""
88
def __new__(
89
cls,
90
shape: tuple[int, ...] | str | None = None,
91
dtype: np.dtype | DType | None = None
92
) -> 'ArrayAliasMeta': ...
93
94
def __getitem__(self, shape_spec: str | tuple | int) -> 'ArrayAliasMeta': ...
95
96
# DType system classes
97
class AnyFloat:
98
"""Matches any floating-point dtype."""
99
100
class AnyInt:
101
"""Matches any integer dtype."""
102
103
class AnyDType:
104
"""Matches any dtype."""
105
```
106
107
## Usage Examples
108
109
### Basic Type Annotations
110
111
```python
112
from etils import array_types
113
import numpy as np
114
115
def process_images(
116
images: array_types.FloatArray,
117
labels: array_types.IntArray
118
) -> array_types.BoolArray:
119
"""
120
Process images and return prediction correctness.
121
122
Args:
123
images: Float array of image data
124
labels: Integer array of true labels
125
126
Returns:
127
Boolean array indicating correct predictions
128
"""
129
# Process images...
130
predictions = model.predict(images)
131
return predictions == labels
132
133
def compute_statistics(
134
data: array_types.ArrayLike
135
) -> tuple[array_types.FloatArray, array_types.FloatArray]:
136
"""
137
Compute mean and standard deviation.
138
139
Args:
140
data: Array-like input data
141
142
Returns:
143
Tuple of (mean, std) arrays
144
"""
145
arr = np.asarray(data)
146
return np.mean(arr), np.std(arr)
147
```
148
149
### Precision-Specific Types
150
151
```python
152
from etils import array_types
153
import numpy as np
154
155
def process_uint8_image(image: array_types.ui8) -> array_types.f32:
156
"""
157
Process 8-bit image and return 32-bit float result.
158
159
Args:
160
image: 8-bit unsigned integer image (0-255)
161
162
Returns:
163
32-bit float processed image (0.0-1.0)
164
"""
165
return image.astype(np.float32) / 255.0
166
167
def high_precision_computation(
168
data: array_types.f64
169
) -> array_types.f64:
170
"""
171
High-precision numerical computation.
172
173
Args:
174
data: 64-bit double precision input
175
176
Returns:
177
64-bit double precision result
178
"""
179
return np.sqrt(data**2 + 1e-15)
180
181
def memory_efficient_indices(
182
large_array: array_types.Array
183
) -> array_types.ui32:
184
"""
185
Generate memory-efficient indices.
186
187
Args:
188
large_array: Input array of any type
189
190
Returns:
191
32-bit unsigned integer indices
192
"""
193
return np.arange(len(large_array), dtype=np.uint32)
194
```
195
196
### Machine Learning Applications
197
198
```python
199
from etils import array_types
200
import numpy as np
201
202
class NeuralNetwork:
203
"""Example neural network with typed parameters."""
204
205
def __init__(self):
206
self.weights: array_types.f32 = np.random.randn(784, 128).astype(np.float32)
207
self.biases: array_types.f32 = np.zeros(128, dtype=np.float32)
208
209
def forward(
210
self,
211
inputs: array_types.f32
212
) -> array_types.f32:
213
"""
214
Forward pass through network.
215
216
Args:
217
inputs: 32-bit float input features
218
219
Returns:
220
32-bit float network outputs
221
"""
222
return np.dot(inputs, self.weights) + self.biases
223
224
def predict(
225
self,
226
inputs: array_types.f32
227
) -> array_types.IntArray:
228
"""
229
Make predictions with the network.
230
231
Args:
232
inputs: Input features
233
234
Returns:
235
Integer class predictions
236
"""
237
logits = self.forward(inputs)
238
return np.argmax(logits, axis=-1)
239
240
def batch_process(
241
batch: array_types.f32,
242
targets: array_types.IntArray
243
) -> dict[str, array_types.FloatArray]:
244
"""
245
Process a batch of training data.
246
247
Args:
248
batch: Batch of input samples
249
targets: Target labels
250
251
Returns:
252
Dictionary containing loss and accuracy arrays
253
"""
254
model = NeuralNetwork()
255
predictions = model.predict(batch)
256
257
accuracy = (predictions == targets).mean()
258
loss = np.mean((predictions - targets) ** 2)
259
260
return {
261
'accuracy': np.array([accuracy]),
262
'loss': np.array([loss])
263
}
264
```
265
266
### Complex Number Operations
267
268
```python
269
from etils import array_types
270
import numpy as np
271
272
def fft_analysis(
273
signal: array_types.FloatArray
274
) -> array_types.complex128:
275
"""
276
Perform FFT analysis on real signal.
277
278
Args:
279
signal: Real-valued input signal
280
281
Returns:
282
Complex-valued frequency domain representation
283
"""
284
return np.fft.fft(signal).astype(np.complex128)
285
286
def complex_multiplication(
287
a: array_types.complex64,
288
b: array_types.complex64
289
) -> array_types.complex64:
290
"""
291
Multiply complex arrays with 64-bit precision.
292
293
Args:
294
a: First complex array
295
b: Second complex array
296
297
Returns:
298
Product of complex arrays
299
"""
300
return a * b
301
```
302
303
### JAX Integration
304
305
```python
306
from etils import array_types
307
import numpy as np
308
309
def generate_random_key(seed: int = 42) -> array_types.PRNGKey:
310
"""
311
Generate random key for JAX operations.
312
313
Args:
314
seed: Random seed
315
316
Returns:
317
PRNG key array for JAX
318
"""
319
# This would typically use jax.random.PRNGKey
320
return np.array([seed, 0], dtype=np.uint32)
321
322
def random_sampling(
323
key: array_types.PRNGKey,
324
shape: tuple[int, ...]
325
) -> array_types.f32:
326
"""
327
Sample random numbers using PRNG key.
328
329
Args:
330
key: PRNG key for random generation
331
shape: Shape of output array
332
333
Returns:
334
Random float32 array
335
"""
336
# Simplified random generation (would use JAX in practice)
337
np.random.seed(int(key[0]))
338
return np.random.randn(*shape).astype(np.float32)
339
```
340
341
### String and Boolean Operations
342
343
```python
344
from etils import array_types
345
import numpy as np
346
347
def process_text_data(
348
texts: array_types.StrArray
349
) -> array_types.IntArray:
350
"""
351
Process text data and return length statistics.
352
353
Args:
354
texts: Array of text strings
355
356
Returns:
357
Array of text lengths
358
"""
359
return np.array([len(text) for text in texts])
360
361
def create_mask(
362
data: array_types.FloatArray,
363
threshold: float = 0.5
364
) -> array_types.BoolArray:
365
"""
366
Create boolean mask from float data.
367
368
Args:
369
data: Input float array
370
threshold: Threshold for mask creation
371
372
Returns:
373
Boolean mask array
374
"""
375
return data > threshold
376
377
def filter_by_mask(
378
data: array_types.Array,
379
mask: array_types.BoolArray
380
) -> array_types.Array:
381
"""
382
Filter data using boolean mask.
383
384
Args:
385
data: Input data array
386
mask: Boolean mask for filtering
387
388
Returns:
389
Filtered data array
390
"""
391
return data[mask]
392
```
393
394
### Type Checking and Validation
395
396
```python
397
from etils import array_types
398
import numpy as np
399
from typing import TypeGuard
400
401
def is_float_array(arr: array_types.Array) -> TypeGuard[array_types.FloatArray]:
402
"""
403
Type guard for float arrays.
404
405
Args:
406
arr: Array to check
407
408
Returns:
409
True if array contains float data
410
"""
411
return np.issubdtype(arr.dtype, np.floating)
412
413
def validate_array_type(
414
arr: array_types.ArrayLike,
415
expected_dtype: np.dtype
416
) -> array_types.Array:
417
"""
418
Validate and convert array to expected type.
419
420
Args:
421
arr: Input array-like data
422
expected_dtype: Expected data type
423
424
Returns:
425
Validated array with correct dtype
426
427
Raises:
428
TypeError: If conversion is not possible
429
"""
430
result = np.asarray(arr)
431
if result.dtype != expected_dtype:
432
try:
433
result = result.astype(expected_dtype)
434
except (ValueError, TypeError) as e:
435
raise TypeError(f"Cannot convert to {expected_dtype}: {e}")
436
return result
437
```