Differentiate, compile, and transform Numpy code.
—
JAX provides a comprehensive set of neural network functions through jax.nn including activation functions, normalization utilities, and attention mechanisms commonly used in machine learning and deep learning applications.
import jax.nn as jnn
from jax.nn import relu, sigmoid, softmax, geluRectified Linear Unit activations and their variants for introducing non-linearity while maintaining computational efficiency.
def relu(x) -> Array:
"""
Rectified Linear Unit activation: max(0, x).
Args:
x: Input array
Returns:
Array with ReLU applied element-wise
"""
def relu6(x) -> Array:
"""
ReLU capped at 6: min(max(0, x), 6).
Args:
x: Input array
Returns:
Array with ReLU6 applied element-wise
"""
def leaky_relu(x, negative_slope=0.01) -> Array:
"""
Leaky ReLU: max(negative_slope * x, x).
Args:
x: Input array
negative_slope: Slope for negative values (default: 0.01)
Returns:
Array with Leaky ReLU applied element-wise
"""
def elu(x, alpha=1.0) -> Array:
"""
Exponential Linear Unit: x if x > 0 else alpha * (exp(x) - 1).
Args:
x: Input array
alpha: Scale for negative values (default: 1.0)
Returns:
Array with ELU applied element-wise
"""
def selu(x) -> Array:
"""
Scaled Exponential Linear Unit with fixed alpha and scale.
Args:
x: Input array
Returns:
Array with SELU applied element-wise
"""
def celu(x, alpha=1.0) -> Array:
"""
Continuously Differentiable Exponential Linear Unit.
Args:
x: Input array
alpha: Scale parameter (default: 1.0)
Returns:
Array with CELU applied element-wise
"""Contemporary activation functions that have shown improved performance in various architectures.
def gelu(x, approximate=True) -> Array:
"""
Gaussian Error Linear Unit: x * Φ(x) where Φ is CDF of standard normal.
Args:
x: Input array
approximate: Whether to use tanh approximation (default: True)
Returns:
Array with GELU applied element-wise
"""
def silu(x) -> Array:
"""
Sigmoid Linear Unit (Swish): x * sigmoid(x).
Args:
x: Input array
Returns:
Array with SiLU applied element-wise
"""
def swish(x) -> Array:
"""
Swish activation (alias for SiLU): x * sigmoid(x).
Args:
x: Input array
Returns:
Array with Swish applied element-wise
"""
def mish(x) -> Array:
"""
Mish activation: x * tanh(softplus(x)).
Args:
x: Input array
Returns:
Array with Mish applied element-wise
"""
def hard_silu(x) -> Array:
"""
Hard SiLU (Hard Swish variant): x * hard_sigmoid(x).
Args:
x: Input array
Returns:
Array with Hard SiLU applied element-wise
"""
def hard_swish(x) -> Array:
"""
Hard Swish: x * relu6(x + 3) / 6.
Args:
x: Input array
Returns:
Array with Hard Swish applied element-wise
"""
def squareplus(x, b=4.0) -> Array:
"""
Squareplus activation: (x + sqrt(x^2 + b)) / 2.
Args:
x: Input array
b: Shape parameter (default: 4.0)
Returns:
Array with Squareplus applied element-wise
"""Sigmoid-based activations and their approximations for bounded outputs.
def sigmoid(x) -> Array:
"""
Sigmoid activation: 1 / (1 + exp(-x)).
Args:
x: Input array
Returns:
Array with sigmoid applied element-wise
"""
def hard_sigmoid(x) -> Array:
"""
Hard sigmoid approximation: max(0, min(1, (x + 1) / 2)).
Args:
x: Input array
Returns:
Array with hard sigmoid applied element-wise
"""
def log_sigmoid(x) -> Array:
"""
Log sigmoid: log(sigmoid(x)) computed in numerically stable way.
Args:
x: Input array
Returns:
Array with log sigmoid applied element-wise
"""
def soft_sign(x) -> Array:
"""
Soft sign activation: x / (1 + |x|).
Args:
x: Input array
Returns:
Array with soft sign applied element-wise
"""
def tanh(x) -> Array:
"""
Hyperbolic tangent activation.
Args:
x: Input array
Returns:
Array with tanh applied element-wise
"""
def hard_tanh(x) -> Array:
"""
Hard tanh activation: max(-1, min(1, x)).
Args:
x: Input array
Returns:
Array with hard tanh applied element-wise
"""Normalization functions for probability distributions and feature standardization.
def softmax(x, axis=-1, where=None, initial=None) -> Array:
"""
Softmax activation: exp(x_i) / sum(exp(x)) along axis.
Args:
x: Input array
axis: Axis to apply softmax along (default: -1)
where: Mask for conditional computation
initial: Initial value for reduction
Returns:
Array with softmax applied along specified axis
"""
def log_softmax(x, axis=-1, where=None, initial=None) -> Array:
"""
Log softmax: log(softmax(x)) computed in numerically stable way.
Args:
x: Input array
axis: Axis to apply log softmax along (default: -1)
where: Mask for conditional computation
initial: Initial value for reduction
Returns:
Array with log softmax applied along specified axis
"""
def softplus(x) -> Array:
"""
Softplus activation: log(1 + exp(x)).
Args:
x: Input array
Returns:
Array with softplus applied element-wise
"""
def standardize(x, axis=None, mean=None, variance=None, epsilon=1e-5) -> Array:
"""
Standardize array to zero mean and unit variance.
Args:
x: Input array to standardize
axis: Axis to compute statistics along
mean: Pre-computed mean (computed if None)
variance: Pre-computed variance (computed if None)
epsilon: Small value for numerical stability
Returns:
Standardized array
"""
def glu(x, axis=-1) -> Array:
"""
Gated Linear Unit: split x in half along axis, return a * sigmoid(b).
Args:
x: Input array (size along axis must be even)
axis: Axis to split along (default: -1)
Returns:
Array with GLU applied
"""Utility functions for neural network operations and transformations.
def one_hot(x, num_classes, dtype=None, axis=-1) -> Array:
"""
One-hot encode array of integers.
Args:
x: Integer array to encode
num_classes: Number of classes
dtype: Output data type
axis: Axis to insert one-hot dimension
Returns:
One-hot encoded array
"""
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None) -> Array:
"""
Compute log(sum(exp(a))) in numerically stable way.
Args:
a: Input array
axis: Axis to sum along
b: Scaling factor array
keepdims: Whether to keep reduced dimensions
return_sign: Whether to return sign separately
where: Mask for conditional computation
Returns:
Log-sum-exp result
"""
def logmeanexp(a, axis=None, b=None, keepdims=False, where=None) -> Array:
"""
Compute log(mean(exp(a))) in numerically stable way.
Args:
a: Input array
axis: Axis to average along
b: Scaling factor array
keepdims: Whether to keep reduced dimensions
where: Mask for conditional computation
Returns:
Log-mean-exp result
"""
def log1mexp(x) -> Array:
"""
Compute log(1 - exp(x)) in numerically stable way.
Args:
x: Input array (should be <= 0)
Returns:
Array with log(1 - exp(x)) applied element-wise
"""
def sparse_plus(x, y) -> Array:
"""
Sparse-aware addition that handles missing values.
Args:
x: First input array
y: Second input array
Returns:
Element-wise addition result
"""
def sparse_sigmoid(x) -> Array:
"""
Sparse-aware sigmoid activation.
Args:
x: Input array
Returns:
Sigmoid activation with sparse support
"""Attention functions for transformer and neural attention models.
def dot_product_attention(
query,
key,
value,
bias=None,
mask=None,
broadcast_dropout=True,
dropout_rng=None,
dropout_rate=0.0,
deterministic=False,
dtype=None,
precision=None
) -> Array:
"""
Dot-product attention mechanism.
Args:
query: Query array (..., length_q, depth_q)
key: Key array (..., length_kv, depth_q)
value: Value array (..., length_kv, depth_v)
bias: Optional attention bias
mask: Optional attention mask
broadcast_dropout: Whether to broadcast dropout
dropout_rng: Random key for dropout
dropout_rate: Dropout probability
deterministic: Whether to use deterministic mode
dtype: Output data type
precision: Computation precision
Returns:
Attention output array (..., length_q, depth_v)
"""
def scaled_dot_general(
lhs,
rhs,
dimension_numbers,
alpha=1.0,
precision=None,
preferred_element_type=None
) -> Array:
"""
Scaled general dot product for attention computations.
Args:
lhs: Left-hand side array
rhs: Right-hand side array
dimension_numbers: Contraction specification
alpha: Scaling factor
precision: Computation precision
preferred_element_type: Preferred output type
Returns:
Scaled dot product result
"""
def scaled_matmul(
a,
b,
alpha=1.0,
precision=None,
preferred_element_type=None
) -> Array:
"""
Scaled matrix multiplication: alpha * (a @ b).
Args:
a: First matrix
b: Second matrix
alpha: Scaling factor
precision: Computation precision
preferred_element_type: Preferred output type
Returns:
Scaled matrix multiplication result
"""
def get_scaled_dot_general_config() -> dict:
"""
Get configuration for scaled dot product attention.
Returns:
Configuration dictionary for attention operations
"""Additional utilities for neural network operations.
def identity(x) -> Array:
"""
Identity function that returns input unchanged.
Args:
x: Input array
Returns:
Input array unchanged
"""JAX provides weight initialization functions through jax.nn.initializers:
import jax.nn.initializers as init
# Standard initializers
init.zeros(key, shape, dtype=jnp.float32) -> Array
init.ones(key, shape, dtype=jnp.float32) -> Array
init.constant(value, dtype=jnp.float32) -> Callable
# Random initializers
init.uniform(scale=1e-2, dtype=jnp.float32) -> Callable
init.normal(stddev=1e-2, dtype=jnp.float32) -> Callable
init.truncated_normal(stddev=1e-2, dtype=jnp.float32) -> Callable
# Variance scaling initializers
init.variance_scaling(scale, mode, distribution, dtype=jnp.float32) -> Callable
init.glorot_uniform(dtype=jnp.float32) -> Callable
init.glorot_normal(dtype=jnp.float32) -> Callable
init.lecun_uniform(dtype=jnp.float32) -> Callable
init.lecun_normal(dtype=jnp.float32) -> Callable
init.he_uniform(dtype=jnp.float32) -> Callable
init.he_normal(dtype=jnp.float32) -> Callable
# Orthogonal initializer
init.orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable
# Delta orthogonal initializer (for RNNs)
init.delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> CallableUsage examples:
import jax
import jax.numpy as jnp
import jax.nn as jnn
from jax.nn import initializers as init
# Initialize weights
key = jax.random.key(42)
weights = init.glorot_uniform()(key, (784, 128))
biases = init.zeros(key, (128,))
# Apply activations in a simple neural network layer
def dense_layer(x, weights, biases):
return jnn.relu(x @ weights + biases)
# Multi-layer example with different activations
def mlp(x, params):
x = jnn.relu(x @ params['w1'] + params['b1'])
x = jnn.gelu(x @ params['w2'] + params['b2'])
x = jnn.softmax(x @ params['w3'] + params['b3'])
return x
# Attention example
def simple_attention(q, k, v):
# Scaled dot-product attention
scores = jnn.dot_product_attention(q, k, v)
return scoresInstall with Tessl CLI
npx tessl i tessl/pypi-jax