0
# Extensions
1
2
WebSocket extension system with built-in per-message deflate compression support (RFC 7692) and extensible framework for custom extensions. Enables protocol enhancement for compression, encryption, and custom features.
3
4
## Core Imports
5
6
```python
7
# Extension base classes
8
from websockets.extensions import Extension, ClientExtensionFactory, ServerExtensionFactory
9
10
# Per-message deflate extension
11
from websockets.extensions.permessage_deflate import (
12
PerMessageDeflate,
13
ClientPerMessageDeflateFactory,
14
ServerPerMessageDeflateFactory,
15
enable_client_permessage_deflate,
16
enable_server_permessage_deflate
17
)
18
```
19
20
## Capabilities
21
22
### Extension Base Classes
23
24
Foundation classes for implementing WebSocket extensions with proper negotiation and frame processing.
25
26
```python { .api }
27
class Extension:
28
"""
29
Base class for WebSocket extensions.
30
31
Extensions modify WebSocket frames during transmission,
32
enabling features like compression, encryption, or custom protocols.
33
"""
34
35
def __init__(self, name: str):
36
"""
37
Initialize extension.
38
39
Parameters:
40
- name: Extension name for negotiation
41
"""
42
self.name = name
43
44
def encode(self, frame: Frame) -> Frame:
45
"""
46
Process outgoing frame before transmission.
47
48
Parameters:
49
- frame: Original WebSocket frame
50
51
Returns:
52
Frame: Processed frame for transmission
53
54
Note:
55
Override this method to modify outgoing frames.
56
Return the original frame if no processing is needed.
57
"""
58
return frame
59
60
def decode(self, frame: Frame) -> Frame:
61
"""
62
Process incoming frame after reception.
63
64
Parameters:
65
- frame: Received WebSocket frame
66
67
Returns:
68
Frame: Processed frame for application
69
70
Note:
71
Override this method to modify incoming frames.
72
Return the original frame if no processing is needed.
73
"""
74
return frame
75
76
def __repr__(self) -> str:
77
"""String representation of extension."""
78
return f"{self.__class__.__name__}(name={self.name!r})"
79
80
class ClientExtensionFactory:
81
"""
82
Factory for creating client-side WebSocket extensions.
83
84
Handles extension negotiation from client perspective,
85
including parameter processing and extension instantiation.
86
"""
87
88
def __init__(self, name: str, parameters: List[ExtensionParameter] = None):
89
"""
90
Initialize client extension factory.
91
92
Parameters:
93
- name: Extension name
94
- parameters: List of extension parameters for negotiation
95
"""
96
self.name = name
97
self.parameters = parameters or []
98
99
def get_request_params(self) -> List[ExtensionParameter]:
100
"""
101
Get parameters for extension negotiation request.
102
103
Returns:
104
List[ExtensionParameter]: Parameters to send in Sec-WebSocket-Extensions header
105
"""
106
return self.parameters
107
108
def process_response_params(
109
self,
110
parameters: List[ExtensionParameter]
111
) -> Extension:
112
"""
113
Process server response and create extension instance.
114
115
Parameters:
116
- parameters: Parameters from server's Sec-WebSocket-Extensions header
117
118
Returns:
119
Extension: Configured extension instance
120
121
Raises:
122
- NegotiationError: If parameters are invalid or incompatible
123
"""
124
raise NotImplementedError("Subclasses must implement process_response_params")
125
126
class ServerExtensionFactory:
127
"""
128
Factory for creating server-side WebSocket extensions.
129
130
Handles extension negotiation from server perspective,
131
including client parameter validation and extension instantiation.
132
"""
133
134
def __init__(self, name: str):
135
"""
136
Initialize server extension factory.
137
138
Parameters:
139
- name: Extension name
140
"""
141
self.name = name
142
143
def process_request_params(
144
self,
145
parameters: List[ExtensionParameter]
146
) -> Tuple[List[ExtensionParameter], Extension]:
147
"""
148
Process client request and create extension instance.
149
150
Parameters:
151
- parameters: Parameters from client's Sec-WebSocket-Extensions header
152
153
Returns:
154
Tuple[List[ExtensionParameter], Extension]:
155
- Response parameters for Sec-WebSocket-Extensions header
156
- Configured extension instance
157
158
Raises:
159
- NegotiationError: If parameters are invalid or unsupported
160
"""
161
raise NotImplementedError("Subclasses must implement process_request_params")
162
```
163
164
### Per-Message Deflate Compression
165
166
Built-in implementation of per-message deflate compression extension (RFC 7692).
167
168
```python { .api }
169
class ClientPerMessageDeflateFactory(ClientExtensionFactory):
170
"""
171
Client factory for per-message deflate compression extension.
172
173
Negotiates compression parameters with server and creates
174
compression extension for reducing WebSocket message size.
175
"""
176
177
def __init__(
178
self,
179
server_max_window_bits: int = 15,
180
client_max_window_bits: int = 15,
181
server_no_context_takeover: bool = False,
182
client_no_context_takeover: bool = False,
183
compress_settings: Dict = None
184
):
185
"""
186
Initialize client per-message deflate factory.
187
188
Parameters:
189
- server_max_window_bits: Maximum server LZ77 sliding window size (8-15)
190
- client_max_window_bits: Maximum client LZ77 sliding window size (8-15)
191
- server_no_context_takeover: Disable server context reuse between messages
192
- client_no_context_takeover: Disable client context reuse between messages
193
- compress_settings: Additional compression settings
194
195
Note:
196
Smaller window sizes reduce memory usage but may decrease compression ratio.
197
No context takeover reduces compression efficiency but saves memory.
198
"""
199
200
class ServerPerMessageDeflateFactory(ServerExtensionFactory):
201
"""
202
Server factory for per-message deflate compression extension.
203
204
Processes client compression requests and creates compression
205
extension with negotiated parameters.
206
"""
207
208
def __init__(
209
self,
210
server_max_window_bits: int = 15,
211
client_max_window_bits: int = 15,
212
server_no_context_takeover: bool = False,
213
client_no_context_takeover: bool = False,
214
compress_settings: Dict = None
215
):
216
"""
217
Initialize server per-message deflate factory.
218
219
Parameters same as ClientPerMessageDeflateFactory
220
"""
221
222
class PerMessageDeflate(Extension):
223
"""
224
Per-message deflate compression extension implementation.
225
226
Compresses outgoing messages and decompresses incoming messages
227
using deflate algorithm with configurable parameters.
228
"""
229
230
def __init__(
231
self,
232
is_server: bool,
233
server_max_window_bits: int = 15,
234
client_max_window_bits: int = 15,
235
server_no_context_takeover: bool = False,
236
client_no_context_takeover: bool = False,
237
compress_settings: Dict = None
238
):
239
"""
240
Initialize per-message deflate extension.
241
242
Parameters:
243
- is_server: True if this is server-side extension
244
- Other parameters same as factory classes
245
"""
246
247
def encode(self, frame: Frame) -> Frame:
248
"""
249
Compress outgoing frame if it's a data frame.
250
251
Parameters:
252
- frame: Original frame
253
254
Returns:
255
Frame: Compressed frame with RSV1 bit set if compressed
256
"""
257
258
def decode(self, frame: Frame) -> Frame:
259
"""
260
Decompress incoming frame if RSV1 bit is set.
261
262
Parameters:
263
- frame: Compressed frame
264
265
Returns:
266
Frame: Decompressed frame
267
268
Raises:
269
- ProtocolError: If decompression fails
270
"""
271
```
272
273
### Extension Utility Functions
274
275
Helper functions for enabling compression extensions on connections.
276
277
```python { .api }
278
def enable_client_permessage_deflate(
279
extensions: Sequence[ClientExtensionFactory] | None
280
) -> Sequence[ClientExtensionFactory]:
281
"""
282
Enable Per-Message Deflate with default settings in client extensions.
283
284
If the extension is already present, perhaps with non-default settings,
285
the configuration isn't changed.
286
287
Parameters:
288
- extensions: Existing list of client extension factories
289
290
Returns:
291
Sequence[ClientExtensionFactory]: Extensions list with deflate factory added
292
293
Usage:
294
extensions = enable_client_permessage_deflate(None)
295
websocket.connect("ws://localhost:8765", extensions=extensions)
296
297
# Or add to existing extensions:
298
existing_extensions = [custom_extension_factory]
299
extensions = enable_client_permessage_deflate(existing_extensions)
300
"""
301
302
def enable_server_permessage_deflate(
303
extensions: Sequence[ServerExtensionFactory] | None
304
) -> Sequence[ServerExtensionFactory]:
305
"""
306
Enable Per-Message Deflate with default settings in server extensions.
307
308
If the extension is already present, perhaps with non-default settings,
309
the configuration isn't changed.
310
311
Parameters:
312
- extensions: Existing list of server extension factories
313
314
Returns:
315
Sequence[ServerExtensionFactory]: Extensions list with deflate factory added
316
317
Usage:
318
extensions = enable_server_permessage_deflate(None)
319
websockets.serve(handler, "localhost", 8765, extensions=extensions)
320
321
# Or add to existing extensions:
322
existing_extensions = [custom_extension_factory]
323
extensions = enable_server_permessage_deflate(existing_extensions)
324
"""
325
```
326
327
## Usage Examples
328
329
### Basic Compression Usage
330
331
```python
332
import asyncio
333
from websockets import connect, serve
334
from websockets.extensions.permessage_deflate import enable_client_permessage_deflate, enable_server_permessage_deflate
335
336
async def compression_client():
337
"""Client with compression enabled."""
338
# Enable compression with default settings
339
extensions = enable_client_permessage_deflate(None)
340
341
async with connect(
342
"ws://localhost:8765",
343
extensions=extensions
344
) as websocket:
345
# Send large text message (will be compressed)
346
large_message = "Hello, WebSocket! " * 1000 # ~18KB message
347
await websocket.send(large_message)
348
349
response = await websocket.recv()
350
print(f"Received: {len(response)} characters")
351
352
async def compression_server():
353
"""Server with compression enabled."""
354
# Enable compression with default settings
355
extensions = enable_server_permessage_deflate(None)
356
357
async def handler(websocket):
358
async for message in websocket:
359
print(f"Received compressed message: {len(message)} chars")
360
# Echo back (will be compressed)
361
await websocket.send(f"Echo: {message}")
362
363
async with serve(
364
handler,
365
"localhost",
366
8765,
367
extensions=extensions
368
):
369
print("Compression server started")
370
await asyncio.Future()
371
372
# Run server in one terminal, client in another
373
# asyncio.run(compression_server())
374
# asyncio.run(compression_client())
375
```
376
377
### Advanced Compression Configuration
378
379
```python
380
import asyncio
381
from websockets import connect, serve
382
from websockets.extensions import enable_client_permessage_deflate, enable_server_permessage_deflate
383
384
async def advanced_compression_example():
385
"""Demonstrate advanced compression configurations."""
386
387
# Standard compression configuration
388
standard_extensions = enable_client_permessage_deflate(None)
389
390
# Custom compression configuration with specific settings
391
from websockets.extensions.permessage_deflate import ClientPerMessageDeflateFactory
392
custom_factory = ClientPerMessageDeflateFactory(
393
server_max_window_bits=12, # Medium server window
394
client_max_window_bits=12, # Medium client window
395
compress_settings={"memLevel": 6}
396
)
397
custom_extensions = [custom_factory]
398
399
# Test configurations
400
configurations = [
401
("Standard Compression", standard_extensions),
402
("Custom Compression", custom_extensions)
403
]
404
405
for name, extensions in configurations:
406
print(f"\n=== {name} Configuration ===")
407
408
try:
409
async with connect(
410
"ws://localhost:8765",
411
extensions=extensions,
412
open_timeout=5
413
) as websocket:
414
# Test with different message sizes
415
test_messages = [
416
"Small message",
417
"Medium message " * 50, # ~750 bytes
418
"Large message " * 500, # ~7KB
419
"Huge message " * 5000 # ~70KB
420
]
421
422
for i, message in enumerate(test_messages):
423
await websocket.send(message)
424
response = await websocket.recv()
425
426
print(f"Message {i+1}: {len(message)} -> {len(response)} chars")
427
428
except Exception as e:
429
print(f"Configuration failed: {e}")
430
431
# asyncio.run(advanced_compression_example())
432
```
433
434
### Custom Extension Implementation
435
436
```python
437
from websockets.extensions.base import Extension, ClientExtensionFactory, ServerExtensionFactory
438
from websockets.datastructures import Frame, Opcode
439
from websockets.exceptions import NegotiationError
440
import base64
441
import hashlib
442
443
class SimpleEncryptionExtension(Extension):
444
"""Simple XOR encryption extension (for demonstration only)."""
445
446
def __init__(self, key: bytes):
447
super().__init__("simple-encrypt")
448
self.key = key
449
450
def encode(self, frame: Frame) -> Frame:
451
"""Encrypt outgoing data frames."""
452
if frame.opcode in [Opcode.TEXT, Opcode.BINARY]:
453
# Simple XOR encryption (NOT secure!)
454
encrypted_data = bytes(
455
b ^ self.key[i % len(self.key)]
456
for i, b in enumerate(frame.data)
457
)
458
return frame._replace(data=encrypted_data)
459
return frame
460
461
def decode(self, frame: Frame) -> Frame:
462
"""Decrypt incoming data frames."""
463
if frame.opcode in [Opcode.TEXT, Opcode.BINARY]:
464
# XOR decryption (same as encryption for XOR)
465
decrypted_data = bytes(
466
b ^ self.key[i % len(self.key)]
467
for i, b in enumerate(frame.data)
468
)
469
return frame._replace(data=decrypted_data)
470
return frame
471
472
class SimpleEncryptionClientFactory(ClientExtensionFactory):
473
"""Client factory for simple encryption extension."""
474
475
def __init__(self, password: str):
476
# Generate key from password
477
key = hashlib.sha256(password.encode()).digest()[:16] # 16-byte key
478
super().__init__("simple-encrypt", [("key", base64.b64encode(key).decode())])
479
self.key = key
480
481
def process_response_params(self, parameters):
482
"""Process server response parameters."""
483
# In real implementation, would validate server parameters
484
return SimpleEncryptionExtension(self.key)
485
486
class SimpleEncryptionServerFactory(ServerExtensionFactory):
487
"""Server factory for simple encryption extension."""
488
489
def __init__(self, password: str):
490
super().__init__("simple-encrypt")
491
self.key = hashlib.sha256(password.encode()).digest()[:16]
492
493
def process_request_params(self, parameters):
494
"""Process client request parameters."""
495
# Validate client key matches our key
496
client_key_param = next((p for p in parameters if p[0] == "key"), None)
497
if not client_key_param:
498
raise NegotiationError("Missing encryption key")
499
500
client_key = base64.b64decode(client_key_param[1])
501
if client_key != self.key:
502
raise NegotiationError("Invalid encryption key")
503
504
# Return response parameters and extension
505
response_params = [("key", base64.b64encode(self.key).decode())]
506
extension = SimpleEncryptionExtension(self.key)
507
508
return response_params, extension
509
510
async def custom_extension_example():
511
"""Demonstrate custom extension usage."""
512
import asyncio
513
from websockets import serve, connect
514
515
password = "shared-secret-password"
516
517
# Server with custom extension
518
async def encrypted_handler(websocket):
519
async for message in websocket:
520
print(f"Server received (decrypted): {message}")
521
await websocket.send(f"Encrypted echo: {message}")
522
523
server_extensions = [SimpleEncryptionServerFactory(password)]
524
525
# Start server (in real code, would run in separate task)
526
print("Starting server with custom encryption extension...")
527
528
# Client with custom extension
529
client_extensions = [SimpleEncryptionClientFactory(password)]
530
531
try:
532
async with connect(
533
"ws://localhost:8765",
534
extensions=client_extensions
535
) as websocket:
536
# Send encrypted message
537
await websocket.send("This message is encrypted!")
538
response = await websocket.recv()
539
print(f"Client received (decrypted): {response}")
540
541
except Exception as e:
542
print(f"Custom extension error: {e}")
543
544
# Note: This is a demonstration. Real encryption would use proper cryptography
545
# asyncio.run(custom_extension_example())
546
```
547
548
### Extension Chain Implementation
549
550
```python
551
from websockets.extensions.base import Extension
552
from websockets.datastructures import Frame, Opcode
553
import zlib
554
import time
555
556
class CompressionExtension(Extension):
557
"""Simple compression extension."""
558
559
def __init__(self):
560
super().__init__("simple-compress")
561
self.compressor = zlib.compressobj()
562
self.decompressor = zlib.decompressobj()
563
564
def encode(self, frame: Frame) -> Frame:
565
if frame.opcode == Opcode.TEXT and len(frame.data) > 100:
566
compressed = self.compressor.compress(frame.data)
567
compressed += self.compressor.flush(zlib.Z_SYNC_FLUSH)
568
return frame._replace(data=compressed, rsv1=True)
569
return frame
570
571
def decode(self, frame: Frame) -> Frame:
572
if frame.rsv1 and frame.opcode == Opcode.TEXT:
573
decompressed = self.decompressor.decompress(frame.data)
574
return frame._replace(data=decompressed, rsv1=False)
575
return frame
576
577
class LoggingExtension(Extension):
578
"""Logging extension for debugging."""
579
580
def __init__(self, name: str = "logger"):
581
super().__init__(name)
582
self.message_count = 0
583
584
def encode(self, frame: Frame) -> Frame:
585
self.message_count += 1
586
print(f"[OUT #{self.message_count}] {frame.opcode.name}: {len(frame.data)} bytes")
587
return frame
588
589
def decode(self, frame: Frame) -> Frame:
590
self.message_count += 1
591
print(f"[IN #{self.message_count}] {frame.opcode.name}: {len(frame.data)} bytes")
592
return frame
593
594
class ExtensionChain:
595
"""Chain multiple extensions together."""
596
597
def __init__(self, extensions: List[Extension]):
598
self.extensions = extensions
599
600
def encode(self, frame: Frame) -> Frame:
601
"""Apply all extensions to outgoing frame."""
602
for extension in self.extensions:
603
frame = extension.encode(frame)
604
return frame
605
606
def decode(self, frame: Frame) -> Frame:
607
"""Apply all extensions to incoming frame (in reverse order)."""
608
for extension in reversed(self.extensions):
609
frame = extension.decode(frame)
610
return frame
611
612
def extension_chain_example():
613
"""Demonstrate extension chaining."""
614
from websockets.datastructures import Frame, Opcode
615
616
# Create extension chain
617
extensions = [
618
LoggingExtension("pre-compress"),
619
CompressionExtension(),
620
LoggingExtension("post-compress")
621
]
622
623
chain = ExtensionChain(extensions)
624
625
# Test message
626
test_message = "This is a test message that should be compressed! " * 10
627
test_frame = Frame(Opcode.TEXT, test_message.encode())
628
629
print("=== Extension Chain Test ===")
630
print(f"Original message: {len(test_message)} characters")
631
632
# Encode (outgoing)
633
print("\nEncoding (outgoing):")
634
encoded_frame = chain.encode(test_frame)
635
print(f"Final encoded size: {len(encoded_frame.data)} bytes")
636
637
# Decode (incoming)
638
print("\nDecoding (incoming):")
639
decoded_frame = chain.decode(encoded_frame)
640
decoded_message = decoded_frame.data.decode()
641
print(f"Final decoded size: {len(decoded_message)} characters")
642
643
# Verify roundtrip
644
print(f"\nRoundtrip successful: {decoded_message == test_message}")
645
646
extension_chain_example()
647
```
648
649
### Extension Performance Testing
650
651
```python
652
import time
653
import asyncio
654
from websockets import connect, serve
655
from websockets.extensions import enable_client_permessage_deflate
656
657
async def extension_performance_test():
658
"""Test extension performance with different configurations."""
659
660
# Test configurations
661
configs = [
662
("No Compression", None),
663
("Default Compression", enable_client_permessage_deflate(None)),
664
]
665
666
# Test messages of different sizes
667
test_messages = [
668
("Small", "Hello" * 20), # ~100 bytes
669
("Medium", "Test message " * 100), # ~1.2KB
670
("Large", "Data " * 2000), # ~10KB
671
("Huge", "Content " * 10000) # ~70KB
672
]
673
674
print("=== Extension Performance Test ===")
675
print(f"{'Config':<20} {'Size':<8} {'Time (ms)':<10} {'Throughput (MB/s)':<15}")
676
print("-" * 65)
677
678
for config_name, extensions in configs:
679
for msg_name, message in test_messages:
680
try:
681
start_time = time.time()
682
683
async with connect(
684
"ws://localhost:8765",
685
extensions=extensions,
686
open_timeout=5
687
) as websocket:
688
# Send message multiple times for better measurement
689
iterations = 10
690
691
for _ in range(iterations):
692
await websocket.send(message)
693
response = await websocket.recv()
694
695
end_time = time.time()
696
697
# Calculate metrics
698
total_time = (end_time - start_time) * 1000 # ms
699
data_size = len(message) * iterations * 2 # send + receive
700
throughput = (data_size / (1024 * 1024)) / (total_time / 1000) # MB/s
701
702
print(f"{config_name:<20} {msg_name:<8} {total_time/iterations:>8.1f} {throughput:>13.2f}")
703
704
except Exception as e:
705
print(f"{config_name:<20} {msg_name:<8} ERROR: {e}")
706
707
# Note: Run with a test server
708
# asyncio.run(extension_performance_test())
709
```
710
711
### Extension Debugging Tools
712
713
```python
714
from websockets.extensions.base import Extension
715
from websockets.datastructures import Frame
716
import json
717
import time
718
719
class DebugExtension(Extension):
720
"""Extension for debugging frame processing."""
721
722
def __init__(self, name: str = "debug"):
723
super().__init__(name)
724
self.stats = {
725
"frames_encoded": 0,
726
"frames_decoded": 0,
727
"bytes_in": 0,
728
"bytes_out": 0,
729
"start_time": time.time()
730
}
731
732
def encode(self, frame: Frame) -> Frame:
733
self.stats["frames_encoded"] += 1
734
self.stats["bytes_out"] += len(frame.data)
735
736
print(f"[DEBUG] Encoding frame:")
737
print(f" Opcode: {frame.opcode.name}")
738
print(f" Size: {len(frame.data)} bytes")
739
print(f" FIN: {frame.fin}")
740
print(f" RSV: {frame.rsv1}{frame.rsv2}{frame.rsv3}")
741
742
return frame
743
744
def decode(self, frame: Frame) -> Frame:
745
self.stats["frames_decoded"] += 1
746
self.stats["bytes_in"] += len(frame.data)
747
748
print(f"[DEBUG] Decoding frame:")
749
print(f" Opcode: {frame.opcode.name}")
750
print(f" Size: {len(frame.data)} bytes")
751
print(f" FIN: {frame.fin}")
752
print(f" RSV: {frame.rsv1}{frame.rsv2}{frame.rsv3}")
753
754
return frame
755
756
def get_stats(self) -> dict:
757
"""Get extension statistics."""
758
runtime = time.time() - self.stats["start_time"]
759
760
return {
761
**self.stats,
762
"runtime_seconds": runtime,
763
"frames_per_second": (self.stats["frames_encoded"] + self.stats["frames_decoded"]) / runtime,
764
"bytes_per_second": (self.stats["bytes_in"] + self.stats["bytes_out"]) / runtime
765
}
766
767
def print_stats(self):
768
"""Print formatted statistics."""
769
stats = self.get_stats()
770
771
print(f"\n=== {self.name} Extension Statistics ===")
772
print(f"Runtime: {stats['runtime_seconds']:.2f} seconds")
773
print(f"Frames encoded: {stats['frames_encoded']}")
774
print(f"Frames decoded: {stats['frames_decoded']}")
775
print(f"Bytes in: {stats['bytes_in']:,}")
776
print(f"Bytes out: {stats['bytes_out']:,}")
777
print(f"Frames/sec: {stats['frames_per_second']:.2f}")
778
print(f"Bytes/sec: {stats['bytes_per_second']:,.0f}")
779
780
def debug_extension_example():
781
"""Demonstrate debug extension."""
782
from websockets.datastructures import Frame, Opcode
783
784
debug_ext = DebugExtension("test-debug")
785
786
# Simulate frame processing
787
test_frames = [
788
Frame(Opcode.TEXT, b"Hello, World!"),
789
Frame(Opcode.BINARY, b"\x00\x01\x02\x03"),
790
Frame(Opcode.PING, b"ping"),
791
Frame(Opcode.PONG, b"pong")
792
]
793
794
print("=== Debug Extension Example ===")
795
796
# Process frames through extension
797
for frame in test_frames:
798
# Simulate outgoing frame
799
encoded = debug_ext.encode(frame)
800
801
# Simulate incoming frame
802
decoded = debug_ext.decode(encoded)
803
804
# Print statistics
805
debug_ext.print_stats()
806
807
debug_extension_example()
808
```