import tempfile import typing from collections import namedtuple from collections.abc import Sequence from shlex import shlex from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit from starlette.concurrency import run_in_threadpool from starlette.types import Scope Address = namedtuple("Address", ["host", "port"]) class URL: def __init__( self, url: str = "", scope: Scope = None, **components: typing.Any ) -> None: if scope is not None: assert not url, 'Cannot set both "url" and "scope".' assert not components, 'Cannot set both "scope" and "**components".' scheme = scope.get("scheme", "http") server = scope.get("server", None) path = scope.get("root_path", "") + scope["path"] query_string = scope.get("query_string", b"") host_header = None for key, value in scope["headers"]: if key == b"host": host_header = value.decode("latin-1") break if host_header is not None: url = f"{scheme}://{host_header}{path}" elif server is None: url = path else: host, port = server default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme] if port == default_port: url = f"{scheme}://{host}{path}" else: url = f"{scheme}://{host}:{port}{path}" if query_string: url += "?" + query_string.decode() elif components: assert not url, 'Cannot set both "url" and "**components".' url = URL("").replace(**components).components.geturl() self._url = url @property def components(self) -> SplitResult: if not hasattr(self, "_components"): self._components = urlsplit(self._url) return self._components @property def scheme(self) -> str: return self.components.scheme @property def netloc(self) -> str: return self.components.netloc @property def path(self) -> str: return self.components.path @property def query(self) -> str: return self.components.query @property def fragment(self) -> str: return self.components.fragment @property def username(self) -> typing.Union[None, str]: return self.components.username @property def password(self) -> typing.Union[None, str]: return self.components.password @property def hostname(self) -> typing.Union[None, str]: return self.components.hostname @property def port(self) -> typing.Optional[int]: return self.components.port @property def is_secure(self) -> bool: return self.scheme in ("https", "wss") def replace(self, **kwargs: typing.Any) -> "URL": if ( "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs ): hostname = kwargs.pop("hostname", self.hostname) port = kwargs.pop("port", self.port) username = kwargs.pop("username", self.username) password = kwargs.pop("password", self.password) netloc = hostname if port is not None: netloc += f":{port}" if username is not None: userpass = username if password is not None: userpass += f":{password}" netloc = f"{userpass}@{netloc}" kwargs["netloc"] = netloc components = self.components._replace(**kwargs) return self.__class__(components.geturl()) def include_query_params(self, **kwargs: typing.Any) -> "URL": params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) params.update({str(key): str(value) for key, value in kwargs.items()}) query = urlencode(params.multi_items()) return self.replace(query=query) def replace_query_params(self, **kwargs: typing.Any) -> "URL": query = urlencode([(str(key), str(value)) for key, value in kwargs.items()]) return self.replace(query=query) def remove_query_params( self, keys: typing.Union[str, typing.Sequence[str]] ) -> "URL": if isinstance(keys, str): keys = [keys] params = MultiDict(parse_qsl(self.query, keep_blank_values=True)) for key in keys: params.pop(key, None) query = urlencode(params.multi_items()) return self.replace(query=query) def __eq__(self, other: typing.Any) -> bool: return str(self) == str(other) def __str__(self) -> str: return self._url def __repr__(self) -> str: url = str(self) if self.password: url = str(self.replace(password="********")) return f"{self.__class__.__name__}({repr(url)})" class URLPath(str): """ A URL path string that may also hold an associated protocol and/or host. Used by the routing to return `url_path_for` matches. """ def __new__(cls, path: str, protocol: str = "", host: str = "") -> "URLPath": assert protocol in ("http", "websocket", "") return str.__new__(cls, path) # type: ignore def __init__(self, path: str, protocol: str = "", host: str = "") -> None: self.protocol = protocol self.host = host def make_absolute_url(self, base_url: typing.Union[str, URL]) -> str: if isinstance(base_url, str): base_url = URL(base_url) if self.protocol: scheme = { "http": {True: "https", False: "http"}, "websocket": {True: "wss", False: "ws"}, }[self.protocol][base_url.is_secure] else: scheme = base_url.scheme if self.host: netloc = self.host else: netloc = base_url.netloc path = base_url.path.rstrip("/") + str(self) return str(URL(scheme=scheme, netloc=netloc, path=path)) class Secret: """ Holds a string value that should not be revealed in tracebacks etc. You should cast the value to `str` at the point it is required. """ def __init__(self, value: str): self._value = value def __repr__(self) -> str: class_name = self.__class__.__name__ return f"{class_name}('**********')" def __str__(self) -> str: return self._value class CommaSeparatedStrings(Sequence): def __init__(self, value: typing.Union[str, typing.Sequence[str]]): if isinstance(value, str): splitter = shlex(value, posix=True) splitter.whitespace = "," splitter.whitespace_split = True self._items = [item.strip() for item in splitter] else: self._items = list(value) def __len__(self) -> int: return len(self._items) def __getitem__(self, index: typing.Union[int, slice]) -> typing.Any: return self._items[index] def __iter__(self) -> typing.Iterator[str]: return iter(self._items) def __repr__(self) -> str: class_name = self.__class__.__name__ items = [item for item in self] return f"{class_name}({items!r})" def __str__(self) -> str: return ", ".join([repr(item) for item in self]) class ImmutableMultiDict(typing.Mapping): def __init__( self, *args: typing.Union[ "ImmutableMultiDict", typing.Mapping, typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." if args: value = args[0] else: value = [] if kwargs: value = ( ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items() ) if not value: _items = [] # type: typing.List[typing.Tuple[typing.Any, typing.Any]] elif hasattr(value, "multi_items"): value = typing.cast(ImmutableMultiDict, value) _items = list(value.multi_items()) elif hasattr(value, "items"): value = typing.cast(typing.Mapping, value) _items = list(value.items()) else: value = typing.cast( typing.List[typing.Tuple[typing.Any, typing.Any]], value ) _items = list(value) self._dict = {k: v for k, v in _items} self._list = _items def getlist(self, key: typing.Any) -> typing.List[str]: return [item_value for item_key, item_value in self._list if item_key == key] def keys(self) -> typing.KeysView: return self._dict.keys() def values(self) -> typing.ValuesView: return self._dict.values() def items(self) -> typing.ItemsView: return self._dict.items() def multi_items(self) -> typing.List[typing.Tuple[str, str]]: return list(self._list) def get(self, key: typing.Any, default: typing.Any = None) -> typing.Any: if key in self._dict: return self._dict[key] return default def __getitem__(self, key: typing.Any) -> str: return self._dict[key] def __contains__(self, key: typing.Any) -> bool: return key in self._dict def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._dict) def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, self.__class__): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ items = self.multi_items() return f"{class_name}({items!r})" class MultiDict(ImmutableMultiDict): def __setitem__(self, key: typing.Any, value: typing.Any) -> None: self.setlist(key, [value]) def __delitem__(self, key: typing.Any) -> None: self._list = [(k, v) for k, v in self._list if k != key] del self._dict[key] def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any: self._list = [(k, v) for k, v in self._list if k != key] return self._dict.pop(key, default) def popitem(self) -> typing.Tuple: key, value = self._dict.popitem() self._list = [(k, v) for k, v in self._list if k != key] return key, value def poplist(self, key: typing.Any) -> typing.List: values = [v for k, v in self._list if k == key] self.pop(key) return values def clear(self) -> None: self._dict.clear() self._list.clear() def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any: if key not in self: self._dict[key] = default self._list.append((key, default)) return self[key] def setlist(self, key: typing.Any, values: typing.List) -> None: if not values: self.pop(key, None) else: existing_items = [(k, v) for (k, v) in self._list if k != key] self._list = existing_items + [(key, value) for value in values] self._dict[key] = values[-1] def append(self, key: typing.Any, value: typing.Any) -> None: self._list.append((key, value)) self._dict[key] = value def update( self, *args: typing.Union[ "MultiDict", typing.Mapping, typing.List[typing.Tuple[typing.Any, typing.Any]], ], **kwargs: typing.Any, ) -> None: value = MultiDict(*args, **kwargs) existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()] self._list = existing_items + value.multi_items() self._dict.update(value) class QueryParams(ImmutableMultiDict): """ An immutable multidict. """ def __init__( self, *args: typing.Union[ "ImmutableMultiDict", typing.Mapping, typing.List[typing.Tuple[typing.Any, typing.Any]], str, bytes, ], **kwargs: typing.Any, ) -> None: assert len(args) < 2, "Too many arguments." value = args[0] if args else [] if isinstance(value, str): super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs) elif isinstance(value, bytes): super().__init__( parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs ) else: super().__init__(*args, **kwargs) # type: ignore self._list = [(str(k), str(v)) for k, v in self._list] self._dict = {str(k): str(v) for k, v in self._dict.items()} def __str__(self) -> str: return urlencode(self._list) def __repr__(self) -> str: class_name = self.__class__.__name__ query_string = str(self) return f"{class_name}({query_string!r})" class UploadFile: """ An uploaded file included as part of the request data. """ spool_max_size = 1024 * 1024 def __init__( self, filename: str, file: typing.IO = None, content_type: str = "" ) -> None: self.filename = filename self.content_type = content_type if file is None: file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) self.file = file @property def _in_memory(self) -> bool: rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk async def write(self, data: typing.Union[bytes, str]) -> None: if self._in_memory: self.file.write(data) # type: ignore else: await run_in_threadpool(self.file.write, data) async def read(self, size: int = -1) -> typing.Union[bytes, str]: if self._in_memory: return self.file.read(size) return await run_in_threadpool(self.file.read, size) async def seek(self, offset: int) -> None: if self._in_memory: self.file.seek(offset) else: await run_in_threadpool(self.file.seek, offset) async def close(self) -> None: if self._in_memory: self.file.close() else: await run_in_threadpool(self.file.close) class FormData(ImmutableMultiDict): """ An immutable multidict, containing both file uploads and text input. """ def __init__( self, *args: typing.Union[ "FormData", typing.Mapping[str, typing.Union[str, UploadFile]], typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]], ], **kwargs: typing.Union[str, UploadFile], ) -> None: super().__init__(*args, **kwargs) async def close(self) -> None: for key, value in self.multi_items(): if isinstance(value, UploadFile): await value.close() class Headers(typing.Mapping[str, str]): """ An immutable, case-insensitive multidict. """ def __init__( self, headers: typing.Mapping[str, str] = None, raw: typing.List[typing.Tuple[bytes, bytes]] = None, scope: Scope = None, ) -> None: self._list = [] # type: typing.List[typing.Tuple[bytes, bytes]] if headers is not None: assert raw is None, 'Cannot set both "headers" and "raw".' assert scope is None, 'Cannot set both "headers" and "scope".' self._list = [ (key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in headers.items() ] elif raw is not None: assert scope is None, 'Cannot set both "raw" and "scope".' self._list = raw elif scope is not None: self._list = scope["headers"] @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: return list(self._list) def keys(self) -> typing.List[str]: # type: ignore return [key.decode("latin-1") for key, value in self._list] def values(self) -> typing.List[str]: # type: ignore return [value.decode("latin-1") for key, value in self._list] def items(self) -> typing.List[typing.Tuple[str, str]]: # type: ignore return [ (key.decode("latin-1"), value.decode("latin-1")) for key, value in self._list ] def get(self, key: str, default: typing.Any = None) -> typing.Any: try: return self[key] except KeyError: return default def getlist(self, key: str) -> typing.List[str]: get_header_key = key.lower().encode("latin-1") return [ item_value.decode("latin-1") for item_key, item_value in self._list if item_key == get_header_key ] def mutablecopy(self) -> "MutableHeaders": return MutableHeaders(raw=self._list[:]) def __getitem__(self, key: str) -> str: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return header_value.decode("latin-1") raise KeyError(key) def __contains__(self, key: typing.Any) -> bool: get_header_key = key.lower().encode("latin-1") for header_key, header_value in self._list: if header_key == get_header_key: return True return False def __iter__(self) -> typing.Iterator[typing.Any]: return iter(self.keys()) def __len__(self) -> int: return len(self._list) def __eq__(self, other: typing.Any) -> bool: if not isinstance(other, Headers): return False return sorted(self._list) == sorted(other._list) def __repr__(self) -> str: class_name = self.__class__.__name__ as_dict = dict(self.items()) if len(as_dict) == len(self): return f"{class_name}({as_dict!r})" return f"{class_name}(raw={self.raw!r})" class MutableHeaders(Headers): def __setitem__(self, key: str, value: str) -> None: """ Set the header `key` to `value`, removing any duplicate entries. Retains insertion order. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") found_indexes = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: found_indexes.append(idx) for idx in reversed(found_indexes[1:]): del self._list[idx] if found_indexes: idx = found_indexes[0] self._list[idx] = (set_key, set_value) else: self._list.append((set_key, set_value)) def __delitem__(self, key: str) -> None: """ Remove the header `key`. """ del_key = key.lower().encode("latin-1") pop_indexes = [] for idx, (item_key, item_value) in enumerate(self._list): if item_key == del_key: pop_indexes.append(idx) for idx in reversed(pop_indexes): del self._list[idx] @property def raw(self) -> typing.List[typing.Tuple[bytes, bytes]]: return self._list def setdefault(self, key: str, value: str) -> str: """ If the header `key` does not exist, then set it to `value`. Returns the header value. """ set_key = key.lower().encode("latin-1") set_value = value.encode("latin-1") for idx, (item_key, item_value) in enumerate(self._list): if item_key == set_key: return item_value.decode("latin-1") self._list.append((set_key, set_value)) return value def update(self, other: dict) -> None: for key, val in other.items(): self[key] = val def append(self, key: str, value: str) -> None: """ Append a header, preserving any duplicate entries. """ append_key = key.lower().encode("latin-1") append_value = value.encode("latin-1") self._list.append((append_key, append_value)) def add_vary_header(self, vary: str) -> None: existing = self.get("vary") if existing is not None: vary = ", ".join([existing, vary]) self["vary"] = vary class State(object): """ An object that can be used to store arbitrary state. Used for `request.state` and `app.state`. """ def __init__(self, state: typing.Dict = None): if state is None: state = {} super(State, self).__setattr__("_state", state) def __setattr__(self, key: typing.Any, value: typing.Any) -> None: self._state[key] = value def __getattr__(self, key: typing.Any) -> typing.Any: try: return self._state[key] except KeyError: message = "'{}' object has no attribute '{}'" raise AttributeError(message.format(self.__class__.__name__, key)) def __delattr__(self, key: typing.Any) -> None: del self._state[key]