The little ASGI library that shines.
Starlette provides first-class WebSocket support for real-time, bidirectional communication between clients and servers, enabling interactive applications like chat systems, live updates, and collaborative tools.
from starlette.websockets import WebSocket, WebSocketState, WebSocketDisconnect
from starlette.datastructures import URL, Headers, QueryParams
from starlette.types import Scope, Receive, Send
from typing import Any, AsyncIterator
class WebSocket:
"""
WebSocket connection handler providing bidirectional communication.
The WebSocket class manages:
- Connection lifecycle (accept, close)
- Message sending and receiving (text, binary, JSON)
- Connection state management
- Subprotocol negotiation
- Error handling and disconnection
"""
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""
Initialize WebSocket from ASGI scope.
Args:
scope: ASGI WebSocket scope
receive: ASGI receive callable
send: ASGI send callable
"""
# Inherited from HTTPConnection
@property
def url(self) -> URL:
"""WebSocket connection URL."""
@property
def base_url(self) -> URL:
"""Base URL for generating absolute URLs."""
@property
def headers(self) -> Headers:
"""WebSocket handshake headers."""
@property
def query_params(self) -> QueryParams:
"""URL query parameters."""
@property
def path_params(self) -> dict[str, Any]:
"""Path parameters from URL pattern."""
@property
def cookies(self) -> dict[str, str]:
"""Cookies from handshake request."""
@property
def client(self) -> Address | None:
"""Client address information."""
@property
def session(self) -> dict[str, Any]:
"""Session data (requires SessionMiddleware)."""
@property
def auth(self) -> Any:
"""Authentication data (requires AuthenticationMiddleware)."""
@property
def user(self) -> Any:
"""User object (requires AuthenticationMiddleware)."""
@property
def state(self) -> State:
"""WebSocket-scoped state storage."""
# WebSocket-specific properties
@property
def client_state(self) -> WebSocketState:
"""Client connection state."""
@property
def application_state(self) -> WebSocketState:
"""Application connection state."""
# Connection management
async def accept(
self,
subprotocol: str | None = None,
headers: Sequence[tuple[bytes, bytes]] | None = None,
) -> None:
"""
Accept WebSocket connection.
Args:
subprotocol: Selected subprotocol from client's list
headers: Additional response headers
Raises:
RuntimeError: If connection already accepted/closed
"""
async def close(
self,
code: int = 1000,
reason: str | None = None
) -> None:
"""
Close WebSocket connection.
Args:
code: WebSocket close code (default: 1000 normal closure)
reason: Optional close reason string
"""
async def send_denial_response(self, response: Response) -> None:
"""
Send HTTP response instead of accepting WebSocket.
Used for authentication/authorization failures during handshake.
Args:
response: HTTP response to send
"""
# Message receiving
async def receive(self) -> dict[str, Any]:
"""
Receive raw WebSocket message.
Returns:
dict: ASGI WebSocket message
Raises:
WebSocketDisconnect: When connection closes
"""
async def receive_text(self) -> str:
"""
Receive text message.
Returns:
str: Text message content
Raises:
WebSocketDisconnect: When connection closes
RuntimeError: If message is not text
"""
async def receive_bytes(self) -> bytes:
"""
Receive binary message.
Returns:
bytes: Binary message content
Raises:
WebSocketDisconnect: When connection closes
RuntimeError: If message is not binary
"""
async def receive_json(self, mode: str = "text") -> Any:
"""
Receive JSON message.
Args:
mode: "text" or "binary" message mode
Returns:
Any: Parsed JSON data
Raises:
WebSocketDisconnect: When connection closes
ValueError: If message is not valid JSON
"""
# Message iteration
def iter_text(self) -> AsyncIterator[str]:
"""
Async iterator for text messages.
Yields:
str: Text messages until connection closes
"""
def iter_bytes(self) -> AsyncIterator[bytes]:
"""
Async iterator for binary messages.
Yields:
bytes: Binary messages until connection closes
"""
def iter_json(self, mode: str = "text") -> AsyncIterator[Any]:
"""
Async iterator for JSON messages.
Args:
mode: "text" or "binary" message mode
Yields:
Any: Parsed JSON data until connection closes
"""
# Message sending
async def send(self, message: dict[str, Any]) -> None:
"""
Send raw ASGI WebSocket message.
Args:
message: ASGI WebSocket message dict
"""
async def send_text(self, data: str) -> None:
"""
Send text message.
Args:
data: Text content to send
"""
async def send_bytes(self, data: bytes) -> None:
"""
Send binary message.
Args:
data: Binary content to send
"""
async def send_json(
self,
data: Any,
mode: str = "text"
) -> None:
"""
Send JSON message.
Args:
data: JSON-serializable data
mode: "text" or "binary" message mode
"""
def url_for(self, name: str, /, **path_params: Any) -> URL:
"""
Generate absolute URL for named route.
Args:
name: Route name
**path_params: Path parameter values
Returns:
URL: Absolute URL
"""from enum import Enum
class WebSocketState(Enum):
"""WebSocket connection states."""
CONNECTING = 0 # Initial state before accept/deny
CONNECTED = 1 # Connection accepted and active
DISCONNECTED = 2 # Connection closed
RESPONSE = 3 # HTTP response sent instead of WebSocket
class WebSocketDisconnect(Exception):
"""
Exception raised when WebSocket connection closes.
Contains close code and optional reason.
"""
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
"""
Initialize disconnect exception.
Args:
code: WebSocket close code
reason: Optional close reason
"""
self.code = code
self.reason = reason
class WebSocketClose:
"""
ASGI application for closing WebSocket connections.
Can be used as a route endpoint to immediately close connections.
"""
def __init__(self, code: int = 1000, reason: str | None = None) -> None:
"""
Initialize WebSocket close handler.
Args:
code: Close code to send
reason: Close reason to send
"""from starlette.applications import Starlette
from starlette.routing import WebSocketRoute
from starlette.websockets import WebSocket, WebSocketDisconnect
async def websocket_endpoint(websocket: WebSocket):
# Accept the connection
await websocket.accept()
try:
while True:
# Receive message from client
message = await websocket.receive_text()
# Echo message back
await websocket.send_text(f"Echo: {message}")
except WebSocketDisconnect:
print("Client disconnected")
app = Starlette(routes=[
WebSocketRoute("/ws", websocket_endpoint)
])async def json_websocket(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Receive JSON data
data = await websocket.receive_json()
# Process message based on type
if data.get("type") == "ping":
await websocket.send_json({
"type": "pong",
"timestamp": time.time()
})
elif data.get("type") == "message":
await websocket.send_json({
"type": "response",
"original": data.get("content"),
"processed": data.get("content", "").upper()
})
else:
await websocket.send_json({
"type": "error",
"message": "Unknown message type"
})
except WebSocketDisconnect:
print("Client disconnected")
except ValueError as e:
# Invalid JSON
await websocket.send_json({
"type": "error",
"message": "Invalid JSON"
})
await websocket.close(code=1003) # Unsupported datafrom starlette.routing import WebSocketRoute
async def user_websocket(websocket: WebSocket):
# Extract path parameters
user_id = websocket.path_params["user_id"]
room_id = websocket.path_params.get("room_id")
await websocket.accept()
try:
# Send welcome message
await websocket.send_json({
"type": "welcome",
"user_id": int(user_id),
"room_id": room_id and int(room_id)
})
async for message in websocket.iter_json():
# Process user messages in room context
response = {
"type": "message",
"user_id": int(user_id),
"room_id": room_id and int(room_id),
"content": message.get("content")
}
await websocket.send_json(response)
except WebSocketDisconnect:
print(f"User {user_id} disconnected from room {room_id}")
routes = [
WebSocketRoute("/ws/user/{user_id:int}", user_websocket),
WebSocketRoute("/ws/room/{room_id:int}/user/{user_id:int}", user_websocket),
]async def websocket_with_auth(websocket: WebSocket):
# Check authentication in query params or headers
token = websocket.query_params.get("token")
auth_header = websocket.headers.get("authorization")
# Validate authentication
if not token and not auth_header:
await websocket.close(code=1008, reason="Authentication required")
return
# Extract user info from token
user = authenticate_token(token or auth_header.split(" ")[-1])
if not user:
await websocket.close(code=1008, reason="Invalid token")
return
# Store user in WebSocket state
websocket.state.user = user
await websocket.accept()
try:
await websocket.send_json({
"type": "authenticated",
"user": user.username
})
async for message in websocket.iter_json():
# Process authenticated user messages
await process_user_message(websocket.state.user, message)
except WebSocketDisconnect:
print(f"User {user.username} disconnected")import json
from typing import Dict, Set
class ConnectionManager:
"""Manage multiple WebSocket connections."""
def __init__(self):
self.active_connections: Dict[str, Set[WebSocket]] = {}
self.user_connections: Dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, room_id: str, user_id: str):
"""Add connection to room and user mapping."""
await websocket.accept()
# Add to room
if room_id not in self.active_connections:
self.active_connections[room_id] = set()
self.active_connections[room_id].add(websocket)
# Add to user mapping
self.user_connections[user_id] = websocket
# Notify room of new connection
await self.broadcast_to_room(room_id, {
"type": "user_joined",
"user_id": user_id,
"count": len(self.active_connections[room_id])
}, exclude=websocket)
def disconnect(self, websocket: WebSocket, room_id: str, user_id: str):
"""Remove connection from room and user mapping."""
if room_id in self.active_connections:
self.active_connections[room_id].discard(websocket)
if not self.active_connections[room_id]:
del self.active_connections[room_id]
if user_id in self.user_connections:
del self.user_connections[user_id]
async def send_personal_message(self, user_id: str, message: dict):
"""Send message to specific user."""
if user_id in self.user_connections:
websocket = self.user_connections[user_id]
await websocket.send_json(message)
async def broadcast_to_room(self, room_id: str, message: dict, exclude: WebSocket = None):
"""Broadcast message to all connections in room."""
if room_id in self.active_connections:
for connection in self.active_connections[room_id]:
if connection != exclude:
try:
await connection.send_json(message)
except:
# Connection is broken, remove it
self.active_connections[room_id].discard(connection)
# Global connection manager
manager = ConnectionManager()
async def chat_websocket(websocket: WebSocket):
room_id = websocket.path_params["room_id"]
user_id = websocket.query_params.get("user_id", "anonymous")
await manager.connect(websocket, room_id, user_id)
try:
async for data in websocket.iter_json():
message = {
"type": "message",
"user_id": user_id,
"room_id": room_id,
"content": data.get("content"),
"timestamp": time.time()
}
# Broadcast to room
await manager.broadcast_to_room(room_id, message)
except WebSocketDisconnect:
manager.disconnect(websocket, room_id, user_id)
# Notify room of disconnection
await manager.broadcast_to_room(room_id, {
"type": "user_left",
"user_id": user_id,
"count": len(manager.active_connections.get(room_id, set()))
})async def subprotocol_websocket(websocket: WebSocket):
# Get requested subprotocols from client
subprotocols = websocket.headers.get("sec-websocket-protocol", "").split(",")
subprotocols = [p.strip() for p in subprotocols if p.strip()]
# Select supported subprotocol
supported = ["chat", "notification", "api.v1"]
selected = None
for protocol in subprotocols:
if protocol in supported:
selected = protocol
break
# Accept with selected subprotocol
await websocket.accept(subprotocol=selected)
try:
# Handle different protocols
if selected == "chat":
await handle_chat_protocol(websocket)
elif selected == "notification":
await handle_notification_protocol(websocket)
elif selected == "api.v1":
await handle_api_protocol(websocket)
else:
await handle_default_protocol(websocket)
except WebSocketDisconnect:
print(f"Client disconnected (protocol: {selected})")
async def handle_chat_protocol(websocket: WebSocket):
"""Handle chat-specific protocol."""
async for message in websocket.iter_text():
# Chat protocol: plain text messages
response = f"[CHAT] {message}"
await websocket.send_text(response)
async def handle_api_protocol(websocket: WebSocket):
"""Handle API-specific protocol."""
async for data in websocket.iter_json():
# API protocol: structured JSON commands
if data.get("command") == "list_users":
await websocket.send_json({
"type": "user_list",
"users": get_active_users()
})
elif data.get("command") == "send_message":
await websocket.send_json({
"type": "message_sent",
"id": data.get("message_id")
})import asyncio
import signal
from typing import Set
class WebSocketServer:
def __init__(self):
self.connections: Set[WebSocket] = set()
self.shutdown_event = asyncio.Event()
async def websocket_endpoint(self, websocket: WebSocket):
await websocket.accept()
self.connections.add(websocket)
try:
# Send initial connection message
await websocket.send_json({
"type": "connected",
"server_time": time.time()
})
# Handle messages with timeout
while not self.shutdown_event.is_set():
try:
# Use timeout to allow checking shutdown event
message = await asyncio.wait_for(
websocket.receive_json(),
timeout=1.0
)
# Process message
await self.process_message(websocket, message)
except asyncio.TimeoutError:
# Timeout is normal, continue loop
continue
except ValueError:
# Invalid JSON
await websocket.send_json({
"type": "error",
"message": "Invalid JSON format"
})
except Exception as e:
# Unexpected error
await websocket.send_json({
"type": "error",
"message": "Server error occurred"
})
break
except WebSocketDisconnect:
pass
finally:
# Cleanup connection
self.connections.discard(websocket)
print(f"Connection closed. Active: {len(self.connections)}")
async def process_message(self, websocket: WebSocket, message: dict):
"""Process incoming message with error handling."""
try:
msg_type = message.get("type")
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
elif msg_type == "echo":
await websocket.send_json({
"type": "echo_response",
"data": message.get("data")
})
else:
await websocket.send_json({
"type": "error",
"message": f"Unknown message type: {msg_type}"
})
except Exception as e:
print(f"Error processing message: {e}")
await websocket.send_json({
"type": "error",
"message": "Failed to process message"
})
async def broadcast_shutdown(self):
"""Notify all connections of shutdown."""
if self.connections:
message = {
"type": "server_shutdown",
"message": "Server is shutting down"
}
# Send to all connections
await asyncio.gather(
*[conn.send_json(message) for conn in self.connections],
return_exceptions=True
)
# Close all connections
await asyncio.gather(
*[conn.close(code=1001, reason="Server shutdown") for conn in self.connections],
return_exceptions=True
)
def setup_shutdown_handlers(self):
"""Setup graceful shutdown handlers."""
def signal_handler(sig, frame):
print(f"Received signal {sig}, starting graceful shutdown...")
self.shutdown_event.set()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
# Usage
server = WebSocketServer()
server.setup_shutdown_handlers()
async def websocket_endpoint(websocket: WebSocket):
await server.websocket_endpoint(websocket)from starlette.testclient import TestClient
def test_websocket():
with TestClient(app) as client:
with client.websocket_connect("/ws") as websocket:
# Send message
websocket.send_text("Hello")
# Receive response
data = websocket.receive_text()
assert data == "Echo: Hello"
def test_websocket_json():
with TestClient(app) as client:
with client.websocket_connect("/ws/json") as websocket:
# Send JSON message
websocket.send_json({"type": "ping"})
# Receive JSON response
data = websocket.receive_json()
assert data["type"] == "pong"
def test_websocket_with_params():
with TestClient(app) as client:
with client.websocket_connect("/ws/user/123?token=abc") as websocket:
# Test authenticated connection
welcome = websocket.receive_json()
assert welcome["user_id"] == 123
def test_websocket_disconnect():
with TestClient(app) as client:
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello")
websocket.close()
# Verify connection closed gracefully
assert websocket.client_state == WebSocketState.DISCONNECTEDimport asyncio
from datetime import datetime
async def live_updates_websocket(websocket: WebSocket):
await websocket.accept()
try:
# Send periodic updates
while True:
# Get latest data
data = {
"type": "update",
"timestamp": datetime.now().isoformat(),
"data": get_latest_data(),
"metrics": get_system_metrics()
}
await websocket.send_json(data)
# Wait before next update
await asyncio.sleep(5)
except WebSocketDisconnect:
print("Client disconnected from live updates")
def get_latest_data():
# Fetch real-time data
return {"value": random.randint(1, 100)}
def get_system_metrics():
# Get system metrics
return {
"cpu": psutil.cpu_percent(),
"memory": psutil.virtual_memory().percent
}class DocumentManager:
def __init__(self):
self.documents = {}
self.connections = {}
async def join_document(self, websocket: WebSocket, doc_id: str, user_id: str):
if doc_id not in self.documents:
self.documents[doc_id] = {"content": "", "version": 0}
if doc_id not in self.connections:
self.connections[doc_id] = {}
self.connections[doc_id][user_id] = websocket
# Send current document state
await websocket.send_json({
"type": "document_state",
"content": self.documents[doc_id]["content"],
"version": self.documents[doc_id]["version"]
})
# Notify other users
await self.broadcast_to_document(doc_id, {
"type": "user_joined",
"user_id": user_id
}, exclude=user_id)
async def handle_edit(self, doc_id: str, user_id: str, operation: dict):
# Apply operation to document
self.documents[doc_id]["content"] = apply_operation(
self.documents[doc_id]["content"],
operation
)
self.documents[doc_id]["version"] += 1
# Broadcast change to other users
await self.broadcast_to_document(doc_id, {
"type": "document_change",
"operation": operation,
"user_id": user_id,
"version": self.documents[doc_id]["version"]
}, exclude=user_id)
async def broadcast_to_document(self, doc_id: str, message: dict, exclude: str = None):
if doc_id in self.connections:
for user_id, websocket in self.connections[doc_id].items():
if user_id != exclude:
try:
await websocket.send_json(message)
except:
# Remove broken connection
del self.connections[doc_id][user_id]Starlette's WebSocket support enables building real-time, interactive applications with robust connection management, error handling, and testing capabilities.
Install with Tessl CLI
npx tessl i tessl/pypi-starlette