CtrlK
BlogDocsLog inGet started
Tessl Logo

tessl/pypi-websockets

An implementation of the WebSocket Protocol (RFC 6455 & 7692)

Overview
Eval results
Files

extensions.mddocs/

Extensions

WebSocket extension system with built-in per-message deflate compression support (RFC 7692) and extensible framework for custom extensions. Enables protocol enhancement for compression, encryption, and custom features.

Core Imports

# Extension base classes
from websockets.extensions import Extension, ClientExtensionFactory, ServerExtensionFactory

# Per-message deflate extension
from websockets.extensions.permessage_deflate import (
    PerMessageDeflate,
    ClientPerMessageDeflateFactory,
    ServerPerMessageDeflateFactory,
    enable_client_permessage_deflate,
    enable_server_permessage_deflate
)

Capabilities

Extension Base Classes

Foundation classes for implementing WebSocket extensions with proper negotiation and frame processing.

class Extension:
    """
    Base class for WebSocket extensions.
    
    Extensions modify WebSocket frames during transmission,
    enabling features like compression, encryption, or custom protocols.
    """
    
    def __init__(self, name: str):
        """
        Initialize extension.

        Parameters:
        - name: Extension name for negotiation
        """
        self.name = name
    
    def encode(self, frame: Frame) -> Frame:
        """
        Process outgoing frame before transmission.

        Parameters:
        - frame: Original WebSocket frame

        Returns:
        Frame: Processed frame for transmission

        Note:
        Override this method to modify outgoing frames.
        Return the original frame if no processing is needed.
        """
        return frame
    
    def decode(self, frame: Frame) -> Frame:
        """
        Process incoming frame after reception.

        Parameters:
        - frame: Received WebSocket frame

        Returns:
        Frame: Processed frame for application

        Note:
        Override this method to modify incoming frames.
        Return the original frame if no processing is needed.
        """
        return frame
    
    def __repr__(self) -> str:
        """String representation of extension."""
        return f"{self.__class__.__name__}(name={self.name!r})"

class ClientExtensionFactory:
    """
    Factory for creating client-side WebSocket extensions.
    
    Handles extension negotiation from client perspective,
    including parameter processing and extension instantiation.
    """
    
    def __init__(self, name: str, parameters: List[ExtensionParameter] = None):
        """
        Initialize client extension factory.

        Parameters:
        - name: Extension name
        - parameters: List of extension parameters for negotiation
        """
        self.name = name
        self.parameters = parameters or []
    
    def get_request_params(self) -> List[ExtensionParameter]:
        """
        Get parameters for extension negotiation request.

        Returns:
        List[ExtensionParameter]: Parameters to send in Sec-WebSocket-Extensions header
        """
        return self.parameters
    
    def process_response_params(
        self, 
        parameters: List[ExtensionParameter]
    ) -> Extension:
        """
        Process server response and create extension instance.

        Parameters:
        - parameters: Parameters from server's Sec-WebSocket-Extensions header

        Returns:
        Extension: Configured extension instance

        Raises:
        - NegotiationError: If parameters are invalid or incompatible
        """
        raise NotImplementedError("Subclasses must implement process_response_params")

class ServerExtensionFactory:
    """
    Factory for creating server-side WebSocket extensions.
    
    Handles extension negotiation from server perspective,
    including client parameter validation and extension instantiation.
    """
    
    def __init__(self, name: str):
        """
        Initialize server extension factory.

        Parameters:
        - name: Extension name
        """
        self.name = name
    
    def process_request_params(
        self, 
        parameters: List[ExtensionParameter]
    ) -> Tuple[List[ExtensionParameter], Extension]:
        """
        Process client request and create extension instance.

        Parameters:
        - parameters: Parameters from client's Sec-WebSocket-Extensions header

        Returns:
        Tuple[List[ExtensionParameter], Extension]: 
            - Response parameters for Sec-WebSocket-Extensions header
            - Configured extension instance

        Raises:
        - NegotiationError: If parameters are invalid or unsupported
        """
        raise NotImplementedError("Subclasses must implement process_request_params")

Per-Message Deflate Compression

Built-in implementation of per-message deflate compression extension (RFC 7692).

class ClientPerMessageDeflateFactory(ClientExtensionFactory):
    """
    Client factory for per-message deflate compression extension.
    
    Negotiates compression parameters with server and creates
    compression extension for reducing WebSocket message size.
    """
    
    def __init__(
        self,
        server_max_window_bits: int = 15,
        client_max_window_bits: int = 15,
        server_no_context_takeover: bool = False,
        client_no_context_takeover: bool = False,
        compress_settings: Dict = None
    ):
        """
        Initialize client per-message deflate factory.

        Parameters:
        - server_max_window_bits: Maximum server LZ77 sliding window size (8-15)
        - client_max_window_bits: Maximum client LZ77 sliding window size (8-15)
        - server_no_context_takeover: Disable server context reuse between messages
        - client_no_context_takeover: Disable client context reuse between messages
        - compress_settings: Additional compression settings

        Note:
        Smaller window sizes reduce memory usage but may decrease compression ratio.
        No context takeover reduces compression efficiency but saves memory.
        """

class ServerPerMessageDeflateFactory(ServerExtensionFactory):
    """
    Server factory for per-message deflate compression extension.
    
    Processes client compression requests and creates compression
    extension with negotiated parameters.
    """
    
    def __init__(
        self,
        server_max_window_bits: int = 15,
        client_max_window_bits: int = 15,
        server_no_context_takeover: bool = False,
        client_no_context_takeover: bool = False,
        compress_settings: Dict = None
    ):
        """
        Initialize server per-message deflate factory.

        Parameters same as ClientPerMessageDeflateFactory
        """

class PerMessageDeflate(Extension):
    """
    Per-message deflate compression extension implementation.
    
    Compresses outgoing messages and decompresses incoming messages
    using deflate algorithm with configurable parameters.
    """
    
    def __init__(
        self,
        is_server: bool,
        server_max_window_bits: int = 15,
        client_max_window_bits: int = 15,
        server_no_context_takeover: bool = False,
        client_no_context_takeover: bool = False,
        compress_settings: Dict = None
    ):
        """
        Initialize per-message deflate extension.

        Parameters:
        - is_server: True if this is server-side extension
        - Other parameters same as factory classes
        """
    
    def encode(self, frame: Frame) -> Frame:
        """
        Compress outgoing frame if it's a data frame.

        Parameters:
        - frame: Original frame

        Returns:
        Frame: Compressed frame with RSV1 bit set if compressed
        """
    
    def decode(self, frame: Frame) -> Frame:
        """
        Decompress incoming frame if RSV1 bit is set.

        Parameters:
        - frame: Compressed frame

        Returns:
        Frame: Decompressed frame

        Raises:
        - ProtocolError: If decompression fails
        """

Extension Utility Functions

Helper functions for enabling compression extensions on connections.

def enable_client_permessage_deflate(
    extensions: Sequence[ClientExtensionFactory] | None
) -> Sequence[ClientExtensionFactory]:
    """
    Enable Per-Message Deflate with default settings in client extensions.
    
    If the extension is already present, perhaps with non-default settings,
    the configuration isn't changed.

    Parameters:
    - extensions: Existing list of client extension factories

    Returns:
    Sequence[ClientExtensionFactory]: Extensions list with deflate factory added

    Usage:
    extensions = enable_client_permessage_deflate(None)
    websocket.connect("ws://localhost:8765", extensions=extensions)
    
    # Or add to existing extensions:
    existing_extensions = [custom_extension_factory]
    extensions = enable_client_permessage_deflate(existing_extensions)
    """

def enable_server_permessage_deflate(
    extensions: Sequence[ServerExtensionFactory] | None
) -> Sequence[ServerExtensionFactory]:
    """
    Enable Per-Message Deflate with default settings in server extensions.
    
    If the extension is already present, perhaps with non-default settings,
    the configuration isn't changed.

    Parameters:
    - extensions: Existing list of server extension factories

    Returns:
    Sequence[ServerExtensionFactory]: Extensions list with deflate factory added

    Usage:
    extensions = enable_server_permessage_deflate(None)
    websockets.serve(handler, "localhost", 8765, extensions=extensions)
    
    # Or add to existing extensions:
    existing_extensions = [custom_extension_factory]
    extensions = enable_server_permessage_deflate(existing_extensions)
    """

Usage Examples

Basic Compression Usage

import asyncio
from websockets import connect, serve
from websockets.extensions.permessage_deflate import enable_client_permessage_deflate, enable_server_permessage_deflate

async def compression_client():
    """Client with compression enabled."""
    # Enable compression with default settings
    extensions = enable_client_permessage_deflate(None)
    
    async with connect(
        "ws://localhost:8765", 
        extensions=extensions
    ) as websocket:
        # Send large text message (will be compressed)
        large_message = "Hello, WebSocket! " * 1000  # ~18KB message
        await websocket.send(large_message)
        
        response = await websocket.recv()
        print(f"Received: {len(response)} characters")

async def compression_server():
    """Server with compression enabled."""
    # Enable compression with default settings
    extensions = enable_server_permessage_deflate(None)
    
    async def handler(websocket):
        async for message in websocket:
            print(f"Received compressed message: {len(message)} chars")
            # Echo back (will be compressed)
            await websocket.send(f"Echo: {message}")
    
    async with serve(
        handler, 
        "localhost", 
        8765, 
        extensions=extensions
    ):
        print("Compression server started")
        await asyncio.Future()

# Run server in one terminal, client in another
# asyncio.run(compression_server())
# asyncio.run(compression_client())

Advanced Compression Configuration

import asyncio
from websockets import connect, serve
from websockets.extensions import enable_client_permessage_deflate, enable_server_permessage_deflate

async def advanced_compression_example():
    """Demonstrate advanced compression configurations."""
    
    # Standard compression configuration
    standard_extensions = enable_client_permessage_deflate(None)
    
    # Custom compression configuration with specific settings
    from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
    custom_factory = ClientPerMessageDeflateFactory(
        server_max_window_bits=12,  # Medium server window
        client_max_window_bits=12,  # Medium client window
        compress_settings={"memLevel": 6}
    )
    custom_extensions = [custom_factory]
    
    # Test configurations
    configurations = [
        ("Standard Compression", standard_extensions),
        ("Custom Compression", custom_extensions)
    ]
    
    for name, extensions in configurations:
        print(f"\n=== {name} Configuration ===")
        
        try:
            async with connect(
                "ws://localhost:8765",
                extensions=extensions,
                open_timeout=5
            ) as websocket:
                # Test with different message sizes
                test_messages = [
                    "Small message",
                    "Medium message " * 50,      # ~750 bytes
                    "Large message " * 500,      # ~7KB
                    "Huge message " * 5000       # ~70KB
                ]
                
                for i, message in enumerate(test_messages):
                    await websocket.send(message)
                    response = await websocket.recv()
                    
                    print(f"Message {i+1}: {len(message)} -> {len(response)} chars")
                    
        except Exception as e:
            print(f"Configuration failed: {e}")

# asyncio.run(advanced_compression_example())

Custom Extension Implementation

from websockets.extensions.base import Extension, ClientExtensionFactory, ServerExtensionFactory
from websockets.datastructures import Frame, Opcode
from websockets.exceptions import NegotiationError
import base64
import hashlib

class SimpleEncryptionExtension(Extension):
    """Simple XOR encryption extension (for demonstration only)."""
    
    def __init__(self, key: bytes):
        super().__init__("simple-encrypt")
        self.key = key
    
    def encode(self, frame: Frame) -> Frame:
        """Encrypt outgoing data frames."""
        if frame.opcode in [Opcode.TEXT, Opcode.BINARY]:
            # Simple XOR encryption (NOT secure!)
            encrypted_data = bytes(
                b ^ self.key[i % len(self.key)] 
                for i, b in enumerate(frame.data)
            )
            return frame._replace(data=encrypted_data)
        return frame
    
    def decode(self, frame: Frame) -> Frame:
        """Decrypt incoming data frames."""
        if frame.opcode in [Opcode.TEXT, Opcode.BINARY]:
            # XOR decryption (same as encryption for XOR)
            decrypted_data = bytes(
                b ^ self.key[i % len(self.key)]
                for i, b in enumerate(frame.data)
            )
            return frame._replace(data=decrypted_data)
        return frame

class SimpleEncryptionClientFactory(ClientExtensionFactory):
    """Client factory for simple encryption extension."""
    
    def __init__(self, password: str):
        # Generate key from password
        key = hashlib.sha256(password.encode()).digest()[:16]  # 16-byte key
        super().__init__("simple-encrypt", [("key", base64.b64encode(key).decode())])
        self.key = key
    
    def process_response_params(self, parameters):
        """Process server response parameters."""
        # In real implementation, would validate server parameters
        return SimpleEncryptionExtension(self.key)

class SimpleEncryptionServerFactory(ServerExtensionFactory):
    """Server factory for simple encryption extension."""
    
    def __init__(self, password: str):
        super().__init__("simple-encrypt")
        self.key = hashlib.sha256(password.encode()).digest()[:16]
    
    def process_request_params(self, parameters):
        """Process client request parameters."""
        # Validate client key matches our key
        client_key_param = next((p for p in parameters if p[0] == "key"), None)
        if not client_key_param:
            raise NegotiationError("Missing encryption key")
        
        client_key = base64.b64decode(client_key_param[1])
        if client_key != self.key:
            raise NegotiationError("Invalid encryption key")
        
        # Return response parameters and extension
        response_params = [("key", base64.b64encode(self.key).decode())]
        extension = SimpleEncryptionExtension(self.key)
        
        return response_params, extension

async def custom_extension_example():
    """Demonstrate custom extension usage."""
    import asyncio
    from websockets import serve, connect
    
    password = "shared-secret-password"
    
    # Server with custom extension
    async def encrypted_handler(websocket):
        async for message in websocket:
            print(f"Server received (decrypted): {message}")
            await websocket.send(f"Encrypted echo: {message}")
    
    server_extensions = [SimpleEncryptionServerFactory(password)]
    
    # Start server (in real code, would run in separate task)
    print("Starting server with custom encryption extension...")
    
    # Client with custom extension
    client_extensions = [SimpleEncryptionClientFactory(password)]
    
    try:
        async with connect(
            "ws://localhost:8765",
            extensions=client_extensions
        ) as websocket:
            # Send encrypted message
            await websocket.send("This message is encrypted!")
            response = await websocket.recv()
            print(f"Client received (decrypted): {response}")
            
    except Exception as e:
        print(f"Custom extension error: {e}")

# Note: This is a demonstration. Real encryption would use proper cryptography
# asyncio.run(custom_extension_example())

Extension Chain Implementation

from websockets.extensions.base import Extension
from websockets.datastructures import Frame, Opcode
import zlib
import time

class CompressionExtension(Extension):
    """Simple compression extension."""
    
    def __init__(self):
        super().__init__("simple-compress")
        self.compressor = zlib.compressobj()
        self.decompressor = zlib.decompressobj()
    
    def encode(self, frame: Frame) -> Frame:
        if frame.opcode == Opcode.TEXT and len(frame.data) > 100:
            compressed = self.compressor.compress(frame.data)
            compressed += self.compressor.flush(zlib.Z_SYNC_FLUSH)
            return frame._replace(data=compressed, rsv1=True)
        return frame
    
    def decode(self, frame: Frame) -> Frame:
        if frame.rsv1 and frame.opcode == Opcode.TEXT:
            decompressed = self.decompressor.decompress(frame.data)
            return frame._replace(data=decompressed, rsv1=False)
        return frame

class LoggingExtension(Extension):
    """Logging extension for debugging."""
    
    def __init__(self, name: str = "logger"):
        super().__init__(name)
        self.message_count = 0
    
    def encode(self, frame: Frame) -> Frame:
        self.message_count += 1
        print(f"[OUT #{self.message_count}] {frame.opcode.name}: {len(frame.data)} bytes")
        return frame
    
    def decode(self, frame: Frame) -> Frame:
        self.message_count += 1
        print(f"[IN #{self.message_count}] {frame.opcode.name}: {len(frame.data)} bytes")
        return frame

class ExtensionChain:
    """Chain multiple extensions together."""
    
    def __init__(self, extensions: List[Extension]):
        self.extensions = extensions
    
    def encode(self, frame: Frame) -> Frame:
        """Apply all extensions to outgoing frame."""
        for extension in self.extensions:
            frame = extension.encode(frame)
        return frame
    
    def decode(self, frame: Frame) -> Frame:
        """Apply all extensions to incoming frame (in reverse order)."""
        for extension in reversed(self.extensions):
            frame = extension.decode(frame)
        return frame

def extension_chain_example():
    """Demonstrate extension chaining."""
    from websockets.datastructures import Frame, Opcode
    
    # Create extension chain
    extensions = [
        LoggingExtension("pre-compress"),
        CompressionExtension(),
        LoggingExtension("post-compress")
    ]
    
    chain = ExtensionChain(extensions)
    
    # Test message
    test_message = "This is a test message that should be compressed! " * 10
    test_frame = Frame(Opcode.TEXT, test_message.encode())
    
    print("=== Extension Chain Test ===")
    print(f"Original message: {len(test_message)} characters")
    
    # Encode (outgoing)
    print("\nEncoding (outgoing):")
    encoded_frame = chain.encode(test_frame)
    print(f"Final encoded size: {len(encoded_frame.data)} bytes")
    
    # Decode (incoming) 
    print("\nDecoding (incoming):")
    decoded_frame = chain.decode(encoded_frame)
    decoded_message = decoded_frame.data.decode()
    print(f"Final decoded size: {len(decoded_message)} characters")
    
    # Verify roundtrip
    print(f"\nRoundtrip successful: {decoded_message == test_message}")

extension_chain_example()

Extension Performance Testing

import time
import asyncio
from websockets import connect, serve
from websockets.extensions import enable_client_permessage_deflate

async def extension_performance_test():
    """Test extension performance with different configurations."""
    
    # Test configurations
    configs = [
        ("No Compression", None),
        ("Default Compression", enable_client_permessage_deflate(None)),
    ]
    
    # Test messages of different sizes
    test_messages = [
        ("Small", "Hello" * 20),           # ~100 bytes
        ("Medium", "Test message " * 100), # ~1.2KB  
        ("Large", "Data " * 2000),         # ~10KB
        ("Huge", "Content " * 10000)       # ~70KB
    ]
    
    print("=== Extension Performance Test ===")
    print(f"{'Config':<20} {'Size':<8} {'Time (ms)':<10} {'Throughput (MB/s)':<15}")
    print("-" * 65)
    
    for config_name, extensions in configs:
        for msg_name, message in test_messages:
            try:
                start_time = time.time()
                
                async with connect(
                    "ws://localhost:8765",
                    extensions=extensions,
                    open_timeout=5
                ) as websocket:
                    # Send message multiple times for better measurement
                    iterations = 10
                    
                    for _ in range(iterations):
                        await websocket.send(message)
                        response = await websocket.recv()
                    
                    end_time = time.time()
                    
                    # Calculate metrics
                    total_time = (end_time - start_time) * 1000  # ms
                    data_size = len(message) * iterations * 2  # send + receive
                    throughput = (data_size / (1024 * 1024)) / (total_time / 1000)  # MB/s
                    
                    print(f"{config_name:<20} {msg_name:<8} {total_time/iterations:>8.1f} {throughput:>13.2f}")
                    
            except Exception as e:
                print(f"{config_name:<20} {msg_name:<8} ERROR: {e}")

# Note: Run with a test server
# asyncio.run(extension_performance_test())

Extension Debugging Tools

from websockets.extensions.base import Extension
from websockets.datastructures import Frame
import json
import time

class DebugExtension(Extension):
    """Extension for debugging frame processing."""
    
    def __init__(self, name: str = "debug"):
        super().__init__(name)
        self.stats = {
            "frames_encoded": 0,
            "frames_decoded": 0,
            "bytes_in": 0,
            "bytes_out": 0,
            "start_time": time.time()
        }
    
    def encode(self, frame: Frame) -> Frame:
        self.stats["frames_encoded"] += 1
        self.stats["bytes_out"] += len(frame.data)
        
        print(f"[DEBUG] Encoding frame:")
        print(f"  Opcode: {frame.opcode.name}")
        print(f"  Size: {len(frame.data)} bytes")
        print(f"  FIN: {frame.fin}")
        print(f"  RSV: {frame.rsv1}{frame.rsv2}{frame.rsv3}")
        
        return frame
    
    def decode(self, frame: Frame) -> Frame:
        self.stats["frames_decoded"] += 1
        self.stats["bytes_in"] += len(frame.data)
        
        print(f"[DEBUG] Decoding frame:")
        print(f"  Opcode: {frame.opcode.name}")
        print(f"  Size: {len(frame.data)} bytes")
        print(f"  FIN: {frame.fin}")
        print(f"  RSV: {frame.rsv1}{frame.rsv2}{frame.rsv3}")
        
        return frame
    
    def get_stats(self) -> dict:
        """Get extension statistics."""
        runtime = time.time() - self.stats["start_time"]
        
        return {
            **self.stats,
            "runtime_seconds": runtime,
            "frames_per_second": (self.stats["frames_encoded"] + self.stats["frames_decoded"]) / runtime,
            "bytes_per_second": (self.stats["bytes_in"] + self.stats["bytes_out"]) / runtime
        }
    
    def print_stats(self):
        """Print formatted statistics."""
        stats = self.get_stats()
        
        print(f"\n=== {self.name} Extension Statistics ===")
        print(f"Runtime: {stats['runtime_seconds']:.2f} seconds")
        print(f"Frames encoded: {stats['frames_encoded']}")
        print(f"Frames decoded: {stats['frames_decoded']}")
        print(f"Bytes in: {stats['bytes_in']:,}")
        print(f"Bytes out: {stats['bytes_out']:,}")
        print(f"Frames/sec: {stats['frames_per_second']:.2f}")
        print(f"Bytes/sec: {stats['bytes_per_second']:,.0f}")

def debug_extension_example():
    """Demonstrate debug extension."""
    from websockets.datastructures import Frame, Opcode
    
    debug_ext = DebugExtension("test-debug")
    
    # Simulate frame processing
    test_frames = [
        Frame(Opcode.TEXT, b"Hello, World!"),
        Frame(Opcode.BINARY, b"\x00\x01\x02\x03"),
        Frame(Opcode.PING, b"ping"),
        Frame(Opcode.PONG, b"pong")
    ]
    
    print("=== Debug Extension Example ===")
    
    # Process frames through extension
    for frame in test_frames:
        # Simulate outgoing frame
        encoded = debug_ext.encode(frame)
        
        # Simulate incoming frame
        decoded = debug_ext.decode(encoded)
    
    # Print statistics
    debug_ext.print_stats()

debug_extension_example()

Install with Tessl CLI

npx tessl i tessl/pypi-websockets

docs

asyncio-client.md

asyncio-server.md

data-structures.md

exceptions.md

extensions.md

index.md

protocol.md

routing.md

sync-client.md

sync-server.md

tile.json