from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
NoReturn,
cast,
)
from advanced_alchemy.exceptions import NotFoundError
from advanced_alchemy.extensions.fastapi import (
repository,
service,
)
from advanced_alchemy.service import (
ModelDictT,
OffsetPagination,
schema_dump,
)
from cashews import cache
from sqlalchemy import delete
from sqlalchemy.orm import (
load_only,
noload,
raiseload,
selectinload,
)
from app.db import models as m
from app.domain.catalogs.services import (
BaseCatalogService,
EquipmentService,
ExerciseTagService,
MuscleGroupService,
)
from app.domain.exercises.schemas import ExerciseRead
from app.lib.deps import (
CacheKeyBuilder,
CompositeServiceMixin,
)
from app.lib.exceptions import (
BadRequestException,
NotFoundException,
PermissionDeniedException,
)
if TYPE_CHECKING:
from uuid import UUID
from app.domain.catalogs.schemas import FieldsReadBase
from app.domain.exercises.filters import ExerciseFilters
from app.domain.users.schemas import UserAuth
[docs]
class ExerciseService(CompositeServiceMixin, service.SQLAlchemyAsyncRepositoryService[m.Exercise]):
"""Service for managing Exercise entities.
Provides high-level business logic for exercises, including handling complex relationships
with muscle groups, equipment, and tags via CompositeServiceMixin.
"""
class ExerciseRepository(repository.SQLAlchemyAsyncRepository[m.Exercise]):
"""Exercise SQLAlchemy Repository."""
model_type = m.Exercise
repository_type = ExerciseRepository
_rel_keys = ("primary_muscles", "secondary_muscles", "equipment", "tags")
@property
def muscles(self) -> MuscleGroupService:
return self._get_service(MuscleGroupService)
@property
def equipment(self) -> EquipmentService:
return self._get_service(EquipmentService)
@property
def tags(self) -> ExerciseTagService:
return self._get_service(ExerciseTagService)
async def to_model_on_create(self, data: ModelDictT[m.Exercise]) -> ModelDictT[m.Exercise]:
return await self._populate_model(data)
async def to_model_on_update(self, data: ModelDictT[m.Exercise]) -> ModelDictT[m.Exercise]:
return await self._populate_model(data)
async def to_model_on_upsert(self, data: ModelDictT[m.Exercise]) -> ModelDictT[m.Exercise]:
return await self._populate_model(data)
async def _validate_and_populate_fields(
self,
data: dict[str, Any],
) -> dict[str, Any]:
primary_ids = set(data.get("primary_muscles") or [])
secondary_ids = set(data.get("secondary_muscles") or [])
if all_muscles_ids := primary_ids | secondary_ids:
all_cached_muscles = await self.muscles.get_all_cached()
found_muscles = [obj for obj in all_cached_muscles if obj.id in all_muscles_ids]
if len(found_muscles) != len(all_muscles_ids):
found_ids = {obj.id for obj in found_muscles}
self._raise_muscle_not_found(primary_ids - found_ids, secondary_ids - found_ids)
muscles = await self.muscles.get_managed_objs(
target_objs=found_muscles,
)
if "primary_muscles" in data:
data["primary_muscles"] = [prim for prim in muscles if prim.id in primary_ids]
if "secondary_muscles" in data:
data["secondary_muscles"] = [sec for sec in muscles if sec.id in secondary_ids]
services: list[tuple[str, BaseCatalogService[Any, Any]]] = [
("equipment", self.equipment),
("tags", self.tags),
]
for key, serv in services:
if requested_ids := set(data.get(key) or []):
cached_items = await serv.get_all_cached()
found_items = self._validate_ids(
requested_ids=requested_ids,
cached_data=cached_items,
error_prefix=key.capitalize(),
)
data[key] = await serv.get_managed_objs(target_objs=found_items)
return data
@staticmethod
def _raise_muscle_not_found(
missing_prim: set[int],
missing_sec: set[int],
) -> NoReturn:
if missing_prim and missing_sec:
msg = f"Primary {list(missing_prim)} and secondary {list(missing_sec)} muscles not found"
elif missing_prim:
msg = f"Primary muscles not found: {list(missing_prim)}"
else:
msg = f"Secondary muscles not found: {list(missing_sec)}"
raise NotFoundException(message=msg)
@staticmethod
def _validate_ids(
requested_ids: set[int],
cached_data: list[FieldsReadBase],
error_prefix: str,
) -> list[FieldsReadBase]:
"""Validate muscle/equipment/tag IDs and replace them with model instances."""
found = [obj for obj in cached_data if obj.id in requested_ids]
if len(found) != len(requested_ids):
found_ids = {obj.id for obj in found}
missing = list(requested_ids - found_ids)
raise NotFoundException(message=f"{error_prefix} not found: {missing}")
return found
async def _populate_model(self, data: ModelDictT[m.Exercise]) -> ModelDictT[m.Exercise]:
data = schema_dump(data)
data = await self._validate_and_populate_fields(data)
model = await self.to_model(data, operation=None)
for key in self._rel_keys:
if key in data:
setattr(model, key, data[key])
return model
[docs]
async def get_exercise_by_filter(
self,
user_id: UUID,
name: str | None,
slug: str | None,
) -> ExerciseRead:
"""Fetch a specific exercise by name (for custom) or slug (for system)."""
if name and slug:
msg = "You must specify only one of the following: slug or name"
raise BadRequestException(message=msg)
try:
if name:
db_obj = await self.get_one(
m.Exercise.created_by == user_id,
name=name,
)
elif slug:
db_obj = await self.get_one(
m.Exercise.is_system_default.is_(True),
slug=slug,
)
else:
msg = "Either name or slug must be provided"
raise BadRequestException(message=msg)
return self.to_schema(db_obj, schema_type=ExerciseRead)
except NotFoundError as exc:
msg = f"Exercise with '{slug or name}' not found"
raise NotFoundException(message=msg) from exc
[docs]
async def update_exercise(
self,
exercise_id: UUID,
data: dict[str, Any],
extra_filters: dict[str, Any],
) -> m.Exercise:
"""Update an exercise with optimized relationship loading."""
exists = await self.exists(id=exercise_id, **extra_filters)
if not exists:
msg = "Exercise not found"
raise NotFoundException(message=msg)
return await self.update(
data=data,
item_id=exercise_id,
load=[
selectinload(m.Exercise.primary_muscles),
selectinload(m.Exercise.secondary_muscles),
selectinload(m.Exercise.equipment),
selectinload(m.Exercise.tags) if extra_filters.get("is_system_default") else noload(m.Exercise.tags),
],
auto_refresh=False,
)
[docs]
async def delete_exercise(
self,
exercise_id: UUID,
user_auth: UserAuth,
) -> None:
"""Delete an exercise after checking ownership or superuser status.
Deleting system-default exercises requires superuser privileges.
"""
db_obj = await self.get_one_or_none(
id=exercise_id,
load=[
load_only(m.Exercise.id, m.Exercise.created_by, m.Exercise.is_system_default),
raiseload("*"),
],
)
if db_obj is None:
msg = "Exercise not found"
raise NotFoundException(message=msg)
if (db_obj.is_system_default and not user_auth.is_superuser) or (
not db_obj.is_system_default and db_obj.created_by != user_auth.id
):
msg = "You do not have permission to delete this exercise"
raise PermissionDeniedException(message=msg)
stmt = delete(m.Exercise).where(m.Exercise.id == exercise_id)
await self.repository.session.execute(stmt)
[docs]
async def get_exercises_paginated_dto(
self,
params: ExerciseFilters,
user_id: UUID,
) -> OffsetPagination[ExerciseRead]:
"""Provide filtered and paginated list of exercises with caching."""
params_key = CacheKeyBuilder.for_exercises(params=params, user_id=user_id)
cached_data = await cache.get(key=params_key)
if not cached_data:
filters = params.build_exercise_filters(user_id=user_id)
results, total = await self.get_many_and_count(*filters)
exercises = self.to_schema(data=results, total=total, filters=filters, schema_type=ExerciseRead)
await cache.set(key=params_key, value=exercises, expire="3m")
return exercises
return cast("OffsetPagination[ExerciseRead]", cached_data)