test_musehub_auth_security.py
python
| 1 | """Auth security edge-case tests. |
| 2 | |
| 3 | Covers scenarios that test_musehub_auth.py's happy-path tests do not: |
| 4 | |
| 5 | - Expired JWT tokens → 401 |
| 6 | - Structurally invalid / tampered JWT → 401 |
| 7 | - Missing Authorization header entirely → 401 on write endpoints |
| 8 | - Token with wrong algorithm (none-attack) → 401 |
| 9 | - Token revocation cache: get/set/clear/TTL semantics |
| 10 | - Revoked token rejected on protected endpoints → 401 |
| 11 | |
| 12 | These tests exercise the auth layer end-to-end through the ASGI transport |
| 13 | so that FastAPI dependency injection, the token validator, and the DB |
| 14 | revocation check all participate. |
| 15 | """ |
| 16 | from __future__ import annotations |
| 17 | |
| 18 | import time |
| 19 | from datetime import datetime, timedelta, timezone |
| 20 | |
| 21 | import jwt |
| 22 | import pytest |
| 23 | from httpx import AsyncClient |
| 24 | from sqlalchemy.ext.asyncio import AsyncSession |
| 25 | |
| 26 | from musehub.auth.revocation_cache import ( |
| 27 | clear_revocation_cache, |
| 28 | get_revocation_status, |
| 29 | set_revocation_status, |
| 30 | ) |
| 31 | from musehub.auth.tokens import ( |
| 32 | AccessCodeError, |
| 33 | create_access_token, |
| 34 | hash_token, |
| 35 | validate_access_code, |
| 36 | ) |
| 37 | from musehub.config import settings |
| 38 | from musehub.db.models import User |
| 39 | |
| 40 | |
| 41 | # --------------------------------------------------------------------------- |
| 42 | # Helpers |
| 43 | # --------------------------------------------------------------------------- |
| 44 | |
| 45 | def _make_expired_token(user_id: str = "test-user-id") -> str: |
| 46 | """Return a structurally valid JWT that expired one hour ago.""" |
| 47 | secret = settings.access_token_secret or "test-secret" |
| 48 | now = datetime.now(tz=timezone.utc) |
| 49 | payload = { |
| 50 | "type": "access", |
| 51 | "sub": user_id, |
| 52 | "iat": int((now - timedelta(hours=2)).timestamp()), |
| 53 | "exp": int((now - timedelta(hours=1)).timestamp()), |
| 54 | } |
| 55 | return jwt.encode(payload, secret, algorithm=settings.access_token_algorithm) |
| 56 | |
| 57 | |
| 58 | def _make_tampered_token(valid_token: str) -> str: |
| 59 | """Flip one character in the signature portion of a valid JWT.""" |
| 60 | header, payload, sig = valid_token.rsplit(".", 2) |
| 61 | bad_sig = sig[:-1] + ("A" if sig[-1] != "A" else "B") |
| 62 | return f"{header}.{payload}.{bad_sig}" |
| 63 | |
| 64 | |
| 65 | def _make_none_alg_token(user_id: str = "test-user-id") -> str: |
| 66 | """Craft a JWT with alg=none (classic signature-bypass attempt).""" |
| 67 | import base64, json |
| 68 | header = base64.urlsafe_b64encode( |
| 69 | json.dumps({"alg": "none", "typ": "JWT"}).encode() |
| 70 | ).rstrip(b"=").decode() |
| 71 | now = int(datetime.now(tz=timezone.utc).timestamp()) |
| 72 | body = base64.urlsafe_b64encode( |
| 73 | json.dumps({"type": "access", "sub": user_id, "iat": now, "exp": now + 3600}).encode() |
| 74 | ).rstrip(b"=").decode() |
| 75 | return f"{header}.{body}." |
| 76 | |
| 77 | |
| 78 | # --------------------------------------------------------------------------- |
| 79 | # validate_access_code unit tests |
| 80 | # --------------------------------------------------------------------------- |
| 81 | |
| 82 | class TestValidateAccessCodeUnit: |
| 83 | """Direct unit tests on the token validator (no HTTP).""" |
| 84 | |
| 85 | def test_valid_token_returns_claims(self) -> None: |
| 86 | token = create_access_token(user_id="u1", expires_hours=1) |
| 87 | claims = validate_access_code(token) |
| 88 | assert claims["sub"] == "u1" |
| 89 | assert claims["type"] == "access" |
| 90 | |
| 91 | def test_expired_token_raises(self) -> None: |
| 92 | token = _make_expired_token() |
| 93 | with pytest.raises(AccessCodeError, match="expired"): |
| 94 | validate_access_code(token) |
| 95 | |
| 96 | def test_tampered_signature_raises(self) -> None: |
| 97 | token = create_access_token(user_id="u2", expires_hours=1) |
| 98 | bad = _make_tampered_token(token) |
| 99 | with pytest.raises(AccessCodeError): |
| 100 | validate_access_code(bad) |
| 101 | |
| 102 | def test_garbage_string_raises(self) -> None: |
| 103 | with pytest.raises(AccessCodeError): |
| 104 | validate_access_code("not.a.jwt") |
| 105 | |
| 106 | def test_none_algorithm_rejected(self) -> None: |
| 107 | token = _make_none_alg_token() |
| 108 | with pytest.raises(AccessCodeError): |
| 109 | validate_access_code(token) |
| 110 | |
| 111 | def test_missing_type_claim_raises(self) -> None: |
| 112 | secret = settings.access_token_secret or "test-secret" |
| 113 | now = int(datetime.now(tz=timezone.utc).timestamp()) |
| 114 | payload = {"sub": "u3", "iat": now, "exp": now + 3600} |
| 115 | token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm) |
| 116 | with pytest.raises(AccessCodeError, match="Invalid token type"): |
| 117 | validate_access_code(token) |
| 118 | |
| 119 | def test_wrong_type_claim_raises(self) -> None: |
| 120 | secret = settings.access_token_secret or "test-secret" |
| 121 | now = int(datetime.now(tz=timezone.utc).timestamp()) |
| 122 | payload = {"type": "refresh", "sub": "u4", "iat": now, "exp": now + 3600} |
| 123 | token = jwt.encode(payload, secret, algorithm=settings.access_token_algorithm) |
| 124 | with pytest.raises(AccessCodeError, match="Invalid token type"): |
| 125 | validate_access_code(token) |
| 126 | |
| 127 | def test_admin_token_has_role_claim(self) -> None: |
| 128 | token = create_access_token(user_id="admin1", expires_hours=1, is_admin=True) |
| 129 | claims = validate_access_code(token) |
| 130 | assert claims.get("role") == "admin" |
| 131 | |
| 132 | def test_anonymous_token_has_no_sub(self) -> None: |
| 133 | token = create_access_token(expires_hours=1) |
| 134 | claims = validate_access_code(token) |
| 135 | assert "sub" not in claims |
| 136 | |
| 137 | |
| 138 | # --------------------------------------------------------------------------- |
| 139 | # Revocation cache unit tests |
| 140 | # --------------------------------------------------------------------------- |
| 141 | |
| 142 | class TestRevocationCache: |
| 143 | """Unit tests for the in-memory revocation cache.""" |
| 144 | |
| 145 | def setup_method(self) -> None: |
| 146 | clear_revocation_cache() |
| 147 | |
| 148 | def teardown_method(self) -> None: |
| 149 | clear_revocation_cache() |
| 150 | |
| 151 | def test_cache_miss_returns_none(self) -> None: |
| 152 | assert get_revocation_status("unknown-hash") is None |
| 153 | |
| 154 | def test_set_valid_status_readable(self) -> None: |
| 155 | set_revocation_status("tok1", revoked=False) |
| 156 | assert get_revocation_status("tok1") is False |
| 157 | |
| 158 | def test_set_revoked_status_readable(self) -> None: |
| 159 | set_revocation_status("tok2", revoked=True) |
| 160 | assert get_revocation_status("tok2") is True |
| 161 | |
| 162 | def test_clear_removes_all_entries(self) -> None: |
| 163 | set_revocation_status("tok3", revoked=False) |
| 164 | set_revocation_status("tok4", revoked=True) |
| 165 | clear_revocation_cache() |
| 166 | assert get_revocation_status("tok3") is None |
| 167 | assert get_revocation_status("tok4") is None |
| 168 | |
| 169 | def test_overwrite_status(self) -> None: |
| 170 | set_revocation_status("tok5", revoked=False) |
| 171 | set_revocation_status("tok5", revoked=True) |
| 172 | assert get_revocation_status("tok5") is True |
| 173 | |
| 174 | def test_hash_token_is_deterministic(self) -> None: |
| 175 | t = "some.jwt.token" |
| 176 | assert hash_token(t) == hash_token(t) |
| 177 | assert len(hash_token(t)) == 64 # SHA-256 hex |
| 178 | |
| 179 | def test_hash_token_distinct_inputs(self) -> None: |
| 180 | assert hash_token("abc") != hash_token("xyz") |
| 181 | |
| 182 | |
| 183 | # --------------------------------------------------------------------------- |
| 184 | # HTTP integration tests — invalid tokens on write endpoints |
| 185 | # --------------------------------------------------------------------------- |
| 186 | |
| 187 | @pytest.mark.anyio |
| 188 | async def test_expired_token_rejected_on_create_repo( |
| 189 | client: AsyncClient, |
| 190 | db_session: AsyncSession, |
| 191 | ) -> None: |
| 192 | """An expired JWT returns 401 on a write endpoint.""" |
| 193 | user = User(id="expired-user-id") |
| 194 | db_session.add(user) |
| 195 | await db_session.commit() |
| 196 | |
| 197 | expired = _make_expired_token(user_id="expired-user-id") |
| 198 | resp = await client.post( |
| 199 | "/api/v1/musehub/repos", |
| 200 | json={"name": "beats", "owner": "testuser"}, |
| 201 | headers={"Authorization": f"Bearer {expired}"}, |
| 202 | ) |
| 203 | assert resp.status_code == 401 |
| 204 | |
| 205 | |
| 206 | @pytest.mark.anyio |
| 207 | async def test_tampered_token_rejected_on_create_repo( |
| 208 | client: AsyncClient, |
| 209 | db_session: AsyncSession, |
| 210 | ) -> None: |
| 211 | """A tampered JWT (bad signature) returns 401 on a write endpoint.""" |
| 212 | user = User(id="tamper-user-id") |
| 213 | db_session.add(user) |
| 214 | await db_session.commit() |
| 215 | |
| 216 | valid = create_access_token(user_id="tamper-user-id", expires_hours=1) |
| 217 | tampered = _make_tampered_token(valid) |
| 218 | resp = await client.post( |
| 219 | "/api/v1/musehub/repos", |
| 220 | json={"name": "beats", "owner": "testuser"}, |
| 221 | headers={"Authorization": f"Bearer {tampered}"}, |
| 222 | ) |
| 223 | assert resp.status_code == 401 |
| 224 | |
| 225 | |
| 226 | @pytest.mark.anyio |
| 227 | async def test_garbage_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None: |
| 228 | """A completely invalid token string returns 401.""" |
| 229 | resp = await client.post( |
| 230 | "/api/v1/musehub/repos", |
| 231 | json={"name": "beats", "owner": "testuser"}, |
| 232 | headers={"Authorization": "Bearer not-a-jwt-at-all"}, |
| 233 | ) |
| 234 | assert resp.status_code == 401 |
| 235 | |
| 236 | |
| 237 | @pytest.mark.anyio |
| 238 | async def test_none_alg_token_rejected(client: AsyncClient, db_session: AsyncSession) -> None: |
| 239 | """alg=none token is rejected — signature bypass attempt must fail.""" |
| 240 | token = _make_none_alg_token() |
| 241 | resp = await client.post( |
| 242 | "/api/v1/musehub/repos", |
| 243 | json={"name": "beats", "owner": "testuser"}, |
| 244 | headers={"Authorization": f"Bearer {token}"}, |
| 245 | ) |
| 246 | assert resp.status_code == 401 |
| 247 | |
| 248 | |
| 249 | @pytest.mark.anyio |
| 250 | @pytest.mark.parametrize("endpoint,method,body", [ |
| 251 | ("/api/v1/musehub/repos", "POST", {"name": "x", "owner": "y"}), |
| 252 | ("/api/v1/musehub/repos/fake-id/issues", "POST", {"title": "t"}), |
| 253 | ("/api/v1/musehub/repos/fake-id/issues/1/close", "POST", {}), |
| 254 | ]) |
| 255 | async def test_missing_auth_header_returns_401( |
| 256 | client: AsyncClient, |
| 257 | db_session: AsyncSession, |
| 258 | endpoint: str, |
| 259 | method: str, |
| 260 | body: dict, |
| 261 | ) -> None: |
| 262 | """Write endpoints return 401 when the Authorization header is absent.""" |
| 263 | fn = getattr(client, method.lower()) |
| 264 | resp = await fn(endpoint, json=body) |
| 265 | assert resp.status_code == 401 |