Skip to content

FastAPI WebSocket Guide: Build Real-Time Applications with Python

WebSockets enable real-time, bidirectional communication between clients and servers, making them essential for modern web applications. FastAPI provides excellent WebSocket support with its asynchronous architecture, allowing you to build scalable real-time features like chat applications, live notifications, and collaborative tools.

This comprehensive guide covers everything you need to know about implementing WebSockets in FastAPI, from basic concepts to production deployment strategies.

Understanding WebSockets in FastAPI

What are WebSockets?

WebSockets provide a persistent, full-duplex communication channel between clients and servers. Unlike traditional HTTP requests that follow a request-response pattern, WebSockets maintain an open connection that allows both parties to send messages at any time.

Key advantages:

  • Real-time communication: Instant message exchange
  • Low latency: No connection overhead after handshake
  • Bidirectional: Both client and server can initiate communication
  • Efficient: Reduced server load compared to polling

FastAPI WebSocket Features

FastAPI's WebSocket implementation leverages ASGI (Asynchronous Server Gateway Interface) for high-performance async operations:

from fastapi import FastAPI, WebSocket
from fastapi.websockets import WebSocketDisconnect

app = FastAPI()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Echo: {data}")
    except WebSocketDisconnect:
        print("Client disconnected")

Basic WebSocket Implementation

Simple Echo Server

Let's start with a basic WebSocket server that echoes messages back to the client:

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse

app = FastAPI()

# HTML client for testing
html = """
<!DOCTYPE html>
<html>
<head>
    <title>WebSocket Test</title>
</head>
<body>
    <h1>WebSocket Test</h1>
    <input type="text" id="messageInput" placeholder="Enter message">
    <button onclick="sendMessage()">Send</button>
    <div id="messages"></div>

    <script>
        const ws = new WebSocket("ws://localhost:8000/ws");

        ws.onmessage = function(event) {
            const messages = document.getElementById("messages");
            messages.innerHTML += `<div>${event.data}</div>`;
        };

        function sendMessage() {
            const input = document.getElementById("messageInput");
            ws.send(input.value);
            input.value = "";
        }
    </script>
</body>
</html>
"""

@app.get("/")
async def get():
    return HTMLResponse(html)

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            data = await websocket.receive_text()
            await websocket.send_text(f"Message received: {data}")
    except WebSocketDisconnect:
        print("Client disconnected")

Handling Different Message Types

WebSockets can handle various message types including text, binary, and JSON:

import json
from typing import Dict, Any

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()
    try:
        while True:
            # Receive text message
            message = await websocket.receive_text()

            try:
                # Parse JSON message
                data = json.loads(message)
                message_type = data.get("type")

                if message_type == "ping":
                    await websocket.send_text(json.dumps({"type": "pong"}))
                elif message_type == "chat":
                    await handle_chat_message(websocket, data)
                elif message_type == "notification":
                    await handle_notification(websocket, data)
                else:
                    await websocket.send_text(json.dumps({
                        "type": "error",
                        "message": "Unknown message type"
                    }))
            except json.JSONDecodeError:
                await websocket.send_text(json.dumps({
                    "type": "error",
                    "message": "Invalid JSON format"
                }))
    except WebSocketDisconnect:
        print("Client disconnected")

Connection Management

Connection Manager Class

For applications with multiple clients, you need a connection manager to handle multiple WebSocket connections:

from typing import List, Dict
import asyncio

class ConnectionManager:
    def __init__(self):
        self.active_connections: List[WebSocket] = []
        self.user_connections: Dict[str, WebSocket] = {}

    async def connect(self, websocket: WebSocket, user_id: str = None):
        await websocket.accept()
        self.active_connections.append(websocket)
        if user_id:
            self.user_connections[user_id] = websocket

    def disconnect(self, websocket: WebSocket, user_id: str = None):
        if websocket in self.active_connections:
            self.active_connections.remove(websocket)
        if user_id and user_id in self.user_connections:
            del self.user_connections[user_id]

    async def send_personal_message(self, message: str, user_id: str):
        if user_id in self.user_connections:
            await self.user_connections[user_id].send_text(message)

    async def broadcast(self, message: str):
        disconnected = []
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except:
                disconnected.append(connection)

        # Clean up disconnected connections
        for connection in disconnected:
            self.active_connections.remove(connection)

manager = ConnectionManager()

@app.websocket("/ws/{user_id}")
async def websocket_endpoint(websocket: WebSocket, user_id: str):
    await manager.connect(websocket, user_id)
    try:
        while True:
            data = await websocket.receive_text()
            await manager.broadcast(f"User {user_id}: {data}")
    except WebSocketDisconnect:
        manager.disconnect(websocket, user_id)
        await manager.broadcast(f"User {user_id} left the chat")

Advanced Connection Management

For production applications, implement more sophisticated connection management:

import time
from dataclasses import dataclass
from typing import Optional

@dataclass
class ConnectionInfo:
    websocket: WebSocket
    user_id: str
    connected_at: float
    last_ping: float
    metadata: Dict[str, Any]

class AdvancedConnectionManager:
    def __init__(self):
        self.connections: Dict[str, ConnectionInfo] = {}
        self.room_connections: Dict[str, List[str]] = {}
        self.heartbeat_interval = 30  # seconds

    async def connect(self, websocket: WebSocket, user_id: str, metadata: Dict[str, Any] = None):
        await websocket.accept()

        connection_info = ConnectionInfo(
            websocket=websocket,
            user_id=user_id,
            connected_at=time.time(),
            last_ping=time.time(),
            metadata=metadata or {}
        )

        self.connections[user_id] = connection_info

        # Start heartbeat task
        asyncio.create_task(self.heartbeat_task(user_id))

    async def disconnect(self, user_id: str):
        if user_id in self.connections:
            del self.connections[user_id]

        # Remove from all rooms
        for room_id, users in self.room_connections.items():
            if user_id in users:
                users.remove(user_id)

    async def join_room(self, user_id: str, room_id: str):
        if room_id not in self.room_connections:
            self.room_connections[room_id] = []

        if user_id not in self.room_connections[room_id]:
            self.room_connections[room_id].append(user_id)

    async def leave_room(self, user_id: str, room_id: str):
        if room_id in self.room_connections:
            if user_id in self.room_connections[room_id]:
                self.room_connections[room_id].remove(user_id)

    async def send_to_room(self, room_id: str, message: str):
        if room_id in self.room_connections:
            for user_id in self.room_connections[room_id]:
                await self.send_personal_message(message, user_id)

    async def send_personal_message(self, message: str, user_id: str):
        if user_id in self.connections:
            try:
                await self.connections[user_id].websocket.send_text(message)
            except:
                await self.disconnect(user_id)

    async def heartbeat_task(self, user_id: str):
        while user_id in self.connections:
            try:
                await self.connections[user_id].websocket.send_text(
                    json.dumps({"type": "ping"})
                )
                await asyncio.sleep(self.heartbeat_interval)
            except:
                await self.disconnect(user_id)
                break

WebSocket Authentication

Token-Based Authentication

Implementing secure authentication for WebSocket connections:

from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from jose import JWTError, jwt
import asyncio

security = HTTPBearer()

SECRET_KEY = "your-secret-key"
ALGORITHM = "HS256"

def verify_token(token: str) -> Dict[str, Any]:
    try:
        payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
        return payload
    except JWTError:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Invalid token"
        )

# Method 1: Token in query parameter
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str = None):
    if not token:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
        return

    try:
        payload = verify_token(token)
        user_id = payload.get("sub")

        await manager.connect(websocket, user_id)

        while True:
            data = await websocket.receive_text()
            await manager.broadcast(f"User {user_id}: {data}")

    except HTTPException:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
    except WebSocketDisconnect:
        manager.disconnect(websocket, user_id)

Advanced Authentication Patterns

For more secure authentication, implement the first-message pattern:

async def authenticate_websocket(websocket: WebSocket) -> Optional[str]:
    """Authenticate WebSocket connection via first message"""
    try:
        # Wait for authentication message with timeout
        auth_message = await asyncio.wait_for(
            websocket.receive_text(), 
            timeout=10.0
        )

        auth_data = json.loads(auth_message)

        if auth_data.get("type") != "auth":
            await websocket.send_text(json.dumps({
                "type": "error",
                "message": "Authentication required"
            }))
            return None

        token = auth_data.get("token")
        payload = verify_token(token)

        await websocket.send_text(json.dumps({
            "type": "auth_success",
            "user_id": payload.get("sub")
        }))

        return payload.get("sub")

    except asyncio.TimeoutError:
        await websocket.send_text(json.dumps({
            "type": "error",
            "message": "Authentication timeout"
        }))
        return None
    except Exception as e:
        await websocket.send_text(json.dumps({
            "type": "error",
            "message": "Authentication failed"
        }))
        return None

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()

    user_id = await authenticate_websocket(websocket)
    if not user_id:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
        return

    await manager.connect(websocket, user_id)

    try:
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)

            if message.get("type") == "chat":
                await handle_chat_message(user_id, message)
            elif message.get("type") == "heartbeat":
                await websocket.send_text(json.dumps({"type": "heartbeat_ack"}))

    except WebSocketDisconnect:
        manager.disconnect(websocket, user_id)

Real-Time Features Implementation

Chat Application

Building a complete chat application with rooms and user management:

from datetime import datetime
from typing import List, Dict

class ChatRoom:
    def __init__(self, room_id: str, name: str):
        self.room_id = room_id
        self.name = name
        self.users: List[str] = []
        self.messages: List[Dict] = []
        self.created_at = datetime.now()

    def add_user(self, user_id: str):
        if user_id not in self.users:
            self.users.append(user_id)

    def remove_user(self, user_id: str):
        if user_id in self.users:
            self.users.remove(user_id)

    def add_message(self, user_id: str, message: str):
        message_data = {
            "id": len(self.messages) + 1,
            "user_id": user_id,
            "message": message,
            "timestamp": datetime.now().isoformat()
        }
        self.messages.append(message_data)
        return message_data

class ChatManager:
    def __init__(self):
        self.rooms: Dict[str, ChatRoom] = {}
        self.connection_manager = ConnectionManager()

    def create_room(self, room_id: str, name: str):
        self.rooms[room_id] = ChatRoom(room_id, name)
        return self.rooms[room_id]

    def get_room(self, room_id: str) -> Optional[ChatRoom]:
        return self.rooms.get(room_id)

    async def join_room(self, user_id: str, room_id: str):
        room = self.get_room(room_id)
        if room:
            room.add_user(user_id)

            # Notify other users
            await self.broadcast_to_room(room_id, {
                "type": "user_joined",
                "user_id": user_id,
                "room_id": room_id
            })

            return room
        return None

    async def leave_room(self, user_id: str, room_id: str):
        room = self.get_room(room_id)
        if room:
            room.remove_user(user_id)

            # Notify other users
            await self.broadcast_to_room(room_id, {
                "type": "user_left",
                "user_id": user_id,
                "room_id": room_id
            })

    async def send_message(self, user_id: str, room_id: str, message: str):
        room = self.get_room(room_id)
        if room and user_id in room.users:
            message_data = room.add_message(user_id, message)

            await self.broadcast_to_room(room_id, {
                "type": "message",
                "data": message_data
            })

            return message_data
        return None

    async def broadcast_to_room(self, room_id: str, message: Dict):
        room = self.get_room(room_id)
        if room:
            for user_id in room.users:
                await self.connection_manager.send_personal_message(
                    json.dumps(message), user_id
                )

chat_manager = ChatManager()

@app.websocket("/ws/chat/{room_id}")
async def chat_websocket(websocket: WebSocket, room_id: str):
    await websocket.accept()

    user_id = await authenticate_websocket(websocket)
    if not user_id:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
        return

    await chat_manager.connection_manager.connect(websocket, user_id)

    # Join room
    room = await chat_manager.join_room(user_id, room_id)
    if not room:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
        return

    try:
        while True:
            data = await websocket.receive_text()
            message = json.loads(data)

            if message.get("type") == "chat_message":
                await chat_manager.send_message(
                    user_id, 
                    room_id, 
                    message.get("message")
                )

    except WebSocketDisconnect:
        await chat_manager.leave_room(user_id, room_id)
        chat_manager.connection_manager.disconnect(websocket, user_id)

Live Notifications System

Implementing a real-time notification system:

from enum import Enum
from typing import Optional

class NotificationType(Enum):
    INFO = "info"
    WARNING = "warning"
    ERROR = "error"
    SUCCESS = "success"

class NotificationManager:
    def __init__(self):
        self.connection_manager = ConnectionManager()
        self.user_preferences: Dict[str, Dict] = {}

    async def send_notification(
        self, 
        user_id: str, 
        title: str, 
        message: str,
        notification_type: NotificationType = NotificationType.INFO,
        metadata: Dict = None
    ):
        notification = {
            "type": "notification",
            "data": {
                "title": title,
                "message": message,
                "notification_type": notification_type.value,
                "timestamp": datetime.now().isoformat(),
                "metadata": metadata or {}
            }
        }

        await self.connection_manager.send_personal_message(
            json.dumps(notification), user_id
        )

    async def broadcast_notification(
        self, 
        title: str, 
        message: str,
        notification_type: NotificationType = NotificationType.INFO,
        exclude_users: List[str] = None
    ):
        notification = {
            "type": "broadcast_notification",
            "data": {
                "title": title,
                "message": message,
                "notification_type": notification_type.value,
                "timestamp": datetime.now().isoformat()
            }
        }

        message_text = json.dumps(notification)

        for user_id in self.connection_manager.user_connections:
            if exclude_users and user_id in exclude_users:
                continue

            await self.connection_manager.send_personal_message(
                message_text, user_id
            )

notification_manager = NotificationManager()

# API endpoint to send notifications
@app.post("/api/notifications/send")
async def send_notification(
    user_id: str,
    title: str,
    message: str,
    notification_type: NotificationType = NotificationType.INFO
):
    await notification_manager.send_notification(
        user_id, title, message, notification_type
    )
    return {"status": "sent"}

@app.post("/api/notifications/broadcast")
async def broadcast_notification(
    title: str,
    message: str,
    notification_type: NotificationType = NotificationType.INFO
):
    await notification_manager.broadcast_notification(
        title, message, notification_type
    )
    return {"status": "broadcasted"}

Performance Optimization

Connection Pooling and Resource Management

import asyncio
from asyncio import Semaphore

class OptimizedConnectionManager:
    def __init__(self, max_connections: int = 1000):
        self.max_connections = max_connections
        self.connection_semaphore = Semaphore(max_connections)
        self.connections: Dict[str, ConnectionInfo] = {}
        self.message_queue: asyncio.Queue = asyncio.Queue()
        self.worker_tasks: List[asyncio.Task] = []

        # Start message processing workers
        for i in range(10):  # 10 worker tasks
            task = asyncio.create_task(self.message_worker())
            self.worker_tasks.append(task)

    async def connect(self, websocket: WebSocket, user_id: str):
        if len(self.connections) >= self.max_connections:
            await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
            return False

        async with self.connection_semaphore:
            await websocket.accept()
            self.connections[user_id] = ConnectionInfo(
                websocket=websocket,
                user_id=user_id,
                connected_at=time.time(),
                last_ping=time.time(),
                metadata={}
            )
            return True

    async def message_worker(self):
        while True:
            try:
                message_data = await self.message_queue.get()
                await self.process_message(message_data)
                self.message_queue.task_done()
            except asyncio.CancelledError:
                break
            except Exception as e:
                print(f"Message worker error: {e}")

    async def queue_message(self, user_id: str, message: str):
        await self.message_queue.put({
            "user_id": user_id,
            "message": message,
            "timestamp": time.time()
        })

    async def process_message(self, message_data: Dict):
        user_id = message_data["user_id"]
        message = message_data["message"]

        if user_id in self.connections:
            try:
                await self.connections[user_id].websocket.send_text(message)
            except:
                await self.disconnect(user_id)

Message Rate Limiting

from collections import defaultdict
import time

class RateLimiter:
    def __init__(self, max_messages: int = 10, window_seconds: int = 60):
        self.max_messages = max_messages
        self.window_seconds = window_seconds
        self.user_messages: Dict[str, List[float]] = defaultdict(list)

    def is_rate_limited(self, user_id: str) -> bool:
        now = time.time()
        user_messages = self.user_messages[user_id]

        # Remove old messages outside the window
        self.user_messages[user_id] = [
            msg_time for msg_time in user_messages
            if now - msg_time < self.window_seconds
        ]

        # Check if user has exceeded rate limit
        if len(self.user_messages[user_id]) >= self.max_messages:
            return True

        # Add current message
        self.user_messages[user_id].append(now)
        return False

rate_limiter = RateLimiter()

@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket):
    await websocket.accept()

    user_id = await authenticate_websocket(websocket)
    if not user_id:
        await websocket.close(code=status.WS_1008_POLICY_VIOLATION)
        return

    try:
        while True:
            data = await websocket.receive_text()

            if rate_limiter.is_rate_limited(user_id):
                await websocket.send_text(json.dumps({
                    "type": "error",
                    "message": "Rate limit exceeded"
                }))
                continue

            # Process message
            await handle_message(user_id, data)

    except WebSocketDisconnect:
        pass

Production Deployment

Docker Configuration

FROM python:3.11-slim

WORKDIR /app

COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

COPY . .

EXPOSE 8000

CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]

Nginx Configuration

upstream websocket_backend {
    server app1:8000;
    server app2:8000;
    server app3:8000;
}

server {
    listen 80;
    server_name your-domain.com;

    location /ws {
        proxy_pass http://websocket_backend;
        proxy_http_version 1.1;
        proxy_set_header Upgrade $http_upgrade;
        proxy_set_header Connection "upgrade";
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_set_header X-Forwarded-Proto $scheme;
        proxy_read_timeout 86400;
    }
}

Kubernetes Deployment

apiVersion: apps/v1
kind: Deployment
metadata:
  name: websocket-app
spec:
  replicas: 3
  selector:
    matchLabels:
      app: websocket-app
  template:
    metadata:
      labels:
        app: websocket-app
    spec:
      containers:
      - name: app
        image: your-registry/websocket-app:latest
        ports:
        - containerPort: 8000
        env:
        - name: REDIS_URL
          value: "redis://redis-service:6379"
        resources:
          requests:
            memory: "256Mi"
            cpu: "100m"
          limits:
            memory: "512Mi"
            cpu: "500m"
---
apiVersion: v1
kind: Service
metadata:
  name: websocket-service
spec:
  selector:
    app: websocket-app
  ports:
  - port: 8000
    targetPort: 8000
  type: ClusterIP

Testing WebSocket Endpoints

Unit Testing

import pytest
from fastapi.testclient import TestClient
from unittest.mock import AsyncMock, patch

@pytest.fixture
def client():
    return TestClient(app)

def test_websocket_connection(client):
    with client.websocket_connect("/ws") as websocket:
        websocket.send_text("Hello")
        data = websocket.receive_text()
        assert "Hello" in data

def test_websocket_authentication():
    with pytest.raises(WebSocketDisconnect):
        with client.websocket_connect("/ws") as websocket:
            # Test without authentication
            pass

@pytest.mark.asyncio
async def test_connection_manager():
    manager = ConnectionManager()
    mock_websocket = AsyncMock()

    await manager.connect(mock_websocket, "user1")
    assert "user1" in manager.user_connections

    await manager.send_personal_message("test", "user1")
    mock_websocket.send_text.assert_called_once_with("test")

Integration Testing

import asyncio
import websockets
import json

async def test_websocket_integration():
    uri = "ws://localhost:8000/ws"

    async with websockets.connect(uri) as websocket:
        # Send authentication
        await websocket.send(json.dumps({
            "type": "auth",
            "token": "valid-token"
        }))

        # Receive auth response
        response = await websocket.recv()
        auth_data = json.loads(response)
        assert auth_data["type"] == "auth_success"

        # Send chat message
        await websocket.send(json.dumps({
            "type": "chat_message",
            "message": "Hello World"
        }))

        # Receive message
        response = await websocket.recv()
        message_data = json.loads(response)
        assert message_data["type"] == "message"

if __name__ == "__main__":
    asyncio.run(test_websocket_integration())

Monitoring and Observability

Metrics Collection

from prometheus_client import Counter, Histogram, Gauge
import time

# Metrics
websocket_connections_total = Counter(
    'websocket_connections_total',
    'Total WebSocket connections',
    ['status']
)

websocket_messages_total = Counter(
    'websocket_messages_total',
    'Total WebSocket messages',
    ['type', 'status']
)

websocket_connection_duration = Histogram(
    'websocket_connection_duration_seconds',
    'WebSocket connection duration'
)

active_connections = Gauge(
    'websocket_active_connections',
    'Active WebSocket connections'
)

class MetricsConnectionManager(ConnectionManager):
    async def connect(self, websocket: WebSocket, user_id: str):
        await super().connect(websocket, user_id)
        websocket_connections_total.labels(status='connected').inc()
        active_connections.set(len(self.active_connections))

    def disconnect(self, websocket: WebSocket, user_id: str):
        super().disconnect(websocket, user_id)
        websocket_connections_total.labels(status='disconnected').inc()
        active_connections.set(len(self.active_connections))

Health Check Endpoint

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "active_connections": len(manager.active_connections),
        "uptime": time.time() - app.start_time,
        "version": "1.0.0"
    }

@app.get("/metrics")
async def get_metrics():
    from prometheus_client import generate_latest
    return Response(generate_latest(), media_type="text/plain")

Best Practices Summary

Security

  • Always authenticate WebSocket connections
  • Use HTTPS/WSS in production
  • Implement rate limiting
  • Validate all incoming messages
  • Use secure token-based authentication

Performance

  • Implement connection pooling
  • Use message queues for high-traffic scenarios
  • Monitor connection limits
  • Implement heartbeat mechanisms
  • Use efficient serialization (JSON, MessagePack)

Scalability

  • Use Redis for multi-instance deployments
  • Implement horizontal scaling with load balancers
  • Use sticky sessions when needed
  • Monitor resource usage
  • Implement graceful shutdowns

Error Handling

  • Handle WebSocket disconnections gracefully
  • Implement retry mechanisms
  • Log all errors appropriately
  • Use structured error responses
  • Implement circuit breakers for external services

Conclusion

FastAPI provides excellent WebSocket support for building real-time applications. By following the patterns and best practices outlined in this guide, you can create scalable, secure, and performant WebSocket applications.

Key takeaways:

  • Use proper connection management for multiple clients
  • Implement secure authentication patterns
  • Design for scalability with Redis and load balancing
  • Monitor and test your WebSocket endpoints
  • Follow security best practices for production deployment

WebSockets enable powerful real-time features that enhance user experience and engagement. With FastAPI's async capabilities and the patterns shown in this guide, you can build robust real-time applications that scale with your needs.


This guide covers FastAPI WebSocket implementation as of 2025. Always refer to the latest FastAPI documentation for the most current features and best practices.