test_musehub_auth_security.py
python
| 1 | """Auth security edge-case tests. |
| 2 | |
| 3 | Covers scenarios that test_musehub_auth.py's happy-path tests do not: |
| 4 | |
| 5 | - Expired JWT tokens → 401 |
| 6 | - Structurally invalid / tampered JWT → 401 |
| 7 | - Missing Authorization header entirely → 401 on write endpoints |
| 8 | - Token with wrong algorithm (none-attack) → 401 |
| 9 | - Token revocation cache: get/set/clear/TTL semantics |
| 10 | - Revoked token rejected on protected endpoints → 401 |
| 11 | |
| 12 | These tests exercise the auth layer end-to-end through the ASGI transport |
| 13 | so that FastAPI dependency injection, the token validator, and the DB |
| 14 | revocation check all participate. |
| 15 | """ |
| 16 | from __future__ import annotations |
| 17 | |
| 18 | import time |
| 19 | from datetime import datetime, timedelta, timezone |
| 20 | |
| 21 | import jwt |
| 22 | import pytest |
| 23 | from httpx import AsyncClient |
| 24 | from sqlalchemy.ext.asyncio import AsyncSession |
| 25 | |
| 26 | from musehub.auth.revocation_cache import ( |
| 27 | clear_revocation_cache, |
| 28 | get_revocation_status, |
| 29 | set_revocation_status, |
| 30 | ) |
| 31 | from musehub.auth.tokens import ( |
| 32 | AccessCodeError, |
| 33 | create_access_token, |
| 34 | hash_token, |
| 35 | validate_access_code, |
| 36 | ) |
| 37 | from musehub.config import settings |
| 38 | from musehub.db.models import User |
| 39 | |
| 40 | |
| 41 | # --------------------------------------------------------------------------- |
| 42 | # Helpers |
| 43 | # --------------------------------------------------------------------------- |
| 44 | |
| 45 | def _make_expired_token(user_id: str = "test-user-id") -> str: |
| 46 | """Return a structurally valid JWT that expired one hour ago.""" |
| 47 | secret = settings.access_token_secret or "test-secret" |
| 48 | now = datetime.now(tz=timezone.utc) |
| 49 | payload = { |
| 50 | "type": "access", |
| 51 | "sub": user_id, |
| 52 | "iat": int((now - timedelta(hours=2)).timestamp()), |
| 53 | "exp": int((now - timedelta(hours=1)).timestamp()), |
| 54 | } |
| 55 | return jwt.encode(payload, secret, algorithm=settings.access_token_algorithm) |
| 56 | |
| 57 | |
| 58 | def _make_tampered_token(valid_token: str) -> str: |
| 59 | """Corrupt the signature portion of a valid JWT. |
| 60 | |
| 61 | We change a character in the middle of the signature rather than the last |
| 62 | character. The last base64url character of a 32-byte HMAC-SHA256 encodes |
| 63 | only 4 data bits (the bottom 2 bits are unused padding); changing only |
| 64 | those padding bits produces the same decoded bytes, leaving the signature |
| 65 | intact. A middle character carries a full 6 bits of data, so flipping it |
| 66 | is guaranteed to corrupt the signature regardless of the token value. |
| 67 | """ |
| 68 | header, payload, sig = valid_token.rsplit(".", 2) |
| 69 | mid = len(sig) // 2 |
| 70 | bad_sig = sig[:mid] + ("A" if sig[mid] != "A" else "B") + sig[mid + 1:] |
| 71 | return f"{header}.{payload}.{bad_sig}" |
| 72 | |
| 73 | |
| 74 | def _make_none_alg_token(user_id: str = "test-user-id") -> str: |
| 75 | """Craft a JWT with alg=none (classic signature-bypass attempt).""" |
| 76 | import base64, json |
| 77 | header = base64.urlsafe_b64encode( |
| 78 | json.dumps({"alg": "none", "typ": "JWT"}).encode() |
| 79 | ).rstrip(b"=").decode() |
| 80 | now = int(datetime.now(tz=timezone.utc).timestamp()) |
| 81 | body = base64.urlsafe_b64encode( |
| 82 | json.dumps({"type": "access", "sub": user_id, "iat": now, "exp": now + 3600}).encode() |
| 83 | ).rstrip(b"=").decode() |
| 84 | return f"{header}.{body}." |
| 85 | |
| 86 | |
| 87 | # --------------------------------------------------------------------------- |
| 88 | # validate_access_code unit tests |
| 89 | # --------------------------------------------------------------------------- |
| 90 | |
| 91 | class TestValidateAccessCodeUnit: |
| 92 | """Direct unit tests on the token validator (no HTTP).""" |
| 93 | |
| 94 | def test_valid_token_returns_claims(self) -> None: |
| 95 | token = create_access_token(user_id="u1", expires_hours=1) |
| 96 | claims = validate_access_code(token) |
| 97 | assert claims["sub"] == "u1" |
| 98 | assert claims["type"] == "access" |
| 99 | |
| 100 | def test_expired_token_raises(self) -> None: |
| 101 | token = _make_expired_token() |
| 102 | with pytest.raises(AccessCodeError, match="expired"): |
| 103 | validate_access_code(token) |
| 104 | |
| 105 | def test_tampered_signature_raises(self) -> None: |
| 106 | token = create_access_token(user_id="u2", expires_hours=1) |
| 107 | bad = _make_tampered_token(token) |
| 108 | with pytest.raises(AccessCodeError): |
| 109 | validate_access_code(bad) |
| 110 | |
| 111 | def test_garbage_string_raises(self) -> None: |
| 112 | with pytest.raises(AccessCodeError): |
| 113 | validate_access_code("not.a.jwt") |
| 114 | |
| 115 | def test_none_algorithm_rejected(self) -> None: |
| 116 | token = _make_none_alg_token() |
| 117 | with pytest.raises(AccessCodeError): |
| 118 | validate_access_code(token) |
| 119 | |
| 120 | def test_missing_type_claim_raises(self) -> None: |
| 121 | secret = settings.access_token_secret or "test-secret" |
| 122 | now = int(datetime.now(tz=timezone.utc).timestamp()) |
| 123 | payload = {"sub": "u3", "iat": now, "exp": now + 3600} |
| 124 | token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm) |
| 125 | with pytest.raises(AccessCodeError, match="Invalid token type"): |
| 126 | validate_access_code(token) |
| 127 | |
| 128 | def test_wrong_type_claim_raises(self) -> None: |
| 129 | secret = settings.access_token_secret or "test-secret" |
| 130 | now = int(datetime.now(tz=timezone.utc).timestamp()) |
| 131 | payload = {"type": "refresh", "sub": "u4", "iat": now, "exp": now + 3600} |
| 132 | token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm) |
| 133 | with pytest.raises(AccessCodeError, match="Invalid token type"): |
| 134 | validate_access_code(token) |
| 135 | |
| 136 | def test_admin_token_has_role_claim(self) -> None: |
| 137 | token = create_access_token(user_id="admin1", expires_hours=1, is_admin=True) |
| 138 | claims = validate_access_code(token) |
| 139 | assert claims.get("role") == "admin" |
| 140 | |
| 141 | def test_anonymous_token_has_no_sub(self) -> None: |
| 142 | token = create_access_token(expires_hours=1) |
| 143 | claims = validate_access_code(token) |
| 144 | assert "sub" not in claims |
| 145 | |
| 146 | |
| 147 | # --------------------------------------------------------------------------- |
| 148 | # Revocation cache unit tests |
| 149 | # --------------------------------------------------------------------------- |
| 150 | |
| 151 | class TestRevocationCache: |
| 152 | """Unit tests for the in-memory revocation cache.""" |
| 153 | |
| 154 | def setup_method(self) -> None: |
| 155 | clear_revocation_cache() |
| 156 | |
| 157 | def teardown_method(self) -> None: |
| 158 | clear_revocation_cache() |
| 159 | |
| 160 | def test_cache_miss_returns_none(self) -> None: |
| 161 | assert get_revocation_status("unknown-hash") is None |
| 162 | |
| 163 | def test_set_valid_status_readable(self) -> None: |
| 164 | set_revocation_status("tok1", revoked=False) |
| 165 | assert get_revocation_status("tok1") is False |
| 166 | |
| 167 | def test_set_revoked_status_readable(self) -> None: |
| 168 | set_revocation_status("tok2", revoked=True) |
| 169 | assert get_revocation_status("tok2") is True |
| 170 | |
| 171 | def test_clear_removes_all_entries(self) -> None: |
| 172 | set_revocation_status("tok3", revoked=False) |
| 173 | set_revocation_status("tok4", revoked=True) |
| 174 | clear_revocation_cache() |
| 175 | assert get_revocation_status("tok3") is None |
| 176 | assert get_revocation_status("tok4") is None |
| 177 | |
| 178 | def test_overwrite_status(self) -> None: |
| 179 | set_revocation_status("tok5", revoked=False) |
| 180 | set_revocation_status("tok5", revoked=True) |
| 181 | assert get_revocation_status("tok5") is True |
| 182 | |
| 183 | def test_hash_token_is_deterministic(self) -> None: |
| 184 | t = "some.jwt.token" |
| 185 | assert hash_token(t) == hash_token(t) |
| 186 | assert len(hash_token(t)) == 64 # SHA-256 hex |
| 187 | |
| 188 | def test_hash_token_distinct_inputs(self) -> None: |
| 189 | assert hash_token("abc") != hash_token("xyz") |
| 190 | |
| 191 | |
| 192 | # --------------------------------------------------------------------------- |
| 193 | # HTTP integration tests — invalid tokens on write endpoints |
| 194 | # --------------------------------------------------------------------------- |
| 195 | |
| 196 | @pytest.mark.anyio |
| 197 | async def test_expired_token_rejected_on_create_repo( |
| 198 | client: AsyncClient, |
| 199 | db_session: AsyncSession, |
| 200 | ) -> None: |
| 201 | """An expired JWT returns 401 on a write endpoint.""" |
| 202 | user = User(id="expired-user-id") |
| 203 | db_session.add(user) |
| 204 | await db_session.commit() |
| 205 | |
| 206 | expired = _make_expired_token(user_id="expired-user-id") |
| 207 | resp = await client.post( |
| 208 | "/api/v1/repos", |
| 209 | json={"name": "beats", "owner": "testuser"}, |
| 210 | headers={"Authorization": f"Bearer {expired}"}, |
| 211 | ) |
| 212 | assert resp.status_code == 401 |
| 213 | |
| 214 | |
| 215 | @pytest.mark.anyio |
| 216 | async def test_tampered_token_rejected_on_create_repo( |
| 217 | client: AsyncClient, |
| 218 | db_session: AsyncSession, |
| 219 | ) -> None: |
| 220 | """A tampered JWT (bad signature) returns 401 on a write endpoint.""" |
| 221 | user = User(id="tamper-user-id") |
| 222 | db_session.add(user) |
| 223 | await db_session.commit() |
| 224 | |
| 225 | valid = create_access_token(user_id="tamper-user-id", expires_hours=1) |
| 226 | tampered = _make_tampered_token(valid) |
| 227 | resp = await client.post( |
| 228 | "/api/v1/repos", |
| 229 | json={"name": "beats", "owner": "testuser"}, |
| 230 | headers={"Authorization": f"Bearer {tampered}"}, |
| 231 | ) |
| 232 | assert resp.status_code == 401 |
| 233 | |
| 234 | |
| 235 | @pytest.mark.anyio |
| 236 | async def test_garbage_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None: |
| 237 | """A completely invalid token string returns 401.""" |
| 238 | resp = await client.post( |
| 239 | "/api/v1/repos", |
| 240 | json={"name": "beats", "owner": "testuser"}, |
| 241 | headers={"Authorization": "Bearer not-a-jwt-at-all"}, |
| 242 | ) |
| 243 | assert resp.status_code == 401 |
| 244 | |
| 245 | |
| 246 | @pytest.mark.anyio |
| 247 | async def test_none_alg_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None: |
| 248 | """alg=none token is rejected — signature bypass attempt must fail.""" |
| 249 | token = _make_none_alg_token() |
| 250 | resp = await client.post( |
| 251 | "/api/v1/repos", |
| 252 | json={"name": "beats", "owner": "testuser"}, |
| 253 | headers={"Authorization": f"Bearer {token}"}, |
| 254 | ) |
| 255 | assert resp.status_code == 401 |
| 256 | |
| 257 | |
| 258 | @pytest.mark.anyio |
| 259 | @pytest.mark.parametrize("endpoint,method,body", [ |
| 260 | ("/api/v1/repos", "POST", {"name": "x", "owner": "y"}), |
| 261 | ("/api/v1/repos/fake-id/issues", "POST", {"title": "t"}), |
| 262 | ("/api/v1/repos/fake-id/issues/1/close", "POST", {}), |
| 263 | ]) |
| 264 | async def test_missing_auth_header_returns_401( |
| 265 | client: AsyncClient, |
| 266 | db_session: AsyncSession, |
| 267 | endpoint: str, |
| 268 | method: str, |
| 269 | body: dict, |
| 270 | ) -> None: |
| 271 | """Write endpoints return 401 when the Authorization header is absent.""" |
| 272 | fn = getattr(client, method.lower()) |
| 273 | resp = await fn(endpoint, json=body) |
| 274 | assert resp.status_code == 401 |