Mabasej_Team/.venv/lib/python3.9/site-packages/starlette/formparsers.py

243 lines
8.6 KiB
Python
Raw Normal View History

2021-03-17 08:57:57 +01:00
import typing
from enum import Enum
from urllib.parse import unquote_plus
from starlette.datastructures import FormData, Headers, UploadFile
try:
import multipart
from multipart.multipart import parse_options_header
except ImportError: # pragma: nocover
parse_options_header = None
multipart = None
class FormMessage(Enum):
FIELD_START = 1
FIELD_NAME = 2
FIELD_DATA = 3
FIELD_END = 4
END = 5
class MultiPartMessage(Enum):
PART_BEGIN = 1
PART_DATA = 2
PART_END = 3
HEADER_FIELD = 4
HEADER_VALUE = 5
HEADER_END = 6
HEADERS_FINISHED = 7
END = 8
def _user_safe_decode(src: bytes, codec: str) -> str:
try:
return src.decode(codec)
except (UnicodeDecodeError, LookupError):
return src.decode("latin-1")
class FormParser:
def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages = [] # type: typing.List[typing.Tuple[FormMessage, bytes]]
def on_field_start(self) -> None:
message = (FormMessage.FIELD_START, b"")
self.messages.append(message)
def on_field_name(self, data: bytes, start: int, end: int) -> None:
message = (FormMessage.FIELD_NAME, data[start:end])
self.messages.append(message)
def on_field_data(self, data: bytes, start: int, end: int) -> None:
message = (FormMessage.FIELD_DATA, data[start:end])
self.messages.append(message)
def on_field_end(self) -> None:
message = (FormMessage.FIELD_END, b"")
self.messages.append(message)
def on_end(self) -> None:
message = (FormMessage.END, b"")
self.messages.append(message)
async def parse(self) -> FormData:
# Callbacks dictionary.
callbacks = {
"on_field_start": self.on_field_start,
"on_field_name": self.on_field_name,
"on_field_data": self.on_field_data,
"on_field_end": self.on_field_end,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.QuerystringParser(callbacks)
field_name = b""
field_value = b""
items = (
[]
) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]]
# Feed the parser with data from the request.
async for chunk in self.stream:
if chunk:
parser.write(chunk)
else:
parser.finalize()
messages = list(self.messages)
self.messages.clear()
for message_type, message_bytes in messages:
if message_type == FormMessage.FIELD_START:
field_name = b""
field_value = b""
elif message_type == FormMessage.FIELD_NAME:
field_name += message_bytes
elif message_type == FormMessage.FIELD_DATA:
field_value += message_bytes
elif message_type == FormMessage.FIELD_END:
name = unquote_plus(field_name.decode("latin-1"))
value = unquote_plus(field_value.decode("latin-1"))
items.append((name, value))
elif message_type == FormMessage.END:
pass
return FormData(items)
class MultiPartParser:
def __init__(
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
) -> None:
assert (
multipart is not None
), "The `python-multipart` library must be installed to use form parsing."
self.headers = headers
self.stream = stream
self.messages = [] # type: typing.List[typing.Tuple[MultiPartMessage, bytes]]
def on_part_begin(self) -> None:
message = (MultiPartMessage.PART_BEGIN, b"")
self.messages.append(message)
def on_part_data(self, data: bytes, start: int, end: int) -> None:
message = (MultiPartMessage.PART_DATA, data[start:end])
self.messages.append(message)
def on_part_end(self) -> None:
message = (MultiPartMessage.PART_END, b"")
self.messages.append(message)
def on_header_field(self, data: bytes, start: int, end: int) -> None:
message = (MultiPartMessage.HEADER_FIELD, data[start:end])
self.messages.append(message)
def on_header_value(self, data: bytes, start: int, end: int) -> None:
message = (MultiPartMessage.HEADER_VALUE, data[start:end])
self.messages.append(message)
def on_header_end(self) -> None:
message = (MultiPartMessage.HEADER_END, b"")
self.messages.append(message)
def on_headers_finished(self) -> None:
message = (MultiPartMessage.HEADERS_FINISHED, b"")
self.messages.append(message)
def on_end(self) -> None:
message = (MultiPartMessage.END, b"")
self.messages.append(message)
async def parse(self) -> FormData:
# Parse the Content-Type header to get the multipart boundary.
content_type, params = parse_options_header(self.headers["Content-Type"])
charset = params.get(b"charset", "utf-8")
if type(charset) == bytes:
charset = charset.decode("latin-1")
boundary = params.get(b"boundary")
# Callbacks dictionary.
callbacks = {
"on_part_begin": self.on_part_begin,
"on_part_data": self.on_part_data,
"on_part_end": self.on_part_end,
"on_header_field": self.on_header_field,
"on_header_value": self.on_header_value,
"on_header_end": self.on_header_end,
"on_headers_finished": self.on_headers_finished,
"on_end": self.on_end,
}
# Create the parser.
parser = multipart.MultipartParser(boundary, callbacks)
header_field = b""
header_value = b""
content_disposition = None
content_type = b""
field_name = ""
data = b""
file = None # type: typing.Optional[UploadFile]
items = (
[]
) # type: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]]
# Feed the parser with data from the request.
async for chunk in self.stream:
parser.write(chunk)
messages = list(self.messages)
self.messages.clear()
for message_type, message_bytes in messages:
if message_type == MultiPartMessage.PART_BEGIN:
content_disposition = None
content_type = b""
data = b""
elif message_type == MultiPartMessage.HEADER_FIELD:
header_field += message_bytes
elif message_type == MultiPartMessage.HEADER_VALUE:
header_value += message_bytes
elif message_type == MultiPartMessage.HEADER_END:
field = header_field.lower()
if field == b"content-disposition":
content_disposition = header_value
elif field == b"content-type":
content_type = header_value
header_field = b""
header_value = b""
elif message_type == MultiPartMessage.HEADERS_FINISHED:
disposition, options = parse_options_header(content_disposition)
field_name = _user_safe_decode(options[b"name"], charset)
if b"filename" in options:
filename = _user_safe_decode(options[b"filename"], charset)
file = UploadFile(
filename=filename,
content_type=content_type.decode("latin-1"),
)
else:
file = None
elif message_type == MultiPartMessage.PART_DATA:
if file is None:
data += message_bytes
else:
await file.write(message_bytes)
elif message_type == MultiPartMessage.PART_END:
if file is None:
items.append((field_name, _user_safe_decode(data, charset)))
else:
await file.seek(0)
items.append((field_name, file))
elif message_type == MultiPartMessage.END:
pass
parser.finalize()
return FormData(items)