import importlib.util import os import stat import typing from email.utils import parsedate from aiofiles.os import stat as aio_stat from starlette.datastructures import URL, Headers from starlette.responses import ( FileResponse, PlainTextResponse, RedirectResponse, Response, ) from starlette.types import Receive, Scope, Send class NotModifiedResponse(Response): NOT_MODIFIED_HEADERS = ( "cache-control", "content-location", "date", "etag", "expires", "vary", ) def __init__(self, headers: Headers): super().__init__( status_code=304, headers={ name: value for name, value in headers.items() if name in self.NOT_MODIFIED_HEADERS }, ) class StaticFiles: def __init__( self, *, directory: str = None, packages: typing.List[str] = None, html: bool = False, check_dir: bool = True, ) -> None: self.directory = directory self.packages = packages self.all_directories = self.get_directories(directory, packages) self.html = html self.config_checked = False if check_dir and directory is not None and not os.path.isdir(directory): raise RuntimeError(f"Directory '{directory}' does not exist") def get_directories( self, directory: str = None, packages: typing.List[str] = None ) -> typing.List[str]: """ Given `directory` and `packages` arguments, return a list of all the directories that should be used for serving static files from. """ directories = [] if directory is not None: directories.append(directory) for package in packages or []: spec = importlib.util.find_spec(package) assert spec is not None, f"Package {package!r} could not be found." assert ( spec.origin is not None ), f"Directory 'statics' in package {package!r} could not be found." directory = os.path.normpath(os.path.join(spec.origin, "..", "statics")) assert os.path.isdir( directory ), f"Directory 'statics' in package {package!r} could not be found." directories.append(directory) return directories async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: """ The ASGI entry point. """ assert scope["type"] == "http" if not self.config_checked: await self.check_config() self.config_checked = True path = self.get_path(scope) response = await self.get_response(path, scope) await response(scope, receive, send) def get_path(self, scope: Scope) -> str: """ Given the ASGI scope, return the `path` string to serve up, with OS specific path seperators, and any '..', '.' components removed. """ return os.path.normpath(os.path.join(*scope["path"].split("/"))) async def get_response(self, path: str, scope: Scope) -> Response: """ Returns an HTTP response, given the incoming path, method and request headers. """ if scope["method"] not in ("GET", "HEAD"): return PlainTextResponse("Method Not Allowed", status_code=405) full_path, stat_result = await self.lookup_path(path) if stat_result and stat.S_ISREG(stat_result.st_mode): # We have a static file to serve. return self.file_response(full_path, stat_result, scope) elif stat_result and stat.S_ISDIR(stat_result.st_mode) and self.html: # We're in HTML mode, and have got a directory URL. # Check if we have 'index.html' file to serve. index_path = os.path.join(path, "index.html") full_path, stat_result = await self.lookup_path(index_path) if stat_result is not None and stat.S_ISREG(stat_result.st_mode): if not scope["path"].endswith("/"): # Directory URLs should redirect to always end in "/". url = URL(scope=scope) url = url.replace(path=url.path + "/") return RedirectResponse(url=url) return self.file_response(full_path, stat_result, scope) if self.html: # Check for '404.html' if we're in HTML mode. full_path, stat_result = await self.lookup_path("404.html") if stat_result is not None and stat.S_ISREG(stat_result.st_mode): return self.file_response( full_path, stat_result, scope, status_code=404 ) return PlainTextResponse("Not Found", status_code=404) async def lookup_path( self, path: str ) -> typing.Tuple[str, typing.Optional[os.stat_result]]: for directory in self.all_directories: full_path = os.path.realpath(os.path.join(directory, path)) directory = os.path.realpath(directory) if os.path.commonprefix([full_path, directory]) != directory: # Don't allow misbehaving clients to break out of the static files directory. continue try: stat_result = await aio_stat(full_path) return (full_path, stat_result) except FileNotFoundError: pass return ("", None) def file_response( self, full_path: str, stat_result: os.stat_result, scope: Scope, status_code: int = 200, ) -> Response: method = scope["method"] request_headers = Headers(scope=scope) response = FileResponse( full_path, status_code=status_code, stat_result=stat_result, method=method ) if self.is_not_modified(response.headers, request_headers): return NotModifiedResponse(response.headers) return response async def check_config(self) -> None: """ Perform a one-off configuration check that StaticFiles is actually pointed at a directory, so that we can raise loud errors rather than just returning 404 responses. """ if self.directory is None: return try: stat_result = await aio_stat(self.directory) except FileNotFoundError: raise RuntimeError( f"StaticFiles directory '{self.directory}' does not exist." ) if not (stat.S_ISDIR(stat_result.st_mode) or stat.S_ISLNK(stat_result.st_mode)): raise RuntimeError( f"StaticFiles path '{self.directory}' is not a directory." ) def is_not_modified( self, response_headers: Headers, request_headers: Headers ) -> bool: """ Given the request and response headers, return `True` if an HTTP "Not Modified" response could be returned instead. """ try: if_none_match = request_headers["if-none-match"] etag = response_headers["etag"] if if_none_match == etag: return True except KeyError: pass try: if_modified_since = parsedate(request_headers["if-modified-since"]) last_modified = parsedate(response_headers["last-modified"]) if ( if_modified_since is not None and last_modified is not None and if_modified_since >= last_modified ): return True except KeyError: pass return False