test_high_security_h2_h8.py
python
| 1 | """Regression tests for HIGH security fixes H2–H8. |
| 2 | |
| 3 | H1 (private repo visibility) is already covered by test_wire_protocol.py. |
| 4 | |
| 5 | Tests grouped by finding: |
| 6 | H2 — Namespace squatting in musehub_publish_domain |
| 7 | H3 — Unbounded push bundle (Pydantic max_length enforcement) |
| 8 | H4 — object content size check (MAX_OBJECT_BYTES) |
| 9 | H5 — Batched BFS fetch (was N sequential queries) |
| 10 | H6 — MCP session store cap → 503 |
| 11 | H7 — muse_push per-user storage quota |
| 12 | H8 — compute_pull_delta pagination (500 objects/page) |
| 13 | """ |
| 14 | from __future__ import annotations |
| 15 | |
| 16 | import uuid |
| 17 | from datetime import datetime, timezone |
| 18 | |
| 19 | import pytest |
| 20 | import pytest_asyncio |
| 21 | from httpx import AsyncClient |
| 22 | from sqlalchemy.ext.asyncio import AsyncSession |
| 23 | |
| 24 | from musehub.auth.tokens import create_access_token |
| 25 | from musehub.mcp.session import ( |
| 26 | SessionCapacityError, |
| 27 | _MAX_SESSIONS, |
| 28 | _SESSIONS, |
| 29 | create_session, |
| 30 | delete_session, |
| 31 | ) |
| 32 | from musehub.models.wire import ( |
| 33 | MAX_OBJECT_BYTES, |
| 34 | MAX_COMMITS_PER_PUSH, |
| 35 | MAX_OBJECTS_PER_PUSH, |
| 36 | MAX_WANT_PER_FETCH, |
| 37 | WireBundle, |
| 38 | WireFetchRequest, |
| 39 | WireObject, |
| 40 | ) |
| 41 | from musehub.services.musehub_mcp_executor import execute_musehub_publish_domain |
| 42 | from musehub.services.musehub_sync import compute_pull_delta, _PULL_OBJECTS_PAGE_SIZE |
| 43 | from musehub.db import musehub_models as db |
| 44 | from tests.factories import create_repo as factory_create_repo |
| 45 | |
| 46 | |
| 47 | # ── fixtures ────────────────────────────────────────────────────────────────── |
| 48 | |
| 49 | @pytest.fixture |
| 50 | def wire_token() -> str: |
| 51 | return create_access_token(user_id="test-user-wire", expires_hours=1) |
| 52 | |
| 53 | |
| 54 | @pytest.fixture |
| 55 | def wire_headers(wire_token: str) -> dict[str, str]: |
| 56 | return { |
| 57 | "Authorization": f"Bearer {wire_token}", |
| 58 | "Content-Type": "application/json", |
| 59 | } |
| 60 | |
| 61 | |
| 62 | # ── H2: namespace squatting ─────────────────────────────────────────────────── |
| 63 | |
| 64 | @pytest.mark.asyncio |
| 65 | async def test_publish_domain_blocked_if_handle_owned_by_other( |
| 66 | db_session: AsyncSession, |
| 67 | ) -> None: |
| 68 | """A user cannot publish under a handle that belongs to a different account.""" |
| 69 | other_user_id = str(uuid.uuid4()) |
| 70 | identity = db.MusehubIdentity( |
| 71 | id=other_user_id, |
| 72 | handle="alice", |
| 73 | identity_type="human", |
| 74 | ) |
| 75 | db_session.add(identity) |
| 76 | await db_session.commit() |
| 77 | |
| 78 | # attacker tries to publish under @alice but their user_id is different |
| 79 | result = await execute_musehub_publish_domain( |
| 80 | author_slug="alice", |
| 81 | slug="hijacked-domain", |
| 82 | display_name="Hijacked", |
| 83 | description="Should be rejected", |
| 84 | capabilities={}, |
| 85 | viewer_type="generic", |
| 86 | user_id="attacker-user-id", # != other_user_id |
| 87 | ) |
| 88 | assert result.ok is False |
| 89 | assert result.error_code == "forbidden" |
| 90 | assert "alice" in (result.error_message or "") |
| 91 | |
| 92 | |
| 93 | @pytest.mark.asyncio |
| 94 | async def test_publish_domain_allowed_if_handle_matches_caller( |
| 95 | db_session: AsyncSession, |
| 96 | ) -> None: |
| 97 | """A user CAN publish under their own handle.""" |
| 98 | owner_id = str(uuid.uuid4()) |
| 99 | identity = db.MusehubIdentity( |
| 100 | id=owner_id, |
| 101 | handle="bobthebuilder", |
| 102 | identity_type="human", |
| 103 | ) |
| 104 | db_session.add(identity) |
| 105 | await db_session.commit() |
| 106 | |
| 107 | result = await execute_musehub_publish_domain( |
| 108 | author_slug="bobthebuilder", |
| 109 | slug=f"my-domain-{uuid.uuid4().hex[:6]}", |
| 110 | display_name="Bob's Domain", |
| 111 | description="Legitimate publish", |
| 112 | capabilities={}, |
| 113 | viewer_type="generic", |
| 114 | user_id=owner_id, |
| 115 | ) |
| 116 | # ok=True means the squatting check passed (may fail at DB level if domains |
| 117 | # table requires more setup, but the guard itself did not reject it) |
| 118 | assert result.error_code != "forbidden" |
| 119 | |
| 120 | |
| 121 | @pytest.mark.asyncio |
| 122 | async def test_publish_domain_allowed_if_no_identity_registered( |
| 123 | db_session: AsyncSession, |
| 124 | ) -> None: |
| 125 | """Publishing is allowed when no identity row exists for the handle yet.""" |
| 126 | result = await execute_musehub_publish_domain( |
| 127 | author_slug="brand-new-user", |
| 128 | slug=f"first-domain-{uuid.uuid4().hex[:6]}", |
| 129 | display_name="First Domain", |
| 130 | description="No identity row yet — should not be forbidden", |
| 131 | capabilities={}, |
| 132 | viewer_type="generic", |
| 133 | user_id="some-user-id", |
| 134 | ) |
| 135 | assert result.error_code != "forbidden" |
| 136 | |
| 137 | |
| 138 | # ── H3: unbounded push bundle ───────────────────────────────────────────────── |
| 139 | |
| 140 | def test_wire_bundle_commits_capped() -> None: |
| 141 | """WireBundle rejects more than MAX_COMMITS_PER_PUSH commits.""" |
| 142 | from pydantic import ValidationError |
| 143 | commits = [{"commit_id": f"c{i}", "repo_id": "r"} for i in range(MAX_COMMITS_PER_PUSH + 1)] |
| 144 | with pytest.raises(ValidationError, match="List should have at most"): |
| 145 | WireBundle(commits=commits) # type: ignore[arg-type] |
| 146 | |
| 147 | |
| 148 | def test_wire_bundle_objects_capped() -> None: |
| 149 | """WireBundle rejects more than MAX_OBJECTS_PER_PUSH objects.""" |
| 150 | from pydantic import ValidationError |
| 151 | objects = [{"object_id": f"o{i}", "content": b"x"} for i in range(MAX_OBJECTS_PER_PUSH + 1)] |
| 152 | with pytest.raises(ValidationError, match="List should have at most"): |
| 153 | WireBundle(objects=objects) # type: ignore[arg-type] |
| 154 | |
| 155 | |
| 156 | def test_wire_fetch_request_want_capped() -> None: |
| 157 | """WireFetchRequest rejects more than MAX_WANT_PER_FETCH want entries.""" |
| 158 | from pydantic import ValidationError |
| 159 | with pytest.raises(ValidationError, match="List should have at most"): |
| 160 | WireFetchRequest(want=["sha"] * (MAX_WANT_PER_FETCH + 1)) |
| 161 | |
| 162 | |
| 163 | def test_wire_bundle_at_limit_is_accepted() -> None: |
| 164 | """Exactly MAX items is valid — the cap is inclusive.""" |
| 165 | bundle = WireBundle( |
| 166 | commits=[{"commit_id": f"c{i}", "repo_id": "r"} for i in range(MAX_COMMITS_PER_PUSH)], # type: ignore[arg-type] |
| 167 | objects=[{"object_id": f"o{i}", "content": b"x"} for i in range(MAX_OBJECTS_PER_PUSH)], # type: ignore[arg-type] |
| 168 | ) |
| 169 | assert len(bundle.commits) == MAX_COMMITS_PER_PUSH |
| 170 | assert len(bundle.objects) == MAX_OBJECTS_PER_PUSH |
| 171 | |
| 172 | |
| 173 | # ── H4: object content size check ──────────────────────────────────────────── |
| 174 | |
| 175 | def test_wire_object_rejects_oversized_content() -> None: |
| 176 | """WireObject rejects content larger than MAX_OBJECT_BYTES.""" |
| 177 | from pydantic import ValidationError |
| 178 | oversized = b"x" * (MAX_OBJECT_BYTES + 1) |
| 179 | with pytest.raises(ValidationError, match="exceeds maximum size|at most"): |
| 180 | WireObject(object_id="too-big", content=oversized) |
| 181 | |
| 182 | |
| 183 | def test_wire_object_accepts_at_limit() -> None: |
| 184 | """WireObject accepts content well under MAX_OBJECT_BYTES.""" |
| 185 | obj = WireObject(object_id="ok", content=b"x" * 1000) |
| 186 | assert obj.object_id == "ok" |
| 187 | |
| 188 | |
| 189 | # ── H6: MCP session store cap ──────────────────────────────────────────────── |
| 190 | |
| 191 | def test_create_session_raises_when_store_full() -> None: |
| 192 | """create_session raises SessionCapacityError when _SESSIONS is at _MAX_SESSIONS.""" |
| 193 | original_sessions = dict(_SESSIONS) |
| 194 | try: |
| 195 | # Fill the store to the exact cap |
| 196 | fake_ids: list[str] = [] |
| 197 | for i in range(_MAX_SESSIONS - len(_SESSIONS)): |
| 198 | sid = f"fake-session-{i}" |
| 199 | from musehub.mcp.session import MCPSession |
| 200 | _SESSIONS[sid] = MCPSession( |
| 201 | session_id=sid, |
| 202 | user_id=None, |
| 203 | client_capabilities={}, |
| 204 | ) |
| 205 | fake_ids.append(sid) |
| 206 | |
| 207 | assert len(_SESSIONS) == _MAX_SESSIONS |
| 208 | with pytest.raises(SessionCapacityError): |
| 209 | create_session(user_id="overflow-user", client_capabilities={}) |
| 210 | finally: |
| 211 | # Clean up all fake sessions we injected |
| 212 | for sid in fake_ids: |
| 213 | _SESSIONS.pop(sid, None) |
| 214 | |
| 215 | |
| 216 | # ── H8: compute_pull_delta pagination ──────────────────────────────────────── |
| 217 | |
| 218 | @pytest.mark.asyncio |
| 219 | async def test_pull_delta_paginates_objects(db_session: AsyncSession) -> None: |
| 220 | """compute_pull_delta returns at most _PULL_OBJECTS_PAGE_SIZE objects per call.""" |
| 221 | repo = await factory_create_repo(db_session, slug="pull-pagination-test") |
| 222 | |
| 223 | # Insert _PULL_OBJECTS_PAGE_SIZE + 5 objects |
| 224 | total = _PULL_OBJECTS_PAGE_SIZE + 5 |
| 225 | for i in range(total): |
| 226 | obj = db.MusehubObject( |
| 227 | object_id=f"obj-{i:05d}", |
| 228 | repo_id=repo.repo_id, |
| 229 | path=f"file_{i}.mid", |
| 230 | size_bytes=100, |
| 231 | disk_path=f"local://{repo.repo_id}/obj-{i:05d}", |
| 232 | storage_uri=f"local://{repo.repo_id}/obj-{i:05d}", |
| 233 | ) |
| 234 | db_session.add(obj) |
| 235 | await db_session.commit() |
| 236 | |
| 237 | result = await compute_pull_delta( |
| 238 | db_session, |
| 239 | repo_id=repo.repo_id, |
| 240 | branch="main", |
| 241 | have_commits=[], |
| 242 | have_objects=[], |
| 243 | ) |
| 244 | assert len(result.objects) == _PULL_OBJECTS_PAGE_SIZE |
| 245 | assert result.has_more is True |
| 246 | assert result.next_cursor is not None |
| 247 | |
| 248 | |
| 249 | @pytest.mark.asyncio |
| 250 | async def test_pull_delta_second_page(db_session: AsyncSession) -> None: |
| 251 | """The second page contains the remaining objects.""" |
| 252 | repo = await factory_create_repo(db_session, slug="pull-page2-test") |
| 253 | |
| 254 | total = _PULL_OBJECTS_PAGE_SIZE + 3 |
| 255 | for i in range(total): |
| 256 | obj = db.MusehubObject( |
| 257 | object_id=f"p2obj-{i:05d}", |
| 258 | repo_id=repo.repo_id, |
| 259 | path=f"file_{i}.mid", |
| 260 | size_bytes=50, |
| 261 | disk_path=f"local://{repo.repo_id}/p2obj-{i:05d}", |
| 262 | storage_uri=f"local://{repo.repo_id}/p2obj-{i:05d}", |
| 263 | ) |
| 264 | db_session.add(obj) |
| 265 | await db_session.commit() |
| 266 | |
| 267 | page1 = await compute_pull_delta( |
| 268 | db_session, |
| 269 | repo_id=repo.repo_id, |
| 270 | branch="main", |
| 271 | have_commits=[], |
| 272 | have_objects=[], |
| 273 | ) |
| 274 | assert page1.has_more is True |
| 275 | |
| 276 | page2 = await compute_pull_delta( |
| 277 | db_session, |
| 278 | repo_id=repo.repo_id, |
| 279 | branch="main", |
| 280 | have_commits=[], |
| 281 | have_objects=[], |
| 282 | cursor=page1.next_cursor, |
| 283 | ) |
| 284 | assert len(page2.objects) == 3 |
| 285 | assert page2.has_more is False |
| 286 | assert page2.next_cursor is None |
| 287 | |
| 288 | |
| 289 | @pytest.mark.asyncio |
| 290 | async def test_pull_delta_no_pagination_when_under_limit(db_session: AsyncSession) -> None: |
| 291 | """When fewer objects than the page limit exist, has_more is False.""" |
| 292 | repo = await factory_create_repo(db_session, slug="pull-under-limit-test") |
| 293 | |
| 294 | for i in range(5): |
| 295 | db_session.add(db.MusehubObject( |
| 296 | object_id=f"small-{i}", |
| 297 | repo_id=repo.repo_id, |
| 298 | path=f"f{i}.mid", |
| 299 | size_bytes=10, |
| 300 | disk_path=f"local://x/small-{i}", |
| 301 | storage_uri=f"local://x/small-{i}", |
| 302 | )) |
| 303 | await db_session.commit() |
| 304 | |
| 305 | result = await compute_pull_delta( |
| 306 | db_session, |
| 307 | repo_id=repo.repo_id, |
| 308 | branch="main", |
| 309 | have_commits=[], |
| 310 | have_objects=[], |
| 311 | ) |
| 312 | assert len(result.objects) == 5 |
| 313 | assert result.has_more is False |
| 314 | assert result.next_cursor is None |
| 315 | |
| 316 | |
| 317 | # ── H5: batched BFS fetch (wire protocol) ──────────────────────────────────── |
| 318 | |
| 319 | @pytest.mark.asyncio |
| 320 | async def test_fetch_batched_bfs_returns_all_commits( |
| 321 | client: AsyncClient, |
| 322 | db_session: AsyncSession, |
| 323 | ) -> None: |
| 324 | """wire_fetch BFS returns all reachable commits without individual PK queries.""" |
| 325 | repo = await factory_create_repo(db_session, slug="fetch-bfs-batch-test") |
| 326 | |
| 327 | # Build a 3-commit chain: A → B → C (parent chain) |
| 328 | _ts = datetime.now(timezone.utc) |
| 329 | commit_a = db.MusehubCommit( |
| 330 | commit_id="bfs-commit-a", |
| 331 | repo_id=repo.repo_id, |
| 332 | branch="main", |
| 333 | message="A", |
| 334 | parent_ids=[], |
| 335 | author="tester", |
| 336 | commit_meta={}, |
| 337 | timestamp=_ts, |
| 338 | ) |
| 339 | commit_b = db.MusehubCommit( |
| 340 | commit_id="bfs-commit-b", |
| 341 | repo_id=repo.repo_id, |
| 342 | branch="main", |
| 343 | message="B", |
| 344 | parent_ids=["bfs-commit-a"], |
| 345 | author="tester", |
| 346 | commit_meta={}, |
| 347 | timestamp=_ts, |
| 348 | ) |
| 349 | commit_c = db.MusehubCommit( |
| 350 | commit_id="bfs-commit-c", |
| 351 | repo_id=repo.repo_id, |
| 352 | branch="main", |
| 353 | message="C", |
| 354 | parent_ids=["bfs-commit-b"], |
| 355 | author="tester", |
| 356 | commit_meta={}, |
| 357 | timestamp=_ts, |
| 358 | ) |
| 359 | db_session.add_all([commit_a, commit_b, commit_c]) |
| 360 | await db_session.commit() |
| 361 | |
| 362 | resp = await client.post( |
| 363 | f"/{repo.owner}/{repo.slug}/fetch", |
| 364 | json={"want": ["bfs-commit-c"], "have": []}, |
| 365 | ) |
| 366 | assert resp.status_code == 200 |
| 367 | data = resp.json() |
| 368 | commit_ids = {c["commit_id"] for c in data.get("commits", [])} |
| 369 | assert "bfs-commit-a" in commit_ids |
| 370 | assert "bfs-commit-b" in commit_ids |
| 371 | assert "bfs-commit-c" in commit_ids |