import enum import json import typing from starlette.requests import HTTPConnection from starlette.types import Message, Receive, Scope, Send class WebSocketState(enum.Enum): CONNECTING = 0 CONNECTED = 1 DISCONNECTED = 2 class WebSocketDisconnect(Exception): def __init__(self, code: int = 1000) -> None: self.code = code class WebSocket(HTTPConnection): def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: super().__init__(scope) assert scope["type"] == "websocket" self._receive = receive self._send = send self.client_state = WebSocketState.CONNECTING self.application_state = WebSocketState.CONNECTING async def receive(self) -> Message: """ Receive ASGI websocket messages, ensuring valid state transitions. """ if self.client_state == WebSocketState.CONNECTING: message = await self._receive() message_type = message["type"] assert message_type == "websocket.connect" self.client_state = WebSocketState.CONNECTED return message elif self.client_state == WebSocketState.CONNECTED: message = await self._receive() message_type = message["type"] assert message_type in {"websocket.receive", "websocket.disconnect"} if message_type == "websocket.disconnect": self.client_state = WebSocketState.DISCONNECTED return message else: raise RuntimeError( 'Cannot call "receive" once a disconnect message has been received.' ) async def send(self, message: Message) -> None: """ Send ASGI websocket messages, ensuring valid state transitions. """ if self.application_state == WebSocketState.CONNECTING: message_type = message["type"] assert message_type in {"websocket.accept", "websocket.close"} if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED else: self.application_state = WebSocketState.CONNECTED await self._send(message) elif self.application_state == WebSocketState.CONNECTED: message_type = message["type"] assert message_type in {"websocket.send", "websocket.close"} if message_type == "websocket.close": self.application_state = WebSocketState.DISCONNECTED await self._send(message) else: raise RuntimeError('Cannot call "send" once a close message has been sent.') async def accept(self, subprotocol: str = None) -> None: if self.client_state == WebSocketState.CONNECTING: # If we haven't yet seen the 'connect' message, then wait for it first. await self.receive() await self.send({"type": "websocket.accept", "subprotocol": subprotocol}) def _raise_on_disconnect(self, message: Message) -> None: if message["type"] == "websocket.disconnect": raise WebSocketDisconnect(message["code"]) async def receive_text(self) -> str: assert self.application_state == WebSocketState.CONNECTED message = await self.receive() self._raise_on_disconnect(message) return message["text"] async def receive_bytes(self) -> bytes: assert self.application_state == WebSocketState.CONNECTED message = await self.receive() self._raise_on_disconnect(message) return message["bytes"] async def receive_json(self, mode: str = "text") -> typing.Any: assert mode in ["text", "binary"] assert self.application_state == WebSocketState.CONNECTED message = await self.receive() self._raise_on_disconnect(message) if mode == "text": text = message["text"] else: text = message["bytes"].decode("utf-8") return json.loads(text) async def iter_text(self) -> typing.AsyncIterator[str]: try: while True: yield await self.receive_text() except WebSocketDisconnect: pass async def iter_bytes(self) -> typing.AsyncIterator[bytes]: try: while True: yield await self.receive_bytes() except WebSocketDisconnect: pass async def iter_json(self) -> typing.AsyncIterator[typing.Any]: try: while True: yield await self.receive_json() except WebSocketDisconnect: pass async def send_text(self, data: str) -> None: await self.send({"type": "websocket.send", "text": data}) async def send_bytes(self, data: bytes) -> None: await self.send({"type": "websocket.send", "bytes": data}) async def send_json(self, data: typing.Any, mode: str = "text") -> None: assert mode in ["text", "binary"] text = json.dumps(data) if mode == "text": await self.send({"type": "websocket.send", "text": text}) else: await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")}) async def close(self, code: int = 1000) -> None: await self.send({"type": "websocket.close", "code": code}) class WebSocketClose: def __init__(self, code: int = 1000) -> None: self.code = code async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await send({"type": "websocket.close", "code": self.code})