import typing from enum import Enum from urllib.parse import unquote_plus from starlette.datastructures import FormData, Headers, UploadFile try: import multipart from multipart.multipart import parse_options_header except ImportError: # pragma: nocover parse_options_header = None multipart = None class FormMessage(Enum): FIELD_START = 1 FIELD_NAME = 2 FIELD_DATA = 3 FIELD_END = 4 END = 5 class MultiPartMessage(Enum): PART_BEGIN = 1 PART_DATA = 2 PART_END = 3 HEADER_FIELD = 4 HEADER_VALUE = 5 HEADER_END = 6 HEADERS_FINISHED = 7 END = 8 def _user_safe_decode(src: bytes, codec: str) -> str: try: return src.decode(codec) except (UnicodeDecodeError, LookupError): return src.decode("latin-1") class FormParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] ) -> None: assert ( multipart is not None ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.messages = [] # type: typing.List[typing.Tuple[FormMessage, bytes]] def on_field_start(self) -> None: message = (FormMessage.FIELD_START, b"") self.messages.append(message) def on_field_name(self, data: bytes, start: int, end: int) -> None: message = (FormMessage.FIELD_NAME, data[start:end]) self.messages.append(message) def on_field_data(self, data: bytes, start: int, end: int) -> None: message = (FormMessage.FIELD_DATA, data[start:end]) self.messages.append(message) def on_field_end(self) -> None: message = (FormMessage.FIELD_END, b"") self.messages.append(message) def on_end(self) -> None: message = (FormMessage.END, b"") self.messages.append(message) async def parse(self) -> FormData: # Callbacks dictionary. callbacks = { "on_field_start": self.on_field_start, "on_field_name": self.on_field_name, "on_field_data": self.on_field_data, "on_field_end": self.on_field_end, "on_end": self.on_end, } # Create the parser. parser = multipart.QuerystringParser(callbacks) field_name = b"" field_value = b"" items = ( [] ) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] # Feed the parser with data from the request. async for chunk in self.stream: if chunk: parser.write(chunk) else: parser.finalize() messages = list(self.messages) self.messages.clear() for message_type, message_bytes in messages: if message_type == FormMessage.FIELD_START: field_name = b"" field_value = b"" elif message_type == FormMessage.FIELD_NAME: field_name += message_bytes elif message_type == FormMessage.FIELD_DATA: field_value += message_bytes elif message_type == FormMessage.FIELD_END: name = unquote_plus(field_name.decode("latin-1")) value = unquote_plus(field_value.decode("latin-1")) items.append((name, value)) elif message_type == FormMessage.END: pass return FormData(items) class MultiPartParser: def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] ) -> None: assert ( multipart is not None ), "The `python-multipart` library must be installed to use form parsing." self.headers = headers self.stream = stream self.messages = [] # type: typing.List[typing.Tuple[MultiPartMessage, bytes]] def on_part_begin(self) -> None: message = (MultiPartMessage.PART_BEGIN, b"") self.messages.append(message) def on_part_data(self, data: bytes, start: int, end: int) -> None: message = (MultiPartMessage.PART_DATA, data[start:end]) self.messages.append(message) def on_part_end(self) -> None: message = (MultiPartMessage.PART_END, b"") self.messages.append(message) def on_header_field(self, data: bytes, start: int, end: int) -> None: message = (MultiPartMessage.HEADER_FIELD, data[start:end]) self.messages.append(message) def on_header_value(self, data: bytes, start: int, end: int) -> None: message = (MultiPartMessage.HEADER_VALUE, data[start:end]) self.messages.append(message) def on_header_end(self) -> None: message = (MultiPartMessage.HEADER_END, b"") self.messages.append(message) def on_headers_finished(self) -> None: message = (MultiPartMessage.HEADERS_FINISHED, b"") self.messages.append(message) def on_end(self) -> None: message = (MultiPartMessage.END, b"") self.messages.append(message) async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. content_type, params = parse_options_header(self.headers["Content-Type"]) charset = params.get(b"charset", "utf-8") if type(charset) == bytes: charset = charset.decode("latin-1") boundary = params.get(b"boundary") # Callbacks dictionary. callbacks = { "on_part_begin": self.on_part_begin, "on_part_data": self.on_part_data, "on_part_end": self.on_part_end, "on_header_field": self.on_header_field, "on_header_value": self.on_header_value, "on_header_end": self.on_header_end, "on_headers_finished": self.on_headers_finished, "on_end": self.on_end, } # Create the parser. parser = multipart.MultipartParser(boundary, callbacks) header_field = b"" header_value = b"" content_disposition = None content_type = b"" field_name = "" data = b"" file = None # type: typing.Optional[UploadFile] items = ( [] ) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] # Feed the parser with data from the request. async for chunk in self.stream: parser.write(chunk) messages = list(self.messages) self.messages.clear() for message_type, message_bytes in messages: if message_type == MultiPartMessage.PART_BEGIN: content_disposition = None content_type = b"" data = b"" elif message_type == MultiPartMessage.HEADER_FIELD: header_field += message_bytes elif message_type == MultiPartMessage.HEADER_VALUE: header_value += message_bytes elif message_type == MultiPartMessage.HEADER_END: field = header_field.lower() if field == b"content-disposition": content_disposition = header_value elif field == b"content-type": content_type = header_value header_field = b"" header_value = b"" elif message_type == MultiPartMessage.HEADERS_FINISHED: disposition, options = parse_options_header(content_disposition) field_name = _user_safe_decode(options[b"name"], charset) if b"filename" in options: filename = _user_safe_decode(options[b"filename"], charset) file = UploadFile( filename=filename, content_type=content_type.decode("latin-1"), ) else: file = None elif message_type == MultiPartMessage.PART_DATA: if file is None: data += message_bytes else: await file.write(message_bytes) elif message_type == MultiPartMessage.PART_END: if file is None: items.append((field_name, _user_safe_decode(data, charset))) else: await file.seek(0) items.append((field_name, file)) elif message_type == MultiPartMessage.END: pass parser.finalize() return FormData(items)