Source code for app.domain.users.auth

from collections.abc import (
    Awaitable,
    Callable,
)
from typing import Annotated

from advanced_alchemy.exceptions import NotFoundError
from cashews import cache
from fastapi import Depends
from sqlalchemy.orm import joinedload, load_only

from app.config.base import get_settings
from app.config.constants import FITNESS_TRAINER_ROLE_SLUG
from app.db.models.role import Role
from app.db.models.user import User
from app.domain.users.deps import UserServiceDep
from app.domain.users.jwt_helpers import (
    TokenPayloadBase,
    get_access_token_payload,
    get_refresh_token_payload,
    is_token_in_blacklist,
)
from app.domain.users.schemas import UserAuth
from app.lib.auth import (
    access_token,
    refresh_token,
)
from app.lib.exceptions import (
    PermissionDeniedException,
    UnauthorizedException,
)

settings = get_settings()


[docs] class Authenticate: """Provides FastAPI dependency factories for authentication and authorization."""
[docs] @classmethod @cache( ttl=settings.app.user_auth_cache_ttl, key="user_auth:{token_payload.sub}", ) async def _get_user_from_payload( cls, users_service: UserServiceDep, token_payload: TokenPayloadBase, ) -> UserAuth: """Load UserAuth schema from the database using the JWT 'sub' claim. The result of this function is aggressively cached to reduce database load. Args: users_service (UserService): Dependency for user service operations. token_payload (TokenPayload): The token data. Returns: UserAuth: The authenticated user. Raises: UnauthorizedException: If the user is not found (HTTP 401). """ user_id = token_payload.sub try: db_obj = await users_service.get( item_id=user_id, load=[ load_only( User.id, User.name, User.email, User.is_active, User.is_superuser, ), joinedload(User.role).load_only(Role.slug), ], ) return users_service.to_schema(db_obj, schema_type=UserAuth) except NotFoundError as exc: msg = "Invalid authentication credentials" raise UnauthorizedException(message=msg) from exc
[docs] @classmethod async def get_current_user_for_refresh( cls, token: Annotated[str, Depends(refresh_token)], users_service: UserServiceDep, ) -> UserAuth: """Authenticate the user using the refresh token. Performs critical security checks including token blacklisting. This dependency is used exclusively by the token refresh endpoint. Args: token (str): The refresh token extracted from the cookie. users_service (UserService): Dependency for user service operations. Returns: UserAuth: The authenticated user with JTI attached. Raises: UnauthorizedException: If the token is invalid, blacklisted, or the user is inactive (HTTP 401). """ token_payload = get_refresh_token_payload(token=token) refresh_jti = token_payload.jti token_exists = await is_token_in_blacklist( refresh_token_identifier=refresh_jti, ) if token_exists: msg = "Invalid credentials" raise UnauthorizedException(message=msg) user_auth = await cls._get_user_from_payload( token_payload=token_payload, users_service=users_service, ) if not user_auth.is_active: msg = "Invalid credentials or account is unavailable" raise UnauthorizedException(message=msg) user_auth._refresh_jti = refresh_jti # noqa: SLF001 user_auth._refresh_exp = token_payload.exp # noqa: SLF001 return user_auth
[docs] @classmethod async def get_current_user( cls, token: Annotated[str, Depends(access_token)], users_service: UserServiceDep, ) -> UserAuth: """Authenticate the user using the access token. Args: token (str): The access token extracted from the cookie. users_service (UserService): Dependency for user service operations. Returns: UserAuth: The authenticated user. """ token_payload = await get_access_token_payload(token=token) return await cls._get_user_from_payload( token_payload=token_payload, users_service=users_service, )
[docs] @classmethod def get_current_active_user(cls) -> Callable[[UserAuth], Awaitable[UserAuth]]: """Dependency factory to ensure the user is active. It chains with `get_current_user` to perform both authentication and basic authorization (account status check). Returns: Callable: A FastAPI dependency function. Raises: UnauthorizedException: If the user is found but not active (HTTP 401). """ async def current_user( user_auth: Annotated[UserAuth, Depends(cls.get_current_user)], ) -> UserAuth: if not user_auth.is_active: raise UnauthorizedException(message="Invalid credentials or account is unavailable") return user_auth return current_user
[docs] @classmethod def superuser_required(cls) -> Callable[[UserAuth], Awaitable[UserAuth]]: """Dependency factory requiring superuser privileges. It chains with `get_current_active_user` and performs the final authorization check. Returns: Callable: A FastAPI dependency function. Raises: PermissionDeniedException: If the user is not a superuser (HTTP 403). """ async def current_user( user_auth: Annotated[UserAuth, Depends(cls.get_current_active_user())], ) -> UserAuth: if not user_auth.is_superuser: msg = "Access denied: Superuser privileges required" raise PermissionDeniedException(message=msg) return user_auth return current_user
[docs] @classmethod def trainer_required(cls) -> Callable[[UserAuth], Awaitable[UserAuth]]: """Dependency factory requiring the Fitness Trainer role. It chains with `get_current_active_user` and performs the final role check. Returns: Callable: A FastAPI dependency function. Raises: PermissionDeniedException: If the user does not have the required role (HTTP 403). """ async def current_user( user_auth: Annotated[UserAuth, Depends(cls.get_current_active_user())], ) -> UserAuth: if user_auth.role_slug == FITNESS_TRAINER_ROLE_SLUG: return user_auth msg = "Access restricted to fitness trainers" raise PermissionDeniedException(message=msg) return current_user
[docs] @classmethod def get_refresh_jti( cls, token: Annotated[str, Depends(refresh_token)], ) -> tuple[str, float]: """Extract the JWT ID (jti) and expiration (exp) from a refresh token. Args: token (str): The refresh token string from the cookie. Returns: tuple[str, float]: The JTI claim and expiration timestamp. """ token_payload = get_refresh_token_payload(token=token) return token_payload.jti, token_payload.exp