gabriel / musehub public
test_musehub_auth_security.py python
265 lines 9.5 KB
d4eb1c39 Theme overhaul: domains, new-repo, MCP docs, copy icons; legacy CSS rem… Gabriel Cardona <cgcardona@gmail.com> 3d ago
1 """Auth security edge-case tests.
2
3 Covers scenarios that test_musehub_auth.py's happy-path tests do not:
4
5 - Expired JWT tokens → 401
6 - Structurally invalid / tampered JWT → 401
7 - Missing Authorization header entirely → 401 on write endpoints
8 - Token with wrong algorithm (none-attack) → 401
9 - Token revocation cache: get/set/clear/TTL semantics
10 - Revoked token rejected on protected endpoints → 401
11
12 These tests exercise the auth layer end-to-end through the ASGI transport
13 so that FastAPI dependency injection, the token validator, and the DB
14 revocation check all participate.
15 """
16 from __future__ import annotations
17
18 import time
19 from datetime import datetime, timedelta, timezone
20
21 import jwt
22 import pytest
23 from httpx import AsyncClient
24 from sqlalchemy.ext.asyncio import AsyncSession
25
26 from musehub.auth.revocation_cache import (
27 clear_revocation_cache,
28 get_revocation_status,
29 set_revocation_status,
30 )
31 from musehub.auth.tokens import (
32 AccessCodeError,
33 create_access_token,
34 hash_token,
35 validate_access_code,
36 )
37 from musehub.config import settings
38 from musehub.db.models import User
39
40
41 # ---------------------------------------------------------------------------
42 # Helpers
43 # ---------------------------------------------------------------------------
44
45 def _make_expired_token(user_id: str = "test-user-id") -> str:
46 """Return a structurally valid JWT that expired one hour ago."""
47 secret = settings.access_token_secret or "test-secret"
48 now = datetime.now(tz=timezone.utc)
49 payload = {
50 "type": "access",
51 "sub": user_id,
52 "iat": int((now - timedelta(hours=2)).timestamp()),
53 "exp": int((now - timedelta(hours=1)).timestamp()),
54 }
55 return jwt.encode(payload, secret, algorithm=settings.access_token_algorithm)
56
57
58 def _make_tampered_token(valid_token: str) -> str:
59 """Flip one character in the signature portion of a valid JWT."""
60 header, payload, sig = valid_token.rsplit(".", 2)
61 bad_sig = sig[:-1] + ("A" if sig[-1] != "A" else "B")
62 return f"{header}.{payload}.{bad_sig}"
63
64
65 def _make_none_alg_token(user_id: str = "test-user-id") -> str:
66 """Craft a JWT with alg=none (classic signature-bypass attempt)."""
67 import base64, json
68 header = base64.urlsafe_b64encode(
69 json.dumps({"alg": "none", "typ": "JWT"}).encode()
70 ).rstrip(b"=").decode()
71 now = int(datetime.now(tz=timezone.utc).timestamp())
72 body = base64.urlsafe_b64encode(
73 json.dumps({"type": "access", "sub": user_id, "iat": now, "exp": now + 3600}).encode()
74 ).rstrip(b"=").decode()
75 return f"{header}.{body}."
76
77
78 # ---------------------------------------------------------------------------
79 # validate_access_code unit tests
80 # ---------------------------------------------------------------------------
81
82 class TestValidateAccessCodeUnit:
83 """Direct unit tests on the token validator (no HTTP)."""
84
85 def test_valid_token_returns_claims(self) -> None:
86 token = create_access_token(user_id="u1", expires_hours=1)
87 claims = validate_access_code(token)
88 assert claims["sub"] == "u1"
89 assert claims["type"] == "access"
90
91 def test_expired_token_raises(self) -> None:
92 token = _make_expired_token()
93 with pytest.raises(AccessCodeError, match="expired"):
94 validate_access_code(token)
95
96 def test_tampered_signature_raises(self) -> None:
97 token = create_access_token(user_id="u2", expires_hours=1)
98 bad = _make_tampered_token(token)
99 with pytest.raises(AccessCodeError):
100 validate_access_code(bad)
101
102 def test_garbage_string_raises(self) -> None:
103 with pytest.raises(AccessCodeError):
104 validate_access_code("not.a.jwt")
105
106 def test_none_algorithm_rejected(self) -> None:
107 token = _make_none_alg_token()
108 with pytest.raises(AccessCodeError):
109 validate_access_code(token)
110
111 def test_missing_type_claim_raises(self) -> None:
112 secret = settings.access_token_secret or "test-secret"
113 now = int(datetime.now(tz=timezone.utc).timestamp())
114 payload = {"sub": "u3", "iat": now, "exp": now + 3600}
115 token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm)
116 with pytest.raises(AccessCodeError, match="Invalid token type"):
117 validate_access_code(token)
118
119 def test_wrong_type_claim_raises(self) -> None:
120 secret = settings.access_token_secret or "test-secret"
121 now = int(datetime.now(tz=timezone.utc).timestamp())
122 payload = {"type": "refresh", "sub": "u4", "iat": now, "exp": now + 3600}
123 token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm)
124 with pytest.raises(AccessCodeError, match="Invalid token type"):
125 validate_access_code(token)
126
127 def test_admin_token_has_role_claim(self) -> None:
128 token = create_access_token(user_id="admin1", expires_hours=1, is_admin=True)
129 claims = validate_access_code(token)
130 assert claims.get("role") == "admin"
131
132 def test_anonymous_token_has_no_sub(self) -> None:
133 token = create_access_token(expires_hours=1)
134 claims = validate_access_code(token)
135 assert "sub" not in claims
136
137
138 # ---------------------------------------------------------------------------
139 # Revocation cache unit tests
140 # ---------------------------------------------------------------------------
141
142 class TestRevocationCache:
143 """Unit tests for the in-memory revocation cache."""
144
145 def setup_method(self) -> None:
146 clear_revocation_cache()
147
148 def teardown_method(self) -> None:
149 clear_revocation_cache()
150
151 def test_cache_miss_returns_none(self) -> None:
152 assert get_revocation_status("unknown-hash") is None
153
154 def test_set_valid_status_readable(self) -> None:
155 set_revocation_status("tok1", revoked=False)
156 assert get_revocation_status("tok1") is False
157
158 def test_set_revoked_status_readable(self) -> None:
159 set_revocation_status("tok2", revoked=True)
160 assert get_revocation_status("tok2") is True
161
162 def test_clear_removes_all_entries(self) -> None:
163 set_revocation_status("tok3", revoked=False)
164 set_revocation_status("tok4", revoked=True)
165 clear_revocation_cache()
166 assert get_revocation_status("tok3") is None
167 assert get_revocation_status("tok4") is None
168
169 def test_overwrite_status(self) -> None:
170 set_revocation_status("tok5", revoked=False)
171 set_revocation_status("tok5", revoked=True)
172 assert get_revocation_status("tok5") is True
173
174 def test_hash_token_is_deterministic(self) -> None:
175 t = "some.jwt.token"
176 assert hash_token(t) == hash_token(t)
177 assert len(hash_token(t)) == 64 # SHA-256 hex
178
179 def test_hash_token_distinct_inputs(self) -> None:
180 assert hash_token("abc") != hash_token("xyz")
181
182
183 # ---------------------------------------------------------------------------
184 # HTTP integration tests — invalid tokens on write endpoints
185 # ---------------------------------------------------------------------------
186
187 @pytest.mark.anyio
188 async def test_expired_token_rejected_on_create_repo(
189 client: AsyncClient,
190 db_session: AsyncSession,
191 ) -> None:
192 """An expired JWT returns 401 on a write endpoint."""
193 user = User(id="expired-user-id")
194 db_session.add(user)
195 await db_session.commit()
196
197 expired = _make_expired_token(user_id="expired-user-id")
198 resp = await client.post(
199 "/api/v1/repos",
200 json={"name": "beats", "owner": "testuser"},
201 headers={"Authorization": f"Bearer {expired}"},
202 )
203 assert resp.status_code == 401
204
205
206 @pytest.mark.anyio
207 async def test_tampered_token_rejected_on_create_repo(
208 client: AsyncClient,
209 db_session: AsyncSession,
210 ) -> None:
211 """A tampered JWT (bad signature) returns 401 on a write endpoint."""
212 user = User(id="tamper-user-id")
213 db_session.add(user)
214 await db_session.commit()
215
216 valid = create_access_token(user_id="tamper-user-id", expires_hours=1)
217 tampered = _make_tampered_token(valid)
218 resp = await client.post(
219 "/api/v1/repos",
220 json={"name": "beats", "owner": "testuser"},
221 headers={"Authorization": f"Bearer {tampered}"},
222 )
223 assert resp.status_code == 401
224
225
226 @pytest.mark.anyio
227 async def test_garbage_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None:
228 """A completely invalid token string returns 401."""
229 resp = await client.post(
230 "/api/v1/repos",
231 json={"name": "beats", "owner": "testuser"},
232 headers={"Authorization": "Bearer not-a-jwt-at-all"},
233 )
234 assert resp.status_code == 401
235
236
237 @pytest.mark.anyio
238 async def test_none_alg_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None:
239 """alg=none token is rejected — signature bypass attempt must fail."""
240 token = _make_none_alg_token()
241 resp = await client.post(
242 "/api/v1/repos",
243 json={"name": "beats", "owner": "testuser"},
244 headers={"Authorization": f"Bearer {token}"},
245 )
246 assert resp.status_code == 401
247
248
249 @pytest.mark.anyio
250 @pytest.mark.parametrize("endpoint,method,body", [
251 ("/api/v1/repos", "POST", {"name": "x", "owner": "y"}),
252 ("/api/v1/repos/fake-id/issues", "POST", {"title": "t"}),
253 ("/api/v1/repos/fake-id/issues/1/close", "POST", {}),
254 ])
255 async def test_missing_auth_header_returns_401(
256 client: AsyncClient,
257 db_session: AsyncSession,
258 endpoint: str,
259 method: str,
260 body: dict,
261 ) -> None:
262 """Write endpoints return 401 when the Authorization header is absent."""
263 fn = getattr(client, method.lower())
264 resp = await fn(endpoint, json=body)
265 assert resp.status_code == 401