An implementation of the WebSocket Protocol (RFC 6455 & 7692)
WebSocket extension system with built-in per-message deflate compression support (RFC 7692) and extensible framework for custom extensions. Enables protocol enhancement for compression, encryption, and custom features.
# Extension base classes
from websockets.extensions import Extension, ClientExtensionFactory, ServerExtensionFactory
# Per-message deflate extension
from websockets.extensions.permessage_deflate import (
PerMessageDeflate,
ClientPerMessageDeflateFactory,
ServerPerMessageDeflateFactory,
enable_client_permessage_deflate,
enable_server_permessage_deflate
)Foundation classes for implementing WebSocket extensions with proper negotiation and frame processing.
class Extension:
"""
Base class for WebSocket extensions.
Extensions modify WebSocket frames during transmission,
enabling features like compression, encryption, or custom protocols.
"""
def __init__(self, name: str):
"""
Initialize extension.
Parameters:
- name: Extension name for negotiation
"""
self.name = name
def encode(self, frame: Frame) -> Frame:
"""
Process outgoing frame before transmission.
Parameters:
- frame: Original WebSocket frame
Returns:
Frame: Processed frame for transmission
Note:
Override this method to modify outgoing frames.
Return the original frame if no processing is needed.
"""
return frame
def decode(self, frame: Frame) -> Frame:
"""
Process incoming frame after reception.
Parameters:
- frame: Received WebSocket frame
Returns:
Frame: Processed frame for application
Note:
Override this method to modify incoming frames.
Return the original frame if no processing is needed.
"""
return frame
def __repr__(self) -> str:
"""String representation of extension."""
return f"{self.__class__.__name__}(name={self.name!r})"
class ClientExtensionFactory:
"""
Factory for creating client-side WebSocket extensions.
Handles extension negotiation from client perspective,
including parameter processing and extension instantiation.
"""
def __init__(self, name: str, parameters: List[ExtensionParameter] = None):
"""
Initialize client extension factory.
Parameters:
- name: Extension name
- parameters: List of extension parameters for negotiation
"""
self.name = name
self.parameters = parameters or []
def get_request_params(self) -> List[ExtensionParameter]:
"""
Get parameters for extension negotiation request.
Returns:
List[ExtensionParameter]: Parameters to send in Sec-WebSocket-Extensions header
"""
return self.parameters
def process_response_params(
self,
parameters: List[ExtensionParameter]
) -> Extension:
"""
Process server response and create extension instance.
Parameters:
- parameters: Parameters from server's Sec-WebSocket-Extensions header
Returns:
Extension: Configured extension instance
Raises:
- NegotiationError: If parameters are invalid or incompatible
"""
raise NotImplementedError("Subclasses must implement process_response_params")
class ServerExtensionFactory:
"""
Factory for creating server-side WebSocket extensions.
Handles extension negotiation from server perspective,
including client parameter validation and extension instantiation.
"""
def __init__(self, name: str):
"""
Initialize server extension factory.
Parameters:
- name: Extension name
"""
self.name = name
def process_request_params(
self,
parameters: List[ExtensionParameter]
) -> Tuple[List[ExtensionParameter], Extension]:
"""
Process client request and create extension instance.
Parameters:
- parameters: Parameters from client's Sec-WebSocket-Extensions header
Returns:
Tuple[List[ExtensionParameter], Extension]:
- Response parameters for Sec-WebSocket-Extensions header
- Configured extension instance
Raises:
- NegotiationError: If parameters are invalid or unsupported
"""
raise NotImplementedError("Subclasses must implement process_request_params")Built-in implementation of per-message deflate compression extension (RFC 7692).
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
"""
Client factory for per-message deflate compression extension.
Negotiates compression parameters with server and creates
compression extension for reducing WebSocket message size.
"""
def __init__(
self,
server_max_window_bits: int = 15,
client_max_window_bits: int = 15,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
compress_settings: Dict = None
):
"""
Initialize client per-message deflate factory.
Parameters:
- server_max_window_bits: Maximum server LZ77 sliding window size (8-15)
- client_max_window_bits: Maximum client LZ77 sliding window size (8-15)
- server_no_context_takeover: Disable server context reuse between messages
- client_no_context_takeover: Disable client context reuse between messages
- compress_settings: Additional compression settings
Note:
Smaller window sizes reduce memory usage but may decrease compression ratio.
No context takeover reduces compression efficiency but saves memory.
"""
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
"""
Server factory for per-message deflate compression extension.
Processes client compression requests and creates compression
extension with negotiated parameters.
"""
def __init__(
self,
server_max_window_bits: int = 15,
client_max_window_bits: int = 15,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
compress_settings: Dict = None
):
"""
Initialize server per-message deflate factory.
Parameters same as ClientPerMessageDeflateFactory
"""
class PerMessageDeflate(Extension):
"""
Per-message deflate compression extension implementation.
Compresses outgoing messages and decompresses incoming messages
using deflate algorithm with configurable parameters.
"""
def __init__(
self,
is_server: bool,
server_max_window_bits: int = 15,
client_max_window_bits: int = 15,
server_no_context_takeover: bool = False,
client_no_context_takeover: bool = False,
compress_settings: Dict = None
):
"""
Initialize per-message deflate extension.
Parameters:
- is_server: True if this is server-side extension
- Other parameters same as factory classes
"""
def encode(self, frame: Frame) -> Frame:
"""
Compress outgoing frame if it's a data frame.
Parameters:
- frame: Original frame
Returns:
Frame: Compressed frame with RSV1 bit set if compressed
"""
def decode(self, frame: Frame) -> Frame:
"""
Decompress incoming frame if RSV1 bit is set.
Parameters:
- frame: Compressed frame
Returns:
Frame: Decompressed frame
Raises:
- ProtocolError: If decompression fails
"""Helper functions for enabling compression extensions on connections.
def enable_client_permessage_deflate(
extensions: Sequence[ClientExtensionFactory] | None
) -> Sequence[ClientExtensionFactory]:
"""
Enable Per-Message Deflate with default settings in client extensions.
If the extension is already present, perhaps with non-default settings,
the configuration isn't changed.
Parameters:
- extensions: Existing list of client extension factories
Returns:
Sequence[ClientExtensionFactory]: Extensions list with deflate factory added
Usage:
extensions = enable_client_permessage_deflate(None)
websocket.connect("ws://localhost:8765", extensions=extensions)
# Or add to existing extensions:
existing_extensions = [custom_extension_factory]
extensions = enable_client_permessage_deflate(existing_extensions)
"""
def enable_server_permessage_deflate(
extensions: Sequence[ServerExtensionFactory] | None
) -> Sequence[ServerExtensionFactory]:
"""
Enable Per-Message Deflate with default settings in server extensions.
If the extension is already present, perhaps with non-default settings,
the configuration isn't changed.
Parameters:
- extensions: Existing list of server extension factories
Returns:
Sequence[ServerExtensionFactory]: Extensions list with deflate factory added
Usage:
extensions = enable_server_permessage_deflate(None)
websockets.serve(handler, "localhost", 8765, extensions=extensions)
# Or add to existing extensions:
existing_extensions = [custom_extension_factory]
extensions = enable_server_permessage_deflate(existing_extensions)
"""import asyncio
from websockets import connect, serve
from websockets.extensions.permessage_deflate import enable_client_permessage_deflate, enable_server_permessage_deflate
async def compression_client():
"""Client with compression enabled."""
# Enable compression with default settings
extensions = enable_client_permessage_deflate(None)
async with connect(
"ws://localhost:8765",
extensions=extensions
) as websocket:
# Send large text message (will be compressed)
large_message = "Hello, WebSocket! " * 1000 # ~18KB message
await websocket.send(large_message)
response = await websocket.recv()
print(f"Received: {len(response)} characters")
async def compression_server():
"""Server with compression enabled."""
# Enable compression with default settings
extensions = enable_server_permessage_deflate(None)
async def handler(websocket):
async for message in websocket:
print(f"Received compressed message: {len(message)} chars")
# Echo back (will be compressed)
await websocket.send(f"Echo: {message}")
async with serve(
handler,
"localhost",
8765,
extensions=extensions
):
print("Compression server started")
await asyncio.Future()
# Run server in one terminal, client in another
# asyncio.run(compression_server())
# asyncio.run(compression_client())import asyncio
from websockets import connect, serve
from websockets.extensions import enable_client_permessage_deflate, enable_server_permessage_deflate
async def advanced_compression_example():
"""Demonstrate advanced compression configurations."""
# Standard compression configuration
standard_extensions = enable_client_permessage_deflate(None)
# Custom compression configuration with specific settings
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
custom_factory = ClientPerMessageDeflateFactory(
server_max_window_bits=12, # Medium server window
client_max_window_bits=12, # Medium client window
compress_settings={"memLevel": 6}
)
custom_extensions = [custom_factory]
# Test configurations
configurations = [
("Standard Compression", standard_extensions),
("Custom Compression", custom_extensions)
]
for name, extensions in configurations:
print(f"\n=== {name} Configuration ===")
try:
async with connect(
"ws://localhost:8765",
extensions=extensions,
open_timeout=5
) as websocket:
# Test with different message sizes
test_messages = [
"Small message",
"Medium message " * 50, # ~750 bytes
"Large message " * 500, # ~7KB
"Huge message " * 5000 # ~70KB
]
for i, message in enumerate(test_messages):
await websocket.send(message)
response = await websocket.recv()
print(f"Message {i+1}: {len(message)} -> {len(response)} chars")
except Exception as e:
print(f"Configuration failed: {e}")
# asyncio.run(advanced_compression_example())from websockets.extensions.base import Extension, ClientExtensionFactory, ServerExtensionFactory
from websockets.datastructures import Frame, Opcode
from websockets.exceptions import NegotiationError
import base64
import hashlib
class SimpleEncryptionExtension(Extension):
"""Simple XOR encryption extension (for demonstration only)."""
def __init__(self, key: bytes):
super().__init__("simple-encrypt")
self.key = key
def encode(self, frame: Frame) -> Frame:
"""Encrypt outgoing data frames."""
if frame.opcode in [Opcode.TEXT, Opcode.BINARY]:
# Simple XOR encryption (NOT secure!)
encrypted_data = bytes(
b ^ self.key[i % len(self.key)]
for i, b in enumerate(frame.data)
)
return frame._replace(data=encrypted_data)
return frame
def decode(self, frame: Frame) -> Frame:
"""Decrypt incoming data frames."""
if frame.opcode in [Opcode.TEXT, Opcode.BINARY]:
# XOR decryption (same as encryption for XOR)
decrypted_data = bytes(
b ^ self.key[i % len(self.key)]
for i, b in enumerate(frame.data)
)
return frame._replace(data=decrypted_data)
return frame
class SimpleEncryptionClientFactory(ClientExtensionFactory):
"""Client factory for simple encryption extension."""
def __init__(self, password: str):
# Generate key from password
key = hashlib.sha256(password.encode()).digest()[:16] # 16-byte key
super().__init__("simple-encrypt", [("key", base64.b64encode(key).decode())])
self.key = key
def process_response_params(self, parameters):
"""Process server response parameters."""
# In real implementation, would validate server parameters
return SimpleEncryptionExtension(self.key)
class SimpleEncryptionServerFactory(ServerExtensionFactory):
"""Server factory for simple encryption extension."""
def __init__(self, password: str):
super().__init__("simple-encrypt")
self.key = hashlib.sha256(password.encode()).digest()[:16]
def process_request_params(self, parameters):
"""Process client request parameters."""
# Validate client key matches our key
client_key_param = next((p for p in parameters if p[0] == "key"), None)
if not client_key_param:
raise NegotiationError("Missing encryption key")
client_key = base64.b64decode(client_key_param[1])
if client_key != self.key:
raise NegotiationError("Invalid encryption key")
# Return response parameters and extension
response_params = [("key", base64.b64encode(self.key).decode())]
extension = SimpleEncryptionExtension(self.key)
return response_params, extension
async def custom_extension_example():
"""Demonstrate custom extension usage."""
import asyncio
from websockets import serve, connect
password = "shared-secret-password"
# Server with custom extension
async def encrypted_handler(websocket):
async for message in websocket:
print(f"Server received (decrypted): {message}")
await websocket.send(f"Encrypted echo: {message}")
server_extensions = [SimpleEncryptionServerFactory(password)]
# Start server (in real code, would run in separate task)
print("Starting server with custom encryption extension...")
# Client with custom extension
client_extensions = [SimpleEncryptionClientFactory(password)]
try:
async with connect(
"ws://localhost:8765",
extensions=client_extensions
) as websocket:
# Send encrypted message
await websocket.send("This message is encrypted!")
response = await websocket.recv()
print(f"Client received (decrypted): {response}")
except Exception as e:
print(f"Custom extension error: {e}")
# Note: This is a demonstration. Real encryption would use proper cryptography
# asyncio.run(custom_extension_example())from websockets.extensions.base import Extension
from websockets.datastructures import Frame, Opcode
import zlib
import time
class CompressionExtension(Extension):
"""Simple compression extension."""
def __init__(self):
super().__init__("simple-compress")
self.compressor = zlib.compressobj()
self.decompressor = zlib.decompressobj()
def encode(self, frame: Frame) -> Frame:
if frame.opcode == Opcode.TEXT and len(frame.data) > 100:
compressed = self.compressor.compress(frame.data)
compressed += self.compressor.flush(zlib.Z_SYNC_FLUSH)
return frame._replace(data=compressed, rsv1=True)
return frame
def decode(self, frame: Frame) -> Frame:
if frame.rsv1 and frame.opcode == Opcode.TEXT:
decompressed = self.decompressor.decompress(frame.data)
return frame._replace(data=decompressed, rsv1=False)
return frame
class LoggingExtension(Extension):
"""Logging extension for debugging."""
def __init__(self, name: str = "logger"):
super().__init__(name)
self.message_count = 0
def encode(self, frame: Frame) -> Frame:
self.message_count += 1
print(f"[OUT #{self.message_count}] {frame.opcode.name}: {len(frame.data)} bytes")
return frame
def decode(self, frame: Frame) -> Frame:
self.message_count += 1
print(f"[IN #{self.message_count}] {frame.opcode.name}: {len(frame.data)} bytes")
return frame
class ExtensionChain:
"""Chain multiple extensions together."""
def __init__(self, extensions: List[Extension]):
self.extensions = extensions
def encode(self, frame: Frame) -> Frame:
"""Apply all extensions to outgoing frame."""
for extension in self.extensions:
frame = extension.encode(frame)
return frame
def decode(self, frame: Frame) -> Frame:
"""Apply all extensions to incoming frame (in reverse order)."""
for extension in reversed(self.extensions):
frame = extension.decode(frame)
return frame
def extension_chain_example():
"""Demonstrate extension chaining."""
from websockets.datastructures import Frame, Opcode
# Create extension chain
extensions = [
LoggingExtension("pre-compress"),
CompressionExtension(),
LoggingExtension("post-compress")
]
chain = ExtensionChain(extensions)
# Test message
test_message = "This is a test message that should be compressed! " * 10
test_frame = Frame(Opcode.TEXT, test_message.encode())
print("=== Extension Chain Test ===")
print(f"Original message: {len(test_message)} characters")
# Encode (outgoing)
print("\nEncoding (outgoing):")
encoded_frame = chain.encode(test_frame)
print(f"Final encoded size: {len(encoded_frame.data)} bytes")
# Decode (incoming)
print("\nDecoding (incoming):")
decoded_frame = chain.decode(encoded_frame)
decoded_message = decoded_frame.data.decode()
print(f"Final decoded size: {len(decoded_message)} characters")
# Verify roundtrip
print(f"\nRoundtrip successful: {decoded_message == test_message}")
extension_chain_example()import time
import asyncio
from websockets import connect, serve
from websockets.extensions import enable_client_permessage_deflate
async def extension_performance_test():
"""Test extension performance with different configurations."""
# Test configurations
configs = [
("No Compression", None),
("Default Compression", enable_client_permessage_deflate(None)),
]
# Test messages of different sizes
test_messages = [
("Small", "Hello" * 20), # ~100 bytes
("Medium", "Test message " * 100), # ~1.2KB
("Large", "Data " * 2000), # ~10KB
("Huge", "Content " * 10000) # ~70KB
]
print("=== Extension Performance Test ===")
print(f"{'Config':<20} {'Size':<8} {'Time (ms)':<10} {'Throughput (MB/s)':<15}")
print("-" * 65)
for config_name, extensions in configs:
for msg_name, message in test_messages:
try:
start_time = time.time()
async with connect(
"ws://localhost:8765",
extensions=extensions,
open_timeout=5
) as websocket:
# Send message multiple times for better measurement
iterations = 10
for _ in range(iterations):
await websocket.send(message)
response = await websocket.recv()
end_time = time.time()
# Calculate metrics
total_time = (end_time - start_time) * 1000 # ms
data_size = len(message) * iterations * 2 # send + receive
throughput = (data_size / (1024 * 1024)) / (total_time / 1000) # MB/s
print(f"{config_name:<20} {msg_name:<8} {total_time/iterations:>8.1f} {throughput:>13.2f}")
except Exception as e:
print(f"{config_name:<20} {msg_name:<8} ERROR: {e}")
# Note: Run with a test server
# asyncio.run(extension_performance_test())from websockets.extensions.base import Extension
from websockets.datastructures import Frame
import json
import time
class DebugExtension(Extension):
"""Extension for debugging frame processing."""
def __init__(self, name: str = "debug"):
super().__init__(name)
self.stats = {
"frames_encoded": 0,
"frames_decoded": 0,
"bytes_in": 0,
"bytes_out": 0,
"start_time": time.time()
}
def encode(self, frame: Frame) -> Frame:
self.stats["frames_encoded"] += 1
self.stats["bytes_out"] += len(frame.data)
print(f"[DEBUG] Encoding frame:")
print(f" Opcode: {frame.opcode.name}")
print(f" Size: {len(frame.data)} bytes")
print(f" FIN: {frame.fin}")
print(f" RSV: {frame.rsv1}{frame.rsv2}{frame.rsv3}")
return frame
def decode(self, frame: Frame) -> Frame:
self.stats["frames_decoded"] += 1
self.stats["bytes_in"] += len(frame.data)
print(f"[DEBUG] Decoding frame:")
print(f" Opcode: {frame.opcode.name}")
print(f" Size: {len(frame.data)} bytes")
print(f" FIN: {frame.fin}")
print(f" RSV: {frame.rsv1}{frame.rsv2}{frame.rsv3}")
return frame
def get_stats(self) -> dict:
"""Get extension statistics."""
runtime = time.time() - self.stats["start_time"]
return {
**self.stats,
"runtime_seconds": runtime,
"frames_per_second": (self.stats["frames_encoded"] + self.stats["frames_decoded"]) / runtime,
"bytes_per_second": (self.stats["bytes_in"] + self.stats["bytes_out"]) / runtime
}
def print_stats(self):
"""Print formatted statistics."""
stats = self.get_stats()
print(f"\n=== {self.name} Extension Statistics ===")
print(f"Runtime: {stats['runtime_seconds']:.2f} seconds")
print(f"Frames encoded: {stats['frames_encoded']}")
print(f"Frames decoded: {stats['frames_decoded']}")
print(f"Bytes in: {stats['bytes_in']:,}")
print(f"Bytes out: {stats['bytes_out']:,}")
print(f"Frames/sec: {stats['frames_per_second']:.2f}")
print(f"Bytes/sec: {stats['bytes_per_second']:,.0f}")
def debug_extension_example():
"""Demonstrate debug extension."""
from websockets.datastructures import Frame, Opcode
debug_ext = DebugExtension("test-debug")
# Simulate frame processing
test_frames = [
Frame(Opcode.TEXT, b"Hello, World!"),
Frame(Opcode.BINARY, b"\x00\x01\x02\x03"),
Frame(Opcode.PING, b"ping"),
Frame(Opcode.PONG, b"pong")
]
print("=== Debug Extension Example ===")
# Process frames through extension
for frame in test_frames:
# Simulate outgoing frame
encoded = debug_ext.encode(frame)
# Simulate incoming frame
decoded = debug_ext.decode(encoded)
# Print statistics
debug_ext.print_stats()
debug_extension_example()Install with Tessl CLI
npx tessl i tessl/pypi-websockets