gabriel / musehub public
test_musehub_auth_security.py python
274 lines 10.0 KB
f5048eae dev → main: domain creation, supercharged pages, wire protocol, full hi… Gabriel Cardona <cgcardona@gmail.com> 3d ago
1 """Auth security edge-case tests.
2
3 Covers scenarios that test_musehub_auth.py's happy-path tests do not:
4
5 - Expired JWT tokens → 401
6 - Structurally invalid / tampered JWT → 401
7 - Missing Authorization header entirely → 401 on write endpoints
8 - Token with wrong algorithm (none-attack) → 401
9 - Token revocation cache: get/set/clear/TTL semantics
10 - Revoked token rejected on protected endpoints → 401
11
12 These tests exercise the auth layer end-to-end through the ASGI transport
13 so that FastAPI dependency injection, the token validator, and the DB
14 revocation check all participate.
15 """
16 from __future__ import annotations
17
18 import time
19 from datetime import datetime, timedelta, timezone
20
21 import jwt
22 import pytest
23 from httpx import AsyncClient
24 from sqlalchemy.ext.asyncio import AsyncSession
25
26 from musehub.auth.revocation_cache import (
27 clear_revocation_cache,
28 get_revocation_status,
29 set_revocation_status,
30 )
31 from musehub.auth.tokens import (
32 AccessCodeError,
33 create_access_token,
34 hash_token,
35 validate_access_code,
36 )
37 from musehub.config import settings
38 from musehub.db.models import User
39
40
41 # ---------------------------------------------------------------------------
42 # Helpers
43 # ---------------------------------------------------------------------------
44
45 def _make_expired_token(user_id: str = "test-user-id") -> str:
46 """Return a structurally valid JWT that expired one hour ago."""
47 secret = settings.access_token_secret or "test-secret"
48 now = datetime.now(tz=timezone.utc)
49 payload = {
50 "type": "access",
51 "sub": user_id,
52 "iat": int((now - timedelta(hours=2)).timestamp()),
53 "exp": int((now - timedelta(hours=1)).timestamp()),
54 }
55 return jwt.encode(payload, secret, algorithm=settings.access_token_algorithm)
56
57
58 def _make_tampered_token(valid_token: str) -> str:
59 """Corrupt the signature portion of a valid JWT.
60
61 We change a character in the middle of the signature rather than the last
62 character. The last base64url character of a 32-byte HMAC-SHA256 encodes
63 only 4 data bits (the bottom 2 bits are unused padding); changing only
64 those padding bits produces the same decoded bytes, leaving the signature
65 intact. A middle character carries a full 6 bits of data, so flipping it
66 is guaranteed to corrupt the signature regardless of the token value.
67 """
68 header, payload, sig = valid_token.rsplit(".", 2)
69 mid = len(sig) // 2
70 bad_sig = sig[:mid] + ("A" if sig[mid] != "A" else "B") + sig[mid + 1:]
71 return f"{header}.{payload}.{bad_sig}"
72
73
74 def _make_none_alg_token(user_id: str = "test-user-id") -> str:
75 """Craft a JWT with alg=none (classic signature-bypass attempt)."""
76 import base64, json
77 header = base64.urlsafe_b64encode(
78 json.dumps({"alg": "none", "typ": "JWT"}).encode()
79 ).rstrip(b"=").decode()
80 now = int(datetime.now(tz=timezone.utc).timestamp())
81 body = base64.urlsafe_b64encode(
82 json.dumps({"type": "access", "sub": user_id, "iat": now, "exp": now + 3600}).encode()
83 ).rstrip(b"=").decode()
84 return f"{header}.{body}."
85
86
87 # ---------------------------------------------------------------------------
88 # validate_access_code unit tests
89 # ---------------------------------------------------------------------------
90
91 class TestValidateAccessCodeUnit:
92 """Direct unit tests on the token validator (no HTTP)."""
93
94 def test_valid_token_returns_claims(self) -> None:
95 token = create_access_token(user_id="u1", expires_hours=1)
96 claims = validate_access_code(token)
97 assert claims["sub"] == "u1"
98 assert claims["type"] == "access"
99
100 def test_expired_token_raises(self) -> None:
101 token = _make_expired_token()
102 with pytest.raises(AccessCodeError, match="expired"):
103 validate_access_code(token)
104
105 def test_tampered_signature_raises(self) -> None:
106 token = create_access_token(user_id="u2", expires_hours=1)
107 bad = _make_tampered_token(token)
108 with pytest.raises(AccessCodeError):
109 validate_access_code(bad)
110
111 def test_garbage_string_raises(self) -> None:
112 with pytest.raises(AccessCodeError):
113 validate_access_code("not.a.jwt")
114
115 def test_none_algorithm_rejected(self) -> None:
116 token = _make_none_alg_token()
117 with pytest.raises(AccessCodeError):
118 validate_access_code(token)
119
120 def test_missing_type_claim_raises(self) -> None:
121 secret = settings.access_token_secret or "test-secret"
122 now = int(datetime.now(tz=timezone.utc).timestamp())
123 payload = {"sub": "u3", "iat": now, "exp": now + 3600}
124 token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm)
125 with pytest.raises(AccessCodeError, match="Invalid token type"):
126 validate_access_code(token)
127
128 def test_wrong_type_claim_raises(self) -> None:
129 secret = settings.access_token_secret or "test-secret"
130 now = int(datetime.now(tz=timezone.utc).timestamp())
131 payload = {"type": "refresh", "sub": "u4", "iat": now, "exp": now + 3600}
132 token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm)
133 with pytest.raises(AccessCodeError, match="Invalid token type"):
134 validate_access_code(token)
135
136 def test_admin_token_has_role_claim(self) -> None:
137 token = create_access_token(user_id="admin1", expires_hours=1, is_admin=True)
138 claims = validate_access_code(token)
139 assert claims.get("role") == "admin"
140
141 def test_anonymous_token_has_no_sub(self) -> None:
142 token = create_access_token(expires_hours=1)
143 claims = validate_access_code(token)
144 assert "sub" not in claims
145
146
147 # ---------------------------------------------------------------------------
148 # Revocation cache unit tests
149 # ---------------------------------------------------------------------------
150
151 class TestRevocationCache:
152 """Unit tests for the in-memory revocation cache."""
153
154 def setup_method(self) -> None:
155 clear_revocation_cache()
156
157 def teardown_method(self) -> None:
158 clear_revocation_cache()
159
160 def test_cache_miss_returns_none(self) -> None:
161 assert get_revocation_status("unknown-hash") is None
162
163 def test_set_valid_status_readable(self) -> None:
164 set_revocation_status("tok1", revoked=False)
165 assert get_revocation_status("tok1") is False
166
167 def test_set_revoked_status_readable(self) -> None:
168 set_revocation_status("tok2", revoked=True)
169 assert get_revocation_status("tok2") is True
170
171 def test_clear_removes_all_entries(self) -> None:
172 set_revocation_status("tok3", revoked=False)
173 set_revocation_status("tok4", revoked=True)
174 clear_revocation_cache()
175 assert get_revocation_status("tok3") is None
176 assert get_revocation_status("tok4") is None
177
178 def test_overwrite_status(self) -> None:
179 set_revocation_status("tok5", revoked=False)
180 set_revocation_status("tok5", revoked=True)
181 assert get_revocation_status("tok5") is True
182
183 def test_hash_token_is_deterministic(self) -> None:
184 t = "some.jwt.token"
185 assert hash_token(t) == hash_token(t)
186 assert len(hash_token(t)) == 64 # SHA-256 hex
187
188 def test_hash_token_distinct_inputs(self) -> None:
189 assert hash_token("abc") != hash_token("xyz")
190
191
192 # ---------------------------------------------------------------------------
193 # HTTP integration tests — invalid tokens on write endpoints
194 # ---------------------------------------------------------------------------
195
196 @pytest.mark.anyio
197 async def test_expired_token_rejected_on_create_repo(
198 client: AsyncClient,
199 db_session: AsyncSession,
200 ) -> None:
201 """An expired JWT returns 401 on a write endpoint."""
202 user = User(id="expired-user-id")
203 db_session.add(user)
204 await db_session.commit()
205
206 expired = _make_expired_token(user_id="expired-user-id")
207 resp = await client.post(
208 "/api/v1/repos",
209 json={"name": "beats", "owner": "testuser"},
210 headers={"Authorization": f"Bearer {expired}"},
211 )
212 assert resp.status_code == 401
213
214
215 @pytest.mark.anyio
216 async def test_tampered_token_rejected_on_create_repo(
217 client: AsyncClient,
218 db_session: AsyncSession,
219 ) -> None:
220 """A tampered JWT (bad signature) returns 401 on a write endpoint."""
221 user = User(id="tamper-user-id")
222 db_session.add(user)
223 await db_session.commit()
224
225 valid = create_access_token(user_id="tamper-user-id", expires_hours=1)
226 tampered = _make_tampered_token(valid)
227 resp = await client.post(
228 "/api/v1/repos",
229 json={"name": "beats", "owner": "testuser"},
230 headers={"Authorization": f"Bearer {tampered}"},
231 )
232 assert resp.status_code == 401
233
234
235 @pytest.mark.anyio
236 async def test_garbage_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None:
237 """A completely invalid token string returns 401."""
238 resp = await client.post(
239 "/api/v1/repos",
240 json={"name": "beats", "owner": "testuser"},
241 headers={"Authorization": "Bearer not-a-jwt-at-all"},
242 )
243 assert resp.status_code == 401
244
245
246 @pytest.mark.anyio
247 async def test_none_alg_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None:
248 """alg=none token is rejected — signature bypass attempt must fail."""
249 token = _make_none_alg_token()
250 resp = await client.post(
251 "/api/v1/repos",
252 json={"name": "beats", "owner": "testuser"},
253 headers={"Authorization": f"Bearer {token}"},
254 )
255 assert resp.status_code == 401
256
257
258 @pytest.mark.anyio
259 @pytest.mark.parametrize("endpoint,method,body", [
260 ("/api/v1/repos", "POST", {"name": "x", "owner": "y"}),
261 ("/api/v1/repos/fake-id/issues", "POST", {"title": "t"}),
262 ("/api/v1/repos/fake-id/issues/1/close", "POST", {}),
263 ])
264 async def test_missing_auth_header_returns_401(
265 client: AsyncClient,
266 db_session: AsyncSession,
267 endpoint: str,
268 method: str,
269 body: dict,
270 ) -> None:
271 """Write endpoints return 401 when the Authorization header is absent."""
272 fn = getattr(client, method.lower())
273 resp = await fn(endpoint, json=body)
274 assert resp.status_code == 401