gabriel / musehub public
dependencies.py python
234 lines 8.5 KB
6b53f1af feat: supercharge all pages, full SOC refactor, and Python 3.14 upgrade (#7) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """
2 FastAPI Authentication Dependencies
3
4 Provides dependency injection for protecting endpoints with access token validation
5 and for asset endpoints with device-ID-only (X-Device-ID) validation.
6 """
7
8 import logging
9 import uuid
10
11 from fastapi import Depends, Header, HTTPException, status
12 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
13 from sqlalchemy import select
14 from sqlalchemy.ext.asyncio import AsyncSession
15 from musehub.auth.tokens import validate_access_code, AccessCodeError, hash_token, TokenClaims as TokenClaims
16 from musehub.auth.revocation_cache import get_revocation_status, set_revocation_status
17
18 logger = logging.getLogger(__name__)
19
20 # HTTPBearer extracts the token from "Authorization: Bearer <token>" header
21 # auto_error=False allows us to provide custom error messages
22 security = HTTPBearer(auto_error=False)
23
24
25 async def _check_and_register_token(token: str, claims: TokenClaims) -> bool:
26 """
27 Check if a token has been revoked, registering it if not found.
28
29 Uses a new database session to avoid dependency issues.
30 Auto-registers tokens on first use for revocation tracking.
31
32 Returns:
33 True if token is revoked, False if valid
34
35 Raises:
36 HTTPException 503: If the database is unavailable (fail closed).
37 """
38 try:
39 from musehub.db.database import AsyncSessionLocal
40 from musehub.db.models import AccessToken
41 from musehub.auth.tokens import get_token_expiration
42
43 async with AsyncSessionLocal() as db:
44 token_hash = hash_token(token)
45 result = await db.execute(
46 select(AccessToken).where(AccessToken.token_hash == token_hash)
47 )
48 access_token = result.scalar_one_or_none()
49
50 if access_token is None:
51 # Token not in database - register it for revocation tracking
52 user_id = claims.get("sub")
53 if user_id:
54 try:
55 expires_at = get_token_expiration(token)
56 new_token = AccessToken(
57 user_id=user_id,
58 token_hash=token_hash,
59 expires_at=expires_at,
60 revoked=False,
61 )
62 db.add(new_token)
63 await db.commit()
64 logger.debug(f"Auto-registered token for user {user_id[:8]}...")
65 except Exception as e:
66 # Don't fail if registration fails (user might not exist yet)
67 logger.debug(f"Could not register token: {e}")
68 return False
69
70 return access_token.revoked
71 except Exception as e:
72 # Fail closed: if we cannot verify revocation status, reject the request
73 logger.error("Token revocation check failed: %s", e)
74 raise HTTPException(
75 status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
76 detail="Unable to verify token status. Please try again later.",
77 headers={"WWW-Authenticate": "Bearer"},
78 )
79
80
81 async def require_valid_token(
82 credentials: HTTPAuthorizationCredentials | None = Depends(security),
83 ) -> TokenClaims:
84 """
85 FastAPI dependency that validates access tokens.
86
87 Checks:
88 1. Token is present
89 2. Token signature is valid
90 3. Token has not expired
91 4. Token has not been revoked
92
93 Usage:
94 @router.post("/protected")
95 async def protected_endpoint(
96 token_claims: dict = Depends(require_valid_token)
97 ):
98 # token_claims contains: type, iat, exp, sub (optional)
99 ...
100
101 Returns:
102 Decoded token claims dict
103
104 Raises:
105 HTTPException 401: If token is missing, invalid, expired, or revoked
106 """
107 if credentials is None:
108 logger.warning("Access attempt without token")
109 raise HTTPException(
110 status_code=status.HTTP_401_UNAUTHORIZED,
111 detail="Access code required. Please provide a valid access code.",
112 headers={"WWW-Authenticate": "Bearer"},
113 )
114
115 token = credentials.credentials
116
117 try:
118 claims = validate_access_code(token)
119 logger.debug(f"Valid token, expires at {claims['exp']}")
120
121 token_hash = hash_token(token)
122 cached = get_revocation_status(token_hash)
123 if cached is not None:
124 if cached:
125 logger.warning(f"Revoked token used by user {claims.get('sub', 'unknown')}")
126 raise HTTPException(
127 status_code=status.HTTP_401_UNAUTHORIZED,
128 detail="Access code has been revoked.",
129 headers={"WWW-Authenticate": "Bearer"},
130 )
131 return claims
132
133 # Cache miss: check DB (and auto-register new tokens), then cache result
134 if await _check_and_register_token(token, claims):
135 set_revocation_status(token_hash, True)
136 logger.warning(f"Revoked token used by user {claims.get('sub', 'unknown')}")
137 raise HTTPException(
138 status_code=status.HTTP_401_UNAUTHORIZED,
139 detail="Access code has been revoked.",
140 headers={"WWW-Authenticate": "Bearer"},
141 )
142 set_revocation_status(token_hash, False)
143 return claims
144
145 except AccessCodeError as e:
146 logger.warning("Invalid token: %s", e)
147 raise HTTPException(
148 status_code=status.HTTP_401_UNAUTHORIZED,
149 detail="Invalid or expired access code.",
150 headers={"WWW-Authenticate": "Bearer"},
151 )
152
153
154 async def optional_token(
155 credentials: HTTPAuthorizationCredentials | None = Depends(security),
156 ) -> TokenClaims | None:
157 """FastAPI dependency for endpoints that are publicly readable but optionally authed.
158
159 - No token → returns ``None`` (anonymous access allowed).
160 - Token present but invalid/expired/revoked → raises 401 (don't silently
161 ignore a bad credential that the caller clearly intended to use).
162 - Token present and valid → returns decoded claims like ``require_valid_token``.
163
164 Use this on GET endpoints for public resources. Pair with a visibility
165 check in the handler: if the resource is private and claims is None, raise 401.
166 """
167 if credentials is None:
168 return None
169
170 token = credentials.credentials
171
172 try:
173 claims = validate_access_code(token)
174 token_hash = hash_token(token)
175 cached = get_revocation_status(token_hash)
176 if cached is not None:
177 if cached:
178 raise HTTPException(
179 status_code=status.HTTP_401_UNAUTHORIZED,
180 detail="Access code has been revoked.",
181 headers={"WWW-Authenticate": "Bearer"},
182 )
183 return claims
184
185 if await _check_and_register_token(token, claims):
186 set_revocation_status(token_hash, True)
187 raise HTTPException(
188 status_code=status.HTTP_401_UNAUTHORIZED,
189 detail="Access code has been revoked.",
190 headers={"WWW-Authenticate": "Bearer"},
191 )
192 set_revocation_status(token_hash, False)
193 return claims
194
195 except AccessCodeError as e:
196 logger.warning("Invalid optional token: %s", e)
197 raise HTTPException(
198 status_code=status.HTTP_401_UNAUTHORIZED,
199 detail="Invalid or expired access code.",
200 headers={"WWW-Authenticate": "Bearer"},
201 )
202
203
204 async def require_device_id(
205 x_device_id: str | None = Header(None, alias="X-Device-ID"),
206 ) -> str:
207 """
208 FastAPI dependency for asset endpoints: require a valid X-Device-ID header (UUID).
209
210 Does not require JWT. Used for drum-kit, soundfont, and bundle download endpoints
211 so the macOS app can access assets without touching Keychain.
212
213 Returns:
214 The validated device ID string (stripped).
215
216 Raises:
217 HTTPException 400: If X-Device-ID is missing, empty, or not a valid UUID.
218 """
219 if not x_device_id or not x_device_id.strip():
220 logger.warning("Asset request without X-Device-ID")
221 raise HTTPException(
222 status_code=status.HTTP_400_BAD_REQUEST,
223 detail="X-Device-ID header required",
224 )
225 value = x_device_id.strip()
226 try:
227 uuid.UUID(value)
228 except ValueError:
229 logger.warning("Invalid X-Device-ID format: %s", value[:32])
230 raise HTTPException(
231 status_code=status.HTTP_400_BAD_REQUEST,
232 detail="Invalid X-Device-ID format",
233 )
234 return value