CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-jax

Differentiate, compile, and transform Numpy code.

Pending
Overview
Eval results
Files

neural-networks.mddocs/

Neural Network Functions

JAX provides a comprehensive set of neural network functions through jax.nn including activation functions, normalization utilities, and attention mechanisms commonly used in machine learning and deep learning applications.

Core Imports

import jax.nn as jnn
from jax.nn import relu, sigmoid, softmax, gelu

Capabilities

ReLU and Variants

Rectified Linear Unit activations and their variants for introducing non-linearity while maintaining computational efficiency.

def relu(x) -> Array:
    """
    Rectified Linear Unit activation: max(0, x).
    
    Args:
        x: Input array
        
    Returns:
        Array with ReLU applied element-wise
    """

def relu6(x) -> Array:
    """
    ReLU capped at 6: min(max(0, x), 6).
    
    Args:
        x: Input array
        
    Returns:
        Array with ReLU6 applied element-wise
    """

def leaky_relu(x, negative_slope=0.01) -> Array:
    """
    Leaky ReLU: max(negative_slope * x, x).
    
    Args:
        x: Input array
        negative_slope: Slope for negative values (default: 0.01)
        
    Returns:
        Array with Leaky ReLU applied element-wise
    """

def elu(x, alpha=1.0) -> Array:
    """
    Exponential Linear Unit: x if x > 0 else alpha * (exp(x) - 1).
    
    Args:
        x: Input array
        alpha: Scale for negative values (default: 1.0)
        
    Returns:
        Array with ELU applied element-wise
    """

def selu(x) -> Array:
    """
    Scaled Exponential Linear Unit with fixed alpha and scale.
    
    Args:
        x: Input array
        
    Returns:
        Array with SELU applied element-wise
    """

def celu(x, alpha=1.0) -> Array:
    """
    Continuously Differentiable Exponential Linear Unit.
    
    Args:
        x: Input array
        alpha: Scale parameter (default: 1.0)
        
    Returns:
        Array with CELU applied element-wise
    """

Modern Activations

Contemporary activation functions that have shown improved performance in various architectures.

def gelu(x, approximate=True) -> Array:
    """
    Gaussian Error Linear Unit: x * Φ(x) where Φ is CDF of standard normal.
    
    Args:
        x: Input array
        approximate: Whether to use tanh approximation (default: True)
        
    Returns:
        Array with GELU applied element-wise
    """

def silu(x) -> Array:
    """
    Sigmoid Linear Unit (Swish): x * sigmoid(x).
    
    Args:
        x: Input array
        
    Returns:
        Array with SiLU applied element-wise
    """

def swish(x) -> Array:
    """
    Swish activation (alias for SiLU): x * sigmoid(x).
    
    Args:
        x: Input array
        
    Returns:
        Array with Swish applied element-wise
    """

def mish(x) -> Array:
    """
    Mish activation: x * tanh(softplus(x)).
    
    Args:
        x: Input array
        
    Returns:
        Array with Mish applied element-wise
    """

def hard_silu(x) -> Array:
    """
    Hard SiLU (Hard Swish variant): x * hard_sigmoid(x).
    
    Args:
        x: Input array
        
    Returns:
        Array with Hard SiLU applied element-wise
    """

def hard_swish(x) -> Array:
    """
    Hard Swish: x * relu6(x + 3) / 6.
    
    Args:
        x: Input array
        
    Returns:
        Array with Hard Swish applied element-wise
    """

def squareplus(x, b=4.0) -> Array:
    """
    Squareplus activation: (x + sqrt(x^2 + b)) / 2.
    
    Args:
        x: Input array
        b: Shape parameter (default: 4.0)
        
    Returns:
        Array with Squareplus applied element-wise
    """

Sigmoid and Tanh Variants

Sigmoid-based activations and their approximations for bounded outputs.

def sigmoid(x) -> Array:
    """
    Sigmoid activation: 1 / (1 + exp(-x)).
    
    Args:
        x: Input array
        
    Returns:
        Array with sigmoid applied element-wise
    """

def hard_sigmoid(x) -> Array:
    """
    Hard sigmoid approximation: max(0, min(1, (x + 1) / 2)).
    
    Args:
        x: Input array
        
    Returns:
        Array with hard sigmoid applied element-wise
    """

def log_sigmoid(x) -> Array:
    """
    Log sigmoid: log(sigmoid(x)) computed in numerically stable way.
    
    Args:
        x: Input array
        
    Returns:
        Array with log sigmoid applied element-wise
    """

def soft_sign(x) -> Array:
    """
    Soft sign activation: x / (1 + |x|).
    
    Args:
        x: Input array
        
    Returns:
        Array with soft sign applied element-wise
    """

def tanh(x) -> Array:
    """
    Hyperbolic tangent activation.
    
    Args:
        x: Input array
        
    Returns:
        Array with tanh applied element-wise
    """

def hard_tanh(x) -> Array:
    """
    Hard tanh activation: max(-1, min(1, x)).
    
    Args:
        x: Input array
        
    Returns:
        Array with hard tanh applied element-wise
    """

Softmax and Normalization

Normalization functions for probability distributions and feature standardization.

def softmax(x, axis=-1, where=None, initial=None) -> Array:
    """
    Softmax activation: exp(x_i) / sum(exp(x)) along axis.
    
    Args:
        x: Input array
        axis: Axis to apply softmax along (default: -1)
        where: Mask for conditional computation
        initial: Initial value for reduction
        
    Returns:
        Array with softmax applied along specified axis
    """

def log_softmax(x, axis=-1, where=None, initial=None) -> Array:
    """
    Log softmax: log(softmax(x)) computed in numerically stable way.
    
    Args:
        x: Input array
        axis: Axis to apply log softmax along (default: -1)
        where: Mask for conditional computation
        initial: Initial value for reduction
        
    Returns:
        Array with log softmax applied along specified axis
    """

def softplus(x) -> Array:
    """
    Softplus activation: log(1 + exp(x)).
    
    Args:
        x: Input array
        
    Returns:
        Array with softplus applied element-wise
    """

def standardize(x, axis=None, mean=None, variance=None, epsilon=1e-5) -> Array:
    """
    Standardize array to zero mean and unit variance.
    
    Args:
        x: Input array to standardize
        axis: Axis to compute statistics along
        mean: Pre-computed mean (computed if None)
        variance: Pre-computed variance (computed if None)
        epsilon: Small value for numerical stability
        
    Returns:
        Standardized array
    """

def glu(x, axis=-1) -> Array:
    """
    Gated Linear Unit: split x in half along axis, return a * sigmoid(b).
    
    Args:
        x: Input array (size along axis must be even)
        axis: Axis to split along (default: -1)
        
    Returns:
        Array with GLU applied
    """

Specialized Functions

Utility functions for neural network operations and transformations.

def one_hot(x, num_classes, dtype=None, axis=-1) -> Array:
    """
    One-hot encode array of integers.
    
    Args:
        x: Integer array to encode
        num_classes: Number of classes
        dtype: Output data type
        axis: Axis to insert one-hot dimension
        
    Returns:
        One-hot encoded array
    """

def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False, where=None) -> Array:
    """
    Compute log(sum(exp(a))) in numerically stable way.
    
    Args:
        a: Input array
        axis: Axis to sum along
        b: Scaling factor array
        keepdims: Whether to keep reduced dimensions
        return_sign: Whether to return sign separately
        where: Mask for conditional computation
        
    Returns:
        Log-sum-exp result
    """

def logmeanexp(a, axis=None, b=None, keepdims=False, where=None) -> Array:
    """
    Compute log(mean(exp(a))) in numerically stable way.
    
    Args:
        a: Input array
        axis: Axis to average along
        b: Scaling factor array
        keepdims: Whether to keep reduced dimensions
        where: Mask for conditional computation
        
    Returns:
        Log-mean-exp result
    """

def log1mexp(x) -> Array:
    """
    Compute log(1 - exp(x)) in numerically stable way.
    
    Args:
        x: Input array (should be <= 0)
        
    Returns:
        Array with log(1 - exp(x)) applied element-wise
    """

def sparse_plus(x, y) -> Array:
    """
    Sparse-aware addition that handles missing values.
    
    Args:
        x: First input array
        y: Second input array
        
    Returns:
        Element-wise addition result
    """

def sparse_sigmoid(x) -> Array:
    """
    Sparse-aware sigmoid activation.
    
    Args:
        x: Input array
        
    Returns:
        Sigmoid activation with sparse support
    """

Attention Mechanisms

Attention functions for transformer and neural attention models.

def dot_product_attention(
    query, 
    key, 
    value,
    bias=None,
    mask=None,
    broadcast_dropout=True,
    dropout_rng=None,
    dropout_rate=0.0,
    deterministic=False,
    dtype=None,
    precision=None
) -> Array:
    """
    Dot-product attention mechanism.
    
    Args:
        query: Query array (..., length_q, depth_q)
        key: Key array (..., length_kv, depth_q) 
        value: Value array (..., length_kv, depth_v)
        bias: Optional attention bias
        mask: Optional attention mask
        broadcast_dropout: Whether to broadcast dropout
        dropout_rng: Random key for dropout
        dropout_rate: Dropout probability
        deterministic: Whether to use deterministic mode
        dtype: Output data type
        precision: Computation precision
        
    Returns:
        Attention output array (..., length_q, depth_v)
    """

def scaled_dot_general(
    lhs,
    rhs, 
    dimension_numbers,
    alpha=1.0,
    precision=None,
    preferred_element_type=None
) -> Array:
    """
    Scaled general dot product for attention computations.
    
    Args:
        lhs: Left-hand side array
        rhs: Right-hand side array
        dimension_numbers: Contraction specification
        alpha: Scaling factor
        precision: Computation precision
        preferred_element_type: Preferred output type
        
    Returns:
        Scaled dot product result
    """

def scaled_matmul(
    a,
    b,
    alpha=1.0,
    precision=None,
    preferred_element_type=None
) -> Array:
    """
    Scaled matrix multiplication: alpha * (a @ b).
    
    Args:
        a: First matrix
        b: Second matrix  
        alpha: Scaling factor
        precision: Computation precision
        preferred_element_type: Preferred output type
        
    Returns:
        Scaled matrix multiplication result
    """

def get_scaled_dot_general_config() -> dict:
    """
    Get configuration for scaled dot product attention.
    
    Returns:
        Configuration dictionary for attention operations
    """

Utility Functions

Additional utilities for neural network operations.

def identity(x) -> Array:
    """
    Identity function that returns input unchanged.
    
    Args:
        x: Input array
        
    Returns:
        Input array unchanged
    """

Neural Network Initializers

JAX provides weight initialization functions through jax.nn.initializers:

import jax.nn.initializers as init

# Standard initializers
init.zeros(key, shape, dtype=jnp.float32) -> Array
init.ones(key, shape, dtype=jnp.float32) -> Array
init.constant(value, dtype=jnp.float32) -> Callable

# Random initializers  
init.uniform(scale=1e-2, dtype=jnp.float32) -> Callable
init.normal(stddev=1e-2, dtype=jnp.float32) -> Callable
init.truncated_normal(stddev=1e-2, dtype=jnp.float32) -> Callable

# Variance scaling initializers
init.variance_scaling(scale, mode, distribution, dtype=jnp.float32) -> Callable
init.glorot_uniform(dtype=jnp.float32) -> Callable  
init.glorot_normal(dtype=jnp.float32) -> Callable
init.lecun_uniform(dtype=jnp.float32) -> Callable
init.lecun_normal(dtype=jnp.float32) -> Callable
init.he_uniform(dtype=jnp.float32) -> Callable
init.he_normal(dtype=jnp.float32) -> Callable

# Orthogonal initializer
init.orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable

# Delta orthogonal initializer (for RNNs)
init.delta_orthogonal(scale=1.0, column_axis=-1, dtype=jnp.float32) -> Callable

Usage examples:

import jax
import jax.numpy as jnp
import jax.nn as jnn
from jax.nn import initializers as init

# Initialize weights
key = jax.random.key(42)
weights = init.glorot_uniform()(key, (784, 128))
biases = init.zeros(key, (128,))

# Apply activations in a simple neural network layer
def dense_layer(x, weights, biases):
    return jnn.relu(x @ weights + biases)

# Multi-layer example with different activations
def mlp(x, params):
    x = jnn.relu(x @ params['w1'] + params['b1'])
    x = jnn.gelu(x @ params['w2'] + params['b2']) 
    x = jnn.softmax(x @ params['w3'] + params['b3'])
    return x

# Attention example
def simple_attention(q, k, v):
    # Scaled dot-product attention
    scores = jnn.dot_product_attention(q, k, v)
    return scores

Install with Tessl CLI

npx tessl i tessl/pypi-jax

docs

core-transformations.md

device-memory.md

experimental.md

index.md

low-level-ops.md

neural-networks.md

numpy-compatibility.md

random-numbers.md

scipy-compatibility.md

tree-operations.md

tile.json