208 lines
7.7 KiB
Python
208 lines
7.7 KiB
Python
from typing import Any, Dict, List, Optional, Union
|
|
|
|
from fastapi.exceptions import HTTPException
|
|
from fastapi.openapi.models import OAuth2 as OAuth2Model
|
|
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
|
|
from fastapi.param_functions import Form
|
|
from fastapi.security.base import SecurityBase
|
|
from fastapi.security.utils import get_authorization_scheme_param
|
|
from starlette.requests import Request
|
|
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
|
|
|
|
|
|
class OAuth2PasswordRequestForm:
|
|
"""
|
|
This is a dependency class, use it like:
|
|
|
|
@app.post("/login")
|
|
def login(form_data: OAuth2PasswordRequestForm = Depends()):
|
|
data = form_data.parse()
|
|
print(data.username)
|
|
print(data.password)
|
|
for scope in data.scopes:
|
|
print(scope)
|
|
if data.client_id:
|
|
print(data.client_id)
|
|
if data.client_secret:
|
|
print(data.client_secret)
|
|
return data
|
|
|
|
|
|
It creates the following Form request parameters in your endpoint:
|
|
|
|
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
|
|
Nevertheless, this dependency class is permissive and allows not passing it. If you want to enforce it,
|
|
use instead the OAuth2PasswordRequestFormStrict dependency.
|
|
username: username string. The OAuth2 spec requires the exact field name "username".
|
|
password: password string. The OAuth2 spec requires the exact field name "password".
|
|
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
|
|
"items:read items:write users:read profile openid"
|
|
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
|
|
using HTTP Basic auth, as: client_id:client_secret
|
|
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
|
|
using HTTP Basic auth, as: client_id:client_secret
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
grant_type: str = Form(None, regex="password"),
|
|
username: str = Form(...),
|
|
password: str = Form(...),
|
|
scope: str = Form(""),
|
|
client_id: Optional[str] = Form(None),
|
|
client_secret: Optional[str] = Form(None),
|
|
):
|
|
self.grant_type = grant_type
|
|
self.username = username
|
|
self.password = password
|
|
self.scopes = scope.split()
|
|
self.client_id = client_id
|
|
self.client_secret = client_secret
|
|
|
|
|
|
class OAuth2PasswordRequestFormStrict(OAuth2PasswordRequestForm):
|
|
"""
|
|
This is a dependency class, use it like:
|
|
|
|
@app.post("/login")
|
|
def login(form_data: OAuth2PasswordRequestFormStrict = Depends()):
|
|
data = form_data.parse()
|
|
print(data.username)
|
|
print(data.password)
|
|
for scope in data.scopes:
|
|
print(scope)
|
|
if data.client_id:
|
|
print(data.client_id)
|
|
if data.client_secret:
|
|
print(data.client_secret)
|
|
return data
|
|
|
|
|
|
It creates the following Form request parameters in your endpoint:
|
|
|
|
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
|
|
This dependency is strict about it. If you want to be permissive, use instead the
|
|
OAuth2PasswordRequestForm dependency class.
|
|
username: username string. The OAuth2 spec requires the exact field name "username".
|
|
password: password string. The OAuth2 spec requires the exact field name "password".
|
|
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
|
|
"items:read items:write users:read profile openid"
|
|
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
|
|
using HTTP Basic auth, as: client_id:client_secret
|
|
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
|
|
using HTTP Basic auth, as: client_id:client_secret
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
grant_type: str = Form(..., regex="password"),
|
|
username: str = Form(...),
|
|
password: str = Form(...),
|
|
scope: str = Form(""),
|
|
client_id: Optional[str] = Form(None),
|
|
client_secret: Optional[str] = Form(None),
|
|
):
|
|
super().__init__(
|
|
grant_type=grant_type,
|
|
username=username,
|
|
password=password,
|
|
scope=scope,
|
|
client_id=client_id,
|
|
client_secret=client_secret,
|
|
)
|
|
|
|
|
|
class OAuth2(SecurityBase):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
flows: Union[OAuthFlowsModel, Dict[str, Dict[str, Any]]] = OAuthFlowsModel(),
|
|
scheme_name: Optional[str] = None,
|
|
auto_error: Optional[bool] = True
|
|
):
|
|
self.model = OAuth2Model(flows=flows)
|
|
self.scheme_name = scheme_name or self.__class__.__name__
|
|
self.auto_error = auto_error
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]:
|
|
authorization: str = request.headers.get("Authorization")
|
|
if not authorization:
|
|
if self.auto_error:
|
|
raise HTTPException(
|
|
status_code=HTTP_403_FORBIDDEN, detail="Not authenticated"
|
|
)
|
|
else:
|
|
return None
|
|
return authorization
|
|
|
|
|
|
class OAuth2PasswordBearer(OAuth2):
|
|
def __init__(
|
|
self,
|
|
tokenUrl: str,
|
|
scheme_name: Optional[str] = None,
|
|
scopes: Optional[Dict[str, str]] = None,
|
|
auto_error: bool = True,
|
|
):
|
|
if not scopes:
|
|
scopes = {}
|
|
flows = OAuthFlowsModel(password={"tokenUrl": tokenUrl, "scopes": scopes})
|
|
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]:
|
|
authorization: str = request.headers.get("Authorization")
|
|
scheme, param = get_authorization_scheme_param(authorization)
|
|
if not authorization or scheme.lower() != "bearer":
|
|
if self.auto_error:
|
|
raise HTTPException(
|
|
status_code=HTTP_401_UNAUTHORIZED,
|
|
detail="Not authenticated",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
else:
|
|
return None
|
|
return param
|
|
|
|
|
|
class OAuth2AuthorizationCodeBearer(OAuth2):
|
|
def __init__(
|
|
self,
|
|
authorizationUrl: str,
|
|
tokenUrl: str,
|
|
refreshUrl: Optional[str] = None,
|
|
scheme_name: Optional[str] = None,
|
|
scopes: Optional[Dict[str, str]] = None,
|
|
auto_error: bool = True,
|
|
):
|
|
if not scopes:
|
|
scopes = {}
|
|
flows = OAuthFlowsModel(
|
|
authorizationCode={
|
|
"authorizationUrl": authorizationUrl,
|
|
"tokenUrl": tokenUrl,
|
|
"refreshUrl": refreshUrl,
|
|
"scopes": scopes,
|
|
}
|
|
)
|
|
super().__init__(flows=flows, scheme_name=scheme_name, auto_error=auto_error)
|
|
|
|
async def __call__(self, request: Request) -> Optional[str]:
|
|
authorization: str = request.headers.get("Authorization")
|
|
scheme, param = get_authorization_scheme_param(authorization)
|
|
if not authorization or scheme.lower() != "bearer":
|
|
if self.auto_error:
|
|
raise HTTPException(
|
|
status_code=HTTP_401_UNAUTHORIZED,
|
|
detail="Not authenticated",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
else:
|
|
return None # pragma: nocover
|
|
return param
|
|
|
|
|
|
class SecurityScopes:
|
|
def __init__(self, scopes: Optional[List[str]] = None):
|
|
self.scopes = scopes or []
|
|
self.scope_str = " ".join(self.scopes)
|