FastAPI WebSocket Guide: Build Real-Time Applications with Python
WebSockets enable real-time, bidirectional communication between clients and servers, making them essential for modern web applications. FastAPI provides excellent WebSocket support with its asynchronous architecture, allowing you to build scalable real-time features like chat applications, live notifications, and collaborative tools.
This comprehensive guide covers everything you need to know about implementing WebSockets in FastAPI, from basic concepts to production deployment strategies.
Understanding WebSockets in FastAPI
What are WebSockets?
WebSockets provide a persistent, full-duplex communication channel between clients and servers. Unlike traditional HTTP requests that follow a request-response pattern, WebSockets maintain an open connection that allows both parties to send messages at any time.
Key advantages:
- Real-time communication: Instant message exchange
- Low latency: No connection overhead after handshake
- Bidirectional: Both client and server can initiate communication
- Efficient: Reduced server load compared to polling
FastAPI WebSocket Features
FastAPI's WebSocket implementation leverages ASGI (Asynchronous Server Gateway Interface) for high-performance async operations:
from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect
app = FastAPI()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Echo: {data}")
except WebSocketDisconnect:
print("Client disconnected")
Basic WebSocket Implementation
Simple Echo Server
Let's start with a basic WebSocket server that echoes messages back to the client:
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
app = FastAPI()
# HTML client for testing
html = """
<!DOCTYPE html>
<html>
<head>
<title>WebSocket Test</title>
</head>
<body>
<h1>WebSocket Test</h1>
<input type="text" id="messageInput" placeholder="Enter message">
<button onclick="sendMessage()">Send</button>
<div id="messages"></div>
<script>
const ws = new WebSocket("ws://localhost:8000/ws");
ws.onmessage = function(event) {
const messages = document.getElementById("messages");
messages.innerHTML += `<div>${event.data}</div>`;
};
function sendMessage() {
const input = document.getElementById("messageInput");
ws.send(input.value);
input.value = "";
}
</script>
</body>
</html>
"""
@app.get("/")
async def get():
return HTMLResponse(html)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
data = await websocket.receive_text()
await websocket.send_text(f"Message received: {data}")
except WebSocketDisconnect:
print("Client disconnected")
Handling Different Message Types
WebSockets can handle various message types including text, binary, and JSON:
import json
from typing import Dict, Any
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
try:
while True:
# Receive text message
message = await websocket.receive_text()
try:
# Parse JSON message
data = json.loads(message)
message_type = data.get("type")
if message_type == "ping":
await websocket.send_text(json.dumps({"type": "pong"}))
elif message_type == "chat":
await handle_chat_message(websocket, data)
elif message_type == "notification":
await handle_notification(websocket, data)
else:
await websocket.send_text(json.dumps({
"type": "error",
"message": "Unknown message type"
}))
except json.JSONDecodeError:
await websocket.send_text(json.dumps({
"type": "error",
"message": "Invalid JSON format"
}))
except WebSocketDisconnect:
print("Client disconnected")
Connection Management
Connection Manager Class
For applications with multiple clients, you need a connection manager to handle multiple WebSocket connections:
from typing import List, Dict
import asyncio
class ConnectionManager:
def __init__(self):
self.active_connections: List[WebSocket] = []
self.user_connections: Dict[str, WebSocket] = {}
async def connect(self, websocket: WebSocket, user_id: str = None):
await websocket.accept()
self.active_connections.append(websocket)
if user_id:
self.user_connections[user_id] = websocket
def disconnect(self, websocket: WebSocket, user_id: str = None):
if websocket in self.active_connections:
self.active_connections.remove(websocket)
if user_id and user_id in self.user_connections:
del self.user_connections[user_id]
async def send_personal_message(self, message: str, user_id: str):
if user_id in self.user_connections:
await self.user_connections[user_id].send_text(message)
async def broadcast(self, message: str):
disconnected = []
for connection in self.active_connections:
try:
await connection.send_text(message)
except:
disconnected.append(connection)
# Clean up disconnected connections
for connection in disconnected:
self.active_connections.remove(connection)
manager = ConnectionManager()
@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
await manager.connect(websocket, user_id)
try:
while True:
data = await websocket.receive_text()
await manager.broadcast(f"User {user_id}: {data}")
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
await manager.broadcast(f"User {user_id} left the chat")
Advanced Connection Management
For production applications, implement more sophisticated connection management:
import time
from dataclasses import dataclass
from typing import Optional
@dataclass
class ConnectionInfo:
websocket: WebSocket
user_id: str
connected_at: float
last_ping: float
metadata: Dict[str, Any]
class AdvancedConnectionManager:
def __init__(self):
self.connections: Dict[str, ConnectionInfo] = {}
self.room_connections: Dict[str, List[str]] = {}
self.heartbeat_interval = 30 # seconds
async def connect(self, websocket: WebSocket, user_id: str, metadata: Dict[str, Any] = None):
await websocket.accept()
connection_info = ConnectionInfo(
websocket=websocket,
user_id=user_id,
connected_at=time.time(),
last_ping=time.time(),
metadata=metadata or {}
)
self.connections[user_id] = connection_info
# Start heartbeat task
asyncio.create_task(self.heartbeat_task(user_id))
async def disconnect(self, user_id: str):
if user_id in self.connections:
del self.connections[user_id]
# Remove from all rooms
for room_id, users in self.room_connections.items():
if user_id in users:
users.remove(user_id)
async def join_room(self, user_id: str, room_id: str):
if room_id not in self.room_connections:
self.room_connections[room_id] = []
if user_id not in self.room_connections[room_id]:
self.room_connections[room_id].append(user_id)
async def leave_room(self, user_id: str, room_id: str):
if room_id in self.room_connections:
if user_id in self.room_connections[room_id]:
self.room_connections[room_id].remove(user_id)
async def send_to_room(self, room_id: str, message: str):
if room_id in self.room_connections:
for user_id in self.room_connections[room_id]:
await self.send_personal_message(message, user_id)
async def send_personal_message(self, message: str, user_id: str):
if user_id in self.connections:
try:
await self.connections[user_id].websocket.send_text(message)
except:
await self.disconnect(user_id)
async def heartbeat_task(self, user_id: str):
while user_id in self.connections:
try:
await self.connections[user_id].websocket.send_text(
json.dumps({"type": "ping"})
)
await asyncio.sleep(self.heartbeat_interval)
except:
await self.disconnect(user_id)
break
WebSocket Authentication
Token-Based Authentication
Implementing secure authentication for WebSocket connections:
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
import asyncio
security = HTTPBearer()
SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"
def verify_token(token: str) -> Dict[str, Any]:
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
return payload
except JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
# Method 1: Token in query parameter
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = None):
if not token:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
try:
payload = verify_token(token)
user_id = payload.get("sub")
await manager.connect(websocket, user_id)
while True:
data = await websocket.receive_text()
await manager.broadcast(f"User {user_id}: {data}")
except HTTPException:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
Advanced Authentication Patterns
For more secure authentication, implement the first-message pattern:
async def authenticate_websocket(websocket: WebSocket) -> Optional[str]:
"""Authenticate WebSocket connection via first message"""
try:
# Wait for authentication message with timeout
auth_message = await asyncio.wait_for(
websocket.receive_text(),
timeout=10.0
)
auth_data = json.loads(auth_message)
if auth_data.get("type") != "auth":
await websocket.send_text(json.dumps({
"type": "error",
"message": "Authentication required"
}))
return None
token = auth_data.get("token")
payload = verify_token(token)
await websocket.send_text(json.dumps({
"type": "auth_success",
"user_id": payload.get("sub")
}))
return payload.get("sub")
except asyncio.TimeoutError:
await websocket.send_text(json.dumps({
"type": "error",
"message": "Authentication timeout"
}))
return None
except Exception as e:
await websocket.send_text(json.dumps({
"type": "error",
"message": "Authentication failed"
}))
return None
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
user_id = await authenticate_websocket(websocket)
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await manager.connect(websocket, user_id)
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message.get("type") == "chat":
await handle_chat_message(user_id, message)
elif message.get("type") == "heartbeat":
await websocket.send_text(json.dumps({"type": "heartbeat_ack"}))
except WebSocketDisconnect:
manager.disconnect(websocket, user_id)
Real-Time Features Implementation
Chat Application
Building a complete chat application with rooms and user management:
from datetime import datetime
from typing import List, Dict
class ChatRoom:
def __init__(self, room_id: str, name: str):
self.room_id = room_id
self.name = name
self.users: List[str] = []
self.messages: List[Dict] = []
self.created_at = datetime.now()
def add_user(self, user_id: str):
if user_id not in self.users:
self.users.append(user_id)
def remove_user(self, user_id: str):
if user_id in self.users:
self.users.remove(user_id)
def add_message(self, user_id: str, message: str):
message_data = {
"id": len(self.messages) + 1,
"user_id": user_id,
"message": message,
"timestamp": datetime.now().isoformat()
}
self.messages.append(message_data)
return message_data
class ChatManager:
def __init__(self):
self.rooms: Dict[str, ChatRoom] = {}
self.connection_manager = ConnectionManager()
def create_room(self, room_id: str, name: str):
self.rooms[room_id] = ChatRoom(room_id, name)
return self.rooms[room_id]
def get_room(self, room_id: str) -> Optional[ChatRoom]:
return self.rooms.get(room_id)
async def join_room(self, user_id: str, room_id: str):
room = self.get_room(room_id)
if room:
room.add_user(user_id)
# Notify other users
await self.broadcast_to_room(room_id, {
"type": "user_joined",
"user_id": user_id,
"room_id": room_id
})
return room
return None
async def leave_room(self, user_id: str, room_id: str):
room = self.get_room(room_id)
if room:
room.remove_user(user_id)
# Notify other users
await self.broadcast_to_room(room_id, {
"type": "user_left",
"user_id": user_id,
"room_id": room_id
})
async def send_message(self, user_id: str, room_id: str, message: str):
room = self.get_room(room_id)
if room and user_id in room.users:
message_data = room.add_message(user_id, message)
await self.broadcast_to_room(room_id, {
"type": "message",
"data": message_data
})
return message_data
return None
async def broadcast_to_room(self, room_id: str, message: Dict):
room = self.get_room(room_id)
if room:
for user_id in room.users:
await self.connection_manager.send_personal_message(
json.dumps(message), user_id
)
chat_manager = ChatManager()
@app.websocket("/ws/chat/{room_id}")
async def chat_websocket(websocket: WebSocket, room_id: str):
await websocket.accept()
user_id = await authenticate_websocket(websocket)
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
await chat_manager.connection_manager.connect(websocket, user_id)
# Join room
room = await chat_manager.join_room(user_id, room_id)
if not room:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
try:
while True:
data = await websocket.receive_text()
message = json.loads(data)
if message.get("type") == "chat_message":
await chat_manager.send_message(
user_id,
room_id,
message.get("message")
)
except WebSocketDisconnect:
await chat_manager.leave_room(user_id, room_id)
chat_manager.connection_manager.disconnect(websocket, user_id)
Live Notifications System
Implementing a real-time notification system:
from enum import Enum
from typing import Optional
class NotificationType(Enum):
INFO = "info"
WARNING = "warning"
ERROR = "error"
SUCCESS = "success"
class NotificationManager:
def __init__(self):
self.connection_manager = ConnectionManager()
self.user_preferences: Dict[str, Dict] = {}
async def send_notification(
self,
user_id: str,
title: str,
message: str,
notification_type: NotificationType = NotificationType.INFO,
metadata: Dict = None
):
notification = {
"type": "notification",
"data": {
"title": title,
"message": message,
"notification_type": notification_type.value,
"timestamp": datetime.now().isoformat(),
"metadata": metadata or {}
}
}
await self.connection_manager.send_personal_message(
json.dumps(notification), user_id
)
async def broadcast_notification(
self,
title: str,
message: str,
notification_type: NotificationType = NotificationType.INFO,
exclude_users: List[str] = None
):
notification = {
"type": "broadcast_notification",
"data": {
"title": title,
"message": message,
"notification_type": notification_type.value,
"timestamp": datetime.now().isoformat()
}
}
message_text = json.dumps(notification)
for user_id in self.connection_manager.user_connections:
if exclude_users and user_id in exclude_users:
continue
await self.connection_manager.send_personal_message(
message_text, user_id
)
notification_manager = NotificationManager()
# API endpoint to send notifications
@app.post("/api/notifications/send")
async def send_notification(
user_id: str,
title: str,
message: str,
notification_type: NotificationType = NotificationType.INFO
):
await notification_manager.send_notification(
user_id, title, message, notification_type
)
return {"status": "sent"}
@app.post("/api/notifications/broadcast")
async def broadcast_notification(
title: str,
message: str,
notification_type: NotificationType = NotificationType.INFO
):
await notification_manager.broadcast_notification(
title, message, notification_type
)
return {"status": "broadcasted"}
Performance Optimization
Connection Pooling and Resource Management
import asyncio
from asyncio import Semaphore
class OptimizedConnectionManager:
def __init__(self, max_connections: int = 1000):
self.max_connections = max_connections
self.connection_semaphore = Semaphore(max_connections)
self.connections: Dict[str, ConnectionInfo] = {}
self.message_queue: asyncio.Queue = asyncio.Queue()
self.worker_tasks: List[asyncio.Task] = []
# Start message processing workers
for i in range(10): # 10 worker tasks
task = asyncio.create_task(self.message_worker())
self.worker_tasks.append(task)
async def connect(self, websocket: WebSocket, user_id: str):
if len(self.connections) >= self.max_connections:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return False
async with self.connection_semaphore:
await websocket.accept()
self.connections[user_id] = ConnectionInfo(
websocket=websocket,
user_id=user_id,
connected_at=time.time(),
last_ping=time.time(),
metadata={}
)
return True
async def message_worker(self):
while True:
try:
message_data = await self.message_queue.get()
await self.process_message(message_data)
self.message_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
print(f"Message worker error: {e}")
async def queue_message(self, user_id: str, message: str):
await self.message_queue.put({
"user_id": user_id,
"message": message,
"timestamp": time.time()
})
async def process_message(self, message_data: Dict):
user_id = message_data["user_id"]
message = message_data["message"]
if user_id in self.connections:
try:
await self.connections[user_id].websocket.send_text(message)
except:
await self.disconnect(user_id)
Message Rate Limiting
from collections import defaultdict
import time
class RateLimiter:
def __init__(self, max_messages: int = 10, window_seconds: int = 60):
self.max_messages = max_messages
self.window_seconds = window_seconds
self.user_messages: Dict[str, List[float]] = defaultdict(list)
def is_rate_limited(self, user_id: str) -> bool:
now = time.time()
user_messages = self.user_messages[user_id]
# Remove old messages outside the window
self.user_messages[user_id] = [
msg_time for msg_time in user_messages
if now - msg_time < self.window_seconds
]
# Check if user has exceeded rate limit
if len(self.user_messages[user_id]) >= self.max_messages:
return True
# Add current message
self.user_messages[user_id].append(now)
return False
rate_limiter = RateLimiter()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
user_id = await authenticate_websocket(websocket)
if not user_id:
await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
return
try:
while True:
data = await websocket.receive_text()
if rate_limiter.is_rate_limited(user_id):
await websocket.send_text(json.dumps({
"type": "error",
"message": "Rate limit exceeded"
}))
continue
# Process message
await handle_message(user_id, data)
except WebSocketDisconnect:
pass
Production Deployment
Docker Configuration
FROM python:3.11-slim
WORKDIR /app
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt
COPY . .
EXPOSE 8000
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
Nginx Configuration
upstream websocket_backend {
server app1:8000;
server app2:8000;
server app3:8000;
}
server {
listen 80;
server_name your-domain.com;
location /ws {
proxy_pass http://websocket_backend;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_read_timeout 86400;
}
}
Kubernetes Deployment
apiVersion: apps/v1
kind: Deployment
metadata:
name: websocket-app
spec:
replicas: 3
selector:
matchLabels:
app: websocket-app
template:
metadata:
labels:
app: websocket-app
spec:
containers:
- name: app
image: your-registry/websocket-app:latest
ports:
- containerPort: 8000
env:
- name: REDIS_URL
value: "redis://redis-service:6379"
resources:
requests:
memory: "256Mi"
cpu: "100m"
limits:
memory: "512Mi"
cpu: "500m"
---
apiVersion: v1
kind: Service
metadata:
name: websocket-service
spec:
selector:
app: websocket-app
ports:
- port: 8000
targetPort: 8000
type: ClusterIP
Testing WebSocket Endpoints
Unit Testing
import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch
@pytest.fixture
def client():
return TestClient(app)
def test_websocket_connection(client):
with client.websocket_connect("/ws") as websocket:
websocket.send_text("Hello")
data = websocket.receive_text()
assert "Hello" in data
def test_websocket_authentication():
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws") as websocket:
# Test without authentication
pass
@pytest.mark.asyncio
async def test_connection_manager():
manager = ConnectionManager()
mock_websocket = AsyncMock()
await manager.connect(mock_websocket, "user1")
assert "user1" in manager.user_connections
await manager.send_personal_message("test", "user1")
mock_websocket.send_text.assert_called_once_with("test")
Integration Testing
import asyncio
import websockets
import json
async def test_websocket_integration():
uri = "ws://localhost:8000/ws"
async with websockets.connect(uri) as websocket:
# Send authentication
await websocket.send(json.dumps({
"type": "auth",
"token": "valid-token"
}))
# Receive auth response
response = await websocket.recv()
auth_data = json.loads(response)
assert auth_data["type"] == "auth_success"
# Send chat message
await websocket.send(json.dumps({
"type": "chat_message",
"message": "Hello World"
}))
# Receive message
response = await websocket.recv()
message_data = json.loads(response)
assert message_data["type"] == "message"
if __name__ == "__main__":
asyncio.run(test_websocket_integration())
Monitoring and Observability
Metrics Collection
from prometheus_client import Counter, Histogram, Gauge
import time
# Metrics
websocket_connections_total = Counter(
'websocket_connections_total',
'Total WebSocket connections',
['status']
)
websocket_messages_total = Counter(
'websocket_messages_total',
'Total WebSocket messages',
['type', 'status']
)
websocket_connection_duration = Histogram(
'websocket_connection_duration_seconds',
'WebSocket connection duration'
)
active_connections = Gauge(
'websocket_active_connections',
'Active WebSocket connections'
)
class MetricsConnectionManager(ConnectionManager):
async def connect(self, websocket: WebSocket, user_id: str):
await super().connect(websocket, user_id)
websocket_connections_total.labels(status='connected').inc()
active_connections.set(len(self.active_connections))
def disconnect(self, websocket: WebSocket, user_id: str):
super().disconnect(websocket, user_id)
websocket_connections_total.labels(status='disconnected').inc()
active_connections.set(len(self.active_connections))
Health Check Endpoint
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"active_connections": len(manager.active_connections),
"uptime": time.time() - app.start_time,
"version": "1.0.0"
}
@app.get("/metrics")
async def get_metrics():
from prometheus_client import generate_latest
return Response(generate_latest(), media_type="text/plain")
Best Practices Summary
Security
- Always authenticate WebSocket connections
- Use HTTPS/WSS in production
- Implement rate limiting
- Validate all incoming messages
- Use secure token-based authentication
Performance
- Implement connection pooling
- Use message queues for high-traffic scenarios
- Monitor connection limits
- Implement heartbeat mechanisms
- Use efficient serialization (JSON, MessagePack)
Scalability
- Use Redis for multi-instance deployments
- Implement horizontal scaling with load balancers
- Use sticky sessions when needed
- Monitor resource usage
- Implement graceful shutdowns
Error Handling
- Handle WebSocket disconnections gracefully
- Implement retry mechanisms
- Log all errors appropriately
- Use structured error responses
- Implement circuit breakers for external services
Conclusion
FastAPI provides excellent WebSocket support for building real-time applications. By following the patterns and best practices outlined in this guide, you can create scalable, secure, and performant WebSocket applications.
Key takeaways:
- Use proper connection management for multiple clients
- Implement secure authentication patterns
- Design for scalability with Redis and load balancing
- Monitor and test your WebSocket endpoints
- Follow security best practices for production deployment
WebSockets enable powerful real-time features that enhance user experience and engagement. With FastAPI's async capabilities and the patterns shown in this guide, you can build robust real-time applications that scale with your needs.
Related Resources
- FastAPI Performance Optimization
- FastAPI Production Deployment
- JWT Authentication in FastAPI
- FastAPI Async Patterns
This guide covers FastAPI WebSocket implementation as of 2025. Always refer to the latest FastAPI documentation for the most current features and best practices.