dependencies.py
python
| 1 | """ |
| 2 | FastAPI Authentication Dependencies |
| 3 | |
| 4 | Provides dependency injection for protecting endpoints with access token validation |
| 5 | and for asset endpoints with device-ID-only (X-Device-ID) validation. |
| 6 | """ |
| 7 | |
| 8 | import logging |
| 9 | import uuid |
| 10 | |
| 11 | from fastapi import Depends, Header, HTTPException, status |
| 12 | from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| 13 | from sqlalchemy import select |
| 14 | from sqlalchemy.ext.asyncio import AsyncSession |
| 15 | from musehub.auth.tokens import validate_access_code, AccessCodeError, hash_token, TokenClaims as TokenClaims |
| 16 | from musehub.auth.revocation_cache import get_revocation_status, set_revocation_status |
| 17 | |
| 18 | logger = logging.getLogger(__name__) |
| 19 | |
| 20 | # HTTPBearer extracts the token from "Authorization: Bearer <token>" header |
| 21 | # auto_error=False allows us to provide custom error messages |
| 22 | security = HTTPBearer(auto_error=False) |
| 23 | |
| 24 | |
| 25 | async def _check_and_register_token(token: str, claims: TokenClaims) -> bool: |
| 26 | """ |
| 27 | Check if a token has been revoked, registering it if not found. |
| 28 | |
| 29 | Uses a new database session to avoid dependency issues. |
| 30 | Auto-registers tokens on first use for revocation tracking. |
| 31 | |
| 32 | Returns: |
| 33 | True if token is revoked, False if valid |
| 34 | |
| 35 | Raises: |
| 36 | HTTPException 503: If the database is unavailable (fail closed). |
| 37 | """ |
| 38 | try: |
| 39 | from musehub.db.database import AsyncSessionLocal |
| 40 | from musehub.db.models import AccessToken |
| 41 | from musehub.auth.tokens import get_token_expiration |
| 42 | |
| 43 | async with AsyncSessionLocal() as db: |
| 44 | token_hash = hash_token(token) |
| 45 | result = await db.execute( |
| 46 | select(AccessToken).where(AccessToken.token_hash == token_hash) |
| 47 | ) |
| 48 | access_token = result.scalar_one_or_none() |
| 49 | |
| 50 | if access_token is None: |
| 51 | # Token not in database - register it for revocation tracking |
| 52 | user_id = claims.get("sub") |
| 53 | if user_id: |
| 54 | try: |
| 55 | expires_at = get_token_expiration(token) |
| 56 | new_token = AccessToken( |
| 57 | user_id=user_id, |
| 58 | token_hash=token_hash, |
| 59 | expires_at=expires_at, |
| 60 | revoked=False, |
| 61 | ) |
| 62 | db.add(new_token) |
| 63 | await db.commit() |
| 64 | logger.debug(f"Auto-registered token for user {user_id[:8]}...") |
| 65 | except Exception as e: |
| 66 | # Don't fail if registration fails (user might not exist yet) |
| 67 | logger.debug(f"Could not register token: {e}") |
| 68 | return False |
| 69 | |
| 70 | return access_token.revoked |
| 71 | except Exception as e: |
| 72 | # Fail closed: if we cannot verify revocation status, reject the request |
| 73 | logger.error("Token revocation check failed: %s", e) |
| 74 | raise HTTPException( |
| 75 | status_code=status.HTTP_503_SERVICE_UNAVAILABLE, |
| 76 | detail="Unable to verify token status. Please try again later.", |
| 77 | headers={"WWW-Authenticate": "Bearer"}, |
| 78 | ) |
| 79 | |
| 80 | |
| 81 | async def require_valid_token( |
| 82 | credentials: HTTPAuthorizationCredentials | None = Depends(security), |
| 83 | ) -> TokenClaims: |
| 84 | """ |
| 85 | FastAPI dependency that validates access tokens. |
| 86 | |
| 87 | Checks: |
| 88 | 1. Token is present |
| 89 | 2. Token signature is valid |
| 90 | 3. Token has not expired |
| 91 | 4. Token has not been revoked |
| 92 | |
| 93 | Usage: |
| 94 | @router.post("/protected") |
| 95 | async def protected_endpoint( |
| 96 | token_claims: dict = Depends(require_valid_token) |
| 97 | ): |
| 98 | # token_claims contains: type, iat, exp, sub (optional) |
| 99 | ... |
| 100 | |
| 101 | Returns: |
| 102 | Decoded token claims dict |
| 103 | |
| 104 | Raises: |
| 105 | HTTPException 401: If token is missing, invalid, expired, or revoked |
| 106 | """ |
| 107 | if credentials is None: |
| 108 | logger.warning("Access attempt without token") |
| 109 | raise HTTPException( |
| 110 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 111 | detail="Access code required. Please provide a valid access code.", |
| 112 | headers={"WWW-Authenticate": "Bearer"}, |
| 113 | ) |
| 114 | |
| 115 | token = credentials.credentials |
| 116 | |
| 117 | try: |
| 118 | claims = validate_access_code(token) |
| 119 | logger.debug(f"Valid token, expires at {claims['exp']}") |
| 120 | |
| 121 | token_hash = hash_token(token) |
| 122 | cached = get_revocation_status(token_hash) |
| 123 | if cached is not None: |
| 124 | if cached: |
| 125 | logger.warning(f"Revoked token used by user {claims.get('sub', 'unknown')}") |
| 126 | raise HTTPException( |
| 127 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 128 | detail="Access code has been revoked.", |
| 129 | headers={"WWW-Authenticate": "Bearer"}, |
| 130 | ) |
| 131 | return claims |
| 132 | |
| 133 | # Cache miss: check DB (and auto-register new tokens), then cache result |
| 134 | if await _check_and_register_token(token, claims): |
| 135 | set_revocation_status(token_hash, True) |
| 136 | logger.warning(f"Revoked token used by user {claims.get('sub', 'unknown')}") |
| 137 | raise HTTPException( |
| 138 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 139 | detail="Access code has been revoked.", |
| 140 | headers={"WWW-Authenticate": "Bearer"}, |
| 141 | ) |
| 142 | set_revocation_status(token_hash, False) |
| 143 | return claims |
| 144 | |
| 145 | except AccessCodeError as e: |
| 146 | logger.warning("Invalid token: %s", e) |
| 147 | raise HTTPException( |
| 148 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 149 | detail="Invalid or expired access code.", |
| 150 | headers={"WWW-Authenticate": "Bearer"}, |
| 151 | ) |
| 152 | |
| 153 | |
| 154 | async def optional_token( |
| 155 | credentials: HTTPAuthorizationCredentials | None = Depends(security), |
| 156 | ) -> TokenClaims | None: |
| 157 | """FastAPI dependency for endpoints that are publicly readable but optionally authed. |
| 158 | |
| 159 | - No token → returns ``None`` (anonymous access allowed). |
| 160 | - Token present but invalid/expired/revoked → raises 401 (don't silently |
| 161 | ignore a bad credential that the caller clearly intended to use). |
| 162 | - Token present and valid → returns decoded claims like ``require_valid_token``. |
| 163 | |
| 164 | Use this on GET endpoints for public resources. Pair with a visibility |
| 165 | check in the handler: if the resource is private and claims is None, raise 401. |
| 166 | """ |
| 167 | if credentials is None: |
| 168 | return None |
| 169 | |
| 170 | token = credentials.credentials |
| 171 | |
| 172 | try: |
| 173 | claims = validate_access_code(token) |
| 174 | token_hash = hash_token(token) |
| 175 | cached = get_revocation_status(token_hash) |
| 176 | if cached is not None: |
| 177 | if cached: |
| 178 | raise HTTPException( |
| 179 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 180 | detail="Access code has been revoked.", |
| 181 | headers={"WWW-Authenticate": "Bearer"}, |
| 182 | ) |
| 183 | return claims |
| 184 | |
| 185 | if await _check_and_register_token(token, claims): |
| 186 | set_revocation_status(token_hash, True) |
| 187 | raise HTTPException( |
| 188 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 189 | detail="Access code has been revoked.", |
| 190 | headers={"WWW-Authenticate": "Bearer"}, |
| 191 | ) |
| 192 | set_revocation_status(token_hash, False) |
| 193 | return claims |
| 194 | |
| 195 | except AccessCodeError as e: |
| 196 | logger.warning("Invalid optional token: %s", e) |
| 197 | raise HTTPException( |
| 198 | status_code=status.HTTP_401_UNAUTHORIZED, |
| 199 | detail="Invalid or expired access code.", |
| 200 | headers={"WWW-Authenticate": "Bearer"}, |
| 201 | ) |
| 202 | |
| 203 | |
| 204 | async def require_device_id( |
| 205 | x_device_id: str | None = Header(None, alias="X-Device-ID"), |
| 206 | ) -> str: |
| 207 | """ |
| 208 | FastAPI dependency for asset endpoints: require a valid X-Device-ID header (UUID). |
| 209 | |
| 210 | Does not require JWT. Used for drum-kit, soundfont, and bundle download endpoints |
| 211 | so the macOS app can access assets without touching Keychain. |
| 212 | |
| 213 | Returns: |
| 214 | The validated device ID string (stripped). |
| 215 | |
| 216 | Raises: |
| 217 | HTTPException 400: If X-Device-ID is missing, empty, or not a valid UUID. |
| 218 | """ |
| 219 | if not x_device_id or not x_device_id.strip(): |
| 220 | logger.warning("Asset request without X-Device-ID") |
| 221 | raise HTTPException( |
| 222 | status_code=status.HTTP_400_BAD_REQUEST, |
| 223 | detail="X-Device-ID header required", |
| 224 | ) |
| 225 | value = x_device_id.strip() |
| 226 | try: |
| 227 | uuid.UUID(value) |
| 228 | except ValueError: |
| 229 | logger.warning("Invalid X-Device-ID format: %s", value[:32]) |
| 230 | raise HTTPException( |
| 231 | status_code=status.HTTP_400_BAD_REQUEST, |
| 232 | detail="Invalid X-Device-ID format", |
| 233 | ) |
| 234 | return value |