from base64 import urlsafe_b64decode
from datetime import timedelta
from time import time
from typing import TYPE_CHECKING
from cashews import cache
from joserfc.errors import (
JoseError,
)
from msgspec import (
DecodeError,
Struct,
json,
msgpack,
)
from structlog import get_logger
from app.config.base import get_settings
from app.lib.exceptions import UnauthorizedException
from app.lib.jwt_utils import decode_jwt, encode_jwt
if TYPE_CHECKING:
from uuid import UUID
settings = get_settings()
log = get_logger()
[docs]
class TokenPayloadBase(Struct):
"""Represents the base set of validated JWT claims."""
iat: float
exp: float
jti: str
sub: str
@property
def is_expired(self) -> bool:
return self.exp <= time()
[docs]
class TokenPayloadAccess(TokenPayloadBase):
"""Stores payload data specific to the access token."""
email: str
[docs]
class TokenPayloadRefresh(TokenPayloadBase):
"""Stores payload data specific to the refresh token."""
_encoder = msgpack.Encoder()
_access_decoder = msgpack.Decoder(TokenPayloadAccess)
_refresh_decoder = msgpack.Decoder(TokenPayloadRefresh)
_json_decoder = json.Decoder(TokenPayloadBase)
[docs]
def get_unverified_jti(token: str) -> str | None:
"""Extract the jti claim from a JWT token without validation."""
try:
payload_segment = token.split(".")[1]
payload_segment += "=" * (4 - len(payload_segment) % 4)
payload_bytes = urlsafe_b64decode(payload_segment.encode("utf-8"))
struct_obj = _json_decoder.decode(payload_bytes)
except (IndexError, ValueError, DecodeError):
return None
else:
return struct_obj.jti
[docs]
def create_access_token(
user_id: "UUID",
email: str,
) -> str:
"""Create a short-lived JWT access token.
Args:
user_id (UUID): The unique identifier of the user.
email (str): The user's email address.
Returns:
str: The encoded access token string.
"""
jwt_payload = {
"sub": str(user_id),
"email": email,
}
return encode_jwt(payload=jwt_payload)
[docs]
def create_refresh_token(user_id: "UUID") -> str:
"""Create a long-lived JWT refresh token.
Args:
user_id (UUID): The unique identifier of the user.
Returns:
str: The encoded refresh token string.
"""
jwt_payload = {
"sub": str(user_id),
}
return encode_jwt(
payload=jwt_payload,
expire_timedelta=timedelta(
days=settings.jwt.REFRESH_TOKEN_EXPIRE_DAYS,
),
)
[docs]
async def get_access_token_payload(token: str) -> TokenPayloadAccess:
"""Decode an access token and validate its claims via cache fast-path.
Args:
token: The raw encoded JWT string.
Returns:
TokenPayloadAccess: The validated and parsed access token data.
Raises:
UnauthorizedException: If the token has expired, its signature
is invalid, or internal deserialization fails.
"""
token_id = get_unverified_jti(token=token)
if token_id:
cache_key = f"jwt:access:{token_id}"
cached_data = await cache.get(key=cache_key)
if cached_data:
payload = _access_decoder.decode(cached_data)
if not payload.is_expired:
return payload
try:
token_obj = decode_jwt(token=token)
claims = token_obj.claims
payload = TokenPayloadAccess(**claims)
if payload.is_expired:
raise UnauthorizedException(message="Token has expired")
now = time()
ttl = int(payload.exp - now)
if ttl > 0:
serialized = _encoder.encode(payload)
cache_key = f"jwt:access:{token_id}"
await cache.set(key=cache_key, value=serialized, expire=ttl)
except (JoseError, DecodeError) as exc:
log.warning(
"Access token decode failed",
error_type=type(exc).__name__,
error_detail=str(exc),
)
raise UnauthorizedException(message="Invalid token") from exc
else:
return payload
[docs]
def get_refresh_token_payload(token: str) -> TokenPayloadRefresh:
"""Decode a refresh token and validate its claims.
Args:
token: The raw encoded JWT string.
Returns:
TokenPayloadRefresh: The validated and parsed refresh token data.
Raises:
UnauthorizedException: If the token has expired, its signature
is invalid, or internal deserialization fails.
"""
try:
token_obj = decode_jwt(token=token)
claims = token_obj.claims
payload = TokenPayloadRefresh(**claims)
if payload.is_expired:
raise UnauthorizedException(message="Token has expired")
except (JoseError, DecodeError) as exc:
log.warning(
"Refresh token decode failed",
error_type=type(exc).__name__,
error_detail=str(exc),
)
msg = "Invalid token"
raise UnauthorizedException(message=msg) from exc
else:
return payload
[docs]
async def add_token_to_blacklist(refresh_token_identifier: str, ttl: int) -> None:
"""Add a refresh token identifier (JTI) to cache for revocation."""
await cache.set(
key=f"revoked:{refresh_token_identifier}",
value="1",
expire=ttl,
)
[docs]
async def is_token_in_blacklist(refresh_token_identifier: str) -> bool:
"""Check if a refresh token identifier (JTI) exists in the blacklist."""
return await cache.exists(f"revoked:{refresh_token_identifier}")