gabriel / musehub public
test_high_security_h2_h8.py python
377 lines 13.1 KB
9396568e fix: HIGH security patch (H1–H8) — dev → main (#25) Gabriel Cardona <cgcardona@gmail.com> 2d ago
1 """Regression tests for HIGH security fixes H2–H8.
2
3 H1 (private repo visibility) is already covered by test_wire_protocol.py.
4
5 Tests grouped by finding:
6 H2 — Namespace squatting in musehub_publish_domain
7 H3 — Unbounded push bundle (Pydantic max_length enforcement)
8 H4 — content_b64 pre-decode size check
9 H5 — Batched BFS fetch (was N sequential queries)
10 H6 — MCP session store cap → 503
11 H7 — muse_push per-user storage quota
12 H8 — compute_pull_delta pagination (500 objects/page)
13 """
14 from __future__ import annotations
15
16 import base64
17 import uuid
18 from datetime import datetime, timezone
19
20 import pytest
21 import pytest_asyncio
22 from httpx import AsyncClient
23 from sqlalchemy.ext.asyncio import AsyncSession
24
25 from musehub.auth.tokens import create_access_token
26 from musehub.mcp.session import (
27 SessionCapacityError,
28 _MAX_SESSIONS,
29 _SESSIONS,
30 create_session,
31 delete_session,
32 )
33 from musehub.models.wire import (
34 MAX_B64_SIZE,
35 MAX_COMMITS_PER_PUSH,
36 MAX_OBJECTS_PER_PUSH,
37 MAX_WANT_PER_FETCH,
38 WireBundle,
39 WireFetchRequest,
40 WireObject,
41 )
42 from musehub.services.musehub_mcp_executor import execute_musehub_publish_domain
43 from musehub.services.musehub_sync import compute_pull_delta, _PULL_OBJECTS_PAGE_SIZE
44 from musehub.db import musehub_models as db
45 from tests.factories import create_repo as factory_create_repo
46
47
48 # ── fixtures ──────────────────────────────────────────────────────────────────
49
50 @pytest.fixture
51 def wire_token() -> str:
52 return create_access_token(user_id="test-user-wire", expires_hours=1)
53
54
55 @pytest.fixture
56 def wire_headers(wire_token: str) -> dict[str, str]:
57 return {
58 "Authorization": f"Bearer {wire_token}",
59 "Content-Type": "application/json",
60 }
61
62
63 # ── H2: namespace squatting ───────────────────────────────────────────────────
64
65 @pytest.mark.asyncio
66 async def test_publish_domain_blocked_if_handle_owned_by_other(
67 db_session: AsyncSession,
68 ) -> None:
69 """A user cannot publish under a handle that belongs to a different account."""
70 other_user_id = str(uuid.uuid4())
71 identity = db.MusehubIdentity(
72 id=other_user_id,
73 handle="alice",
74 identity_type="human",
75 )
76 db_session.add(identity)
77 await db_session.commit()
78
79 # attacker tries to publish under @alice but their user_id is different
80 result = await execute_musehub_publish_domain(
81 author_slug="alice",
82 slug="hijacked-domain",
83 display_name="Hijacked",
84 description="Should be rejected",
85 capabilities={},
86 viewer_type="generic",
87 user_id="attacker-user-id", # != other_user_id
88 )
89 assert result.ok is False
90 assert result.error_code == "forbidden"
91 assert "alice" in (result.error_message or "")
92
93
94 @pytest.mark.asyncio
95 async def test_publish_domain_allowed_if_handle_matches_caller(
96 db_session: AsyncSession,
97 ) -> None:
98 """A user CAN publish under their own handle."""
99 owner_id = str(uuid.uuid4())
100 identity = db.MusehubIdentity(
101 id=owner_id,
102 handle="bobthebuilder",
103 identity_type="human",
104 )
105 db_session.add(identity)
106 await db_session.commit()
107
108 result = await execute_musehub_publish_domain(
109 author_slug="bobthebuilder",
110 slug=f"my-domain-{uuid.uuid4().hex[:6]}",
111 display_name="Bob's Domain",
112 description="Legitimate publish",
113 capabilities={},
114 viewer_type="generic",
115 user_id=owner_id,
116 )
117 # ok=True means the squatting check passed (may fail at DB level if domains
118 # table requires more setup, but the guard itself did not reject it)
119 assert result.error_code != "forbidden"
120
121
122 @pytest.mark.asyncio
123 async def test_publish_domain_allowed_if_no_identity_registered(
124 db_session: AsyncSession,
125 ) -> None:
126 """Publishing is allowed when no identity row exists for the handle yet."""
127 result = await execute_musehub_publish_domain(
128 author_slug="brand-new-user",
129 slug=f"first-domain-{uuid.uuid4().hex[:6]}",
130 display_name="First Domain",
131 description="No identity row yet — should not be forbidden",
132 capabilities={},
133 viewer_type="generic",
134 user_id="some-user-id",
135 )
136 assert result.error_code != "forbidden"
137
138
139 # ── H3: unbounded push bundle ─────────────────────────────────────────────────
140
141 def test_wire_bundle_commits_capped() -> None:
142 """WireBundle rejects more than MAX_COMMITS_PER_PUSH commits."""
143 from pydantic import ValidationError
144 commits = [{"commit_id": f"c{i}", "repo_id": "r"} for i in range(MAX_COMMITS_PER_PUSH + 1)]
145 with pytest.raises(ValidationError, match="List should have at most"):
146 WireBundle(commits=commits) # type: ignore[arg-type]
147
148
149 def test_wire_bundle_objects_capped() -> None:
150 """WireBundle rejects more than MAX_OBJECTS_PER_PUSH objects."""
151 from pydantic import ValidationError
152 b64 = base64.b64encode(b"x").decode()
153 objects = [{"object_id": f"o{i}", "content_b64": b64} for i in range(MAX_OBJECTS_PER_PUSH + 1)]
154 with pytest.raises(ValidationError, match="List should have at most"):
155 WireBundle(objects=objects) # type: ignore[arg-type]
156
157
158 def test_wire_fetch_request_want_capped() -> None:
159 """WireFetchRequest rejects more than MAX_WANT_PER_FETCH want entries."""
160 from pydantic import ValidationError
161 with pytest.raises(ValidationError, match="List should have at most"):
162 WireFetchRequest(want=["sha"] * (MAX_WANT_PER_FETCH + 1))
163
164
165 def test_wire_bundle_at_limit_is_accepted() -> None:
166 """Exactly MAX items is valid — the cap is inclusive."""
167 b64 = base64.b64encode(b"x").decode()
168 bundle = WireBundle(
169 commits=[{"commit_id": f"c{i}", "repo_id": "r"} for i in range(MAX_COMMITS_PER_PUSH)], # type: ignore[arg-type]
170 objects=[{"object_id": f"o{i}", "content_b64": b64} for i in range(MAX_OBJECTS_PER_PUSH)], # type: ignore[arg-type]
171 )
172 assert len(bundle.commits) == MAX_COMMITS_PER_PUSH
173 assert len(bundle.objects) == MAX_OBJECTS_PER_PUSH
174
175
176 # ── H4: content_b64 pre-decode size check ─────────────────────────────────────
177
178 def test_wire_object_rejects_oversized_b64() -> None:
179 """WireObject rejects content_b64 longer than MAX_B64_SIZE."""
180 from pydantic import ValidationError
181 # Create a string just over the limit (no need to base64-encode — the
182 # validator checks string length, not decoded bytes).
183 oversized = "A" * (MAX_B64_SIZE + 1)
184 with pytest.raises(ValidationError, match="exceeds maximum size|at most"):
185 WireObject(object_id="too-big", content_b64=oversized)
186
187
188 def test_wire_object_accepts_at_limit() -> None:
189 """WireObject accepts content_b64 exactly at MAX_B64_SIZE chars."""
190 at_limit = base64.b64encode(b"x" * 1000).decode() # well under limit
191 obj = WireObject(object_id="ok", content_b64=at_limit)
192 assert obj.object_id == "ok"
193
194
195 # ── H6: MCP session store cap ────────────────────────────────────────────────
196
197 def test_create_session_raises_when_store_full() -> None:
198 """create_session raises SessionCapacityError when _SESSIONS is at _MAX_SESSIONS."""
199 original_sessions = dict(_SESSIONS)
200 try:
201 # Fill the store to the exact cap
202 fake_ids: list[str] = []
203 for i in range(_MAX_SESSIONS - len(_SESSIONS)):
204 sid = f"fake-session-{i}"
205 from musehub.mcp.session import MCPSession
206 _SESSIONS[sid] = MCPSession(
207 session_id=sid,
208 user_id=None,
209 client_capabilities={},
210 )
211 fake_ids.append(sid)
212
213 assert len(_SESSIONS) == _MAX_SESSIONS
214 with pytest.raises(SessionCapacityError):
215 create_session(user_id="overflow-user", client_capabilities={})
216 finally:
217 # Clean up all fake sessions we injected
218 for sid in fake_ids:
219 _SESSIONS.pop(sid, None)
220
221
222 # ── H8: compute_pull_delta pagination ────────────────────────────────────────
223
224 @pytest.mark.asyncio
225 async def test_pull_delta_paginates_objects(db_session: AsyncSession) -> None:
226 """compute_pull_delta returns at most _PULL_OBJECTS_PAGE_SIZE objects per call."""
227 repo = await factory_create_repo(db_session, slug="pull-pagination-test")
228
229 # Insert _PULL_OBJECTS_PAGE_SIZE + 5 objects
230 total = _PULL_OBJECTS_PAGE_SIZE + 5
231 for i in range(total):
232 obj = db.MusehubObject(
233 object_id=f"obj-{i:05d}",
234 repo_id=repo.repo_id,
235 path=f"file_{i}.mid",
236 size_bytes=100,
237 disk_path=f"local://{repo.repo_id}/obj-{i:05d}",
238 storage_uri=f"local://{repo.repo_id}/obj-{i:05d}",
239 )
240 db_session.add(obj)
241 await db_session.commit()
242
243 result = await compute_pull_delta(
244 db_session,
245 repo_id=repo.repo_id,
246 branch="main",
247 have_commits=[],
248 have_objects=[],
249 )
250 assert len(result.objects) == _PULL_OBJECTS_PAGE_SIZE
251 assert result.has_more is True
252 assert result.next_cursor is not None
253
254
255 @pytest.mark.asyncio
256 async def test_pull_delta_second_page(db_session: AsyncSession) -> None:
257 """The second page contains the remaining objects."""
258 repo = await factory_create_repo(db_session, slug="pull-page2-test")
259
260 total = _PULL_OBJECTS_PAGE_SIZE + 3
261 for i in range(total):
262 obj = db.MusehubObject(
263 object_id=f"p2obj-{i:05d}",
264 repo_id=repo.repo_id,
265 path=f"file_{i}.mid",
266 size_bytes=50,
267 disk_path=f"local://{repo.repo_id}/p2obj-{i:05d}",
268 storage_uri=f"local://{repo.repo_id}/p2obj-{i:05d}",
269 )
270 db_session.add(obj)
271 await db_session.commit()
272
273 page1 = await compute_pull_delta(
274 db_session,
275 repo_id=repo.repo_id,
276 branch="main",
277 have_commits=[],
278 have_objects=[],
279 )
280 assert page1.has_more is True
281
282 page2 = await compute_pull_delta(
283 db_session,
284 repo_id=repo.repo_id,
285 branch="main",
286 have_commits=[],
287 have_objects=[],
288 cursor=page1.next_cursor,
289 )
290 assert len(page2.objects) == 3
291 assert page2.has_more is False
292 assert page2.next_cursor is None
293
294
295 @pytest.mark.asyncio
296 async def test_pull_delta_no_pagination_when_under_limit(db_session: AsyncSession) -> None:
297 """When fewer objects than the page limit exist, has_more is False."""
298 repo = await factory_create_repo(db_session, slug="pull-under-limit-test")
299
300 for i in range(5):
301 db_session.add(db.MusehubObject(
302 object_id=f"small-{i}",
303 repo_id=repo.repo_id,
304 path=f"f{i}.mid",
305 size_bytes=10,
306 disk_path=f"local://x/small-{i}",
307 storage_uri=f"local://x/small-{i}",
308 ))
309 await db_session.commit()
310
311 result = await compute_pull_delta(
312 db_session,
313 repo_id=repo.repo_id,
314 branch="main",
315 have_commits=[],
316 have_objects=[],
317 )
318 assert len(result.objects) == 5
319 assert result.has_more is False
320 assert result.next_cursor is None
321
322
323 # ── H5: batched BFS fetch (wire protocol) ────────────────────────────────────
324
325 @pytest.mark.asyncio
326 async def test_fetch_batched_bfs_returns_all_commits(
327 client: AsyncClient,
328 db_session: AsyncSession,
329 ) -> None:
330 """wire_fetch BFS returns all reachable commits without individual PK queries."""
331 repo = await factory_create_repo(db_session, slug="fetch-bfs-batch-test")
332
333 # Build a 3-commit chain: A → B → C (parent chain)
334 _ts = datetime.now(timezone.utc)
335 commit_a = db.MusehubCommit(
336 commit_id="bfs-commit-a",
337 repo_id=repo.repo_id,
338 branch="main",
339 message="A",
340 parent_ids=[],
341 author="tester",
342 commit_meta={},
343 timestamp=_ts,
344 )
345 commit_b = db.MusehubCommit(
346 commit_id="bfs-commit-b",
347 repo_id=repo.repo_id,
348 branch="main",
349 message="B",
350 parent_ids=["bfs-commit-a"],
351 author="tester",
352 commit_meta={},
353 timestamp=_ts,
354 )
355 commit_c = db.MusehubCommit(
356 commit_id="bfs-commit-c",
357 repo_id=repo.repo_id,
358 branch="main",
359 message="C",
360 parent_ids=["bfs-commit-b"],
361 author="tester",
362 commit_meta={},
363 timestamp=_ts,
364 )
365 db_session.add_all([commit_a, commit_b, commit_c])
366 await db_session.commit()
367
368 resp = await client.post(
369 f"/{repo.owner}/{repo.slug}/fetch",
370 json={"want": ["bfs-commit-c"], "have": []},
371 )
372 assert resp.status_code == 200
373 data = resp.json()
374 commit_ids = {c["commit_id"] for c in data.get("commits", [])}
375 assert "bfs-commit-a" in commit_ids
376 assert "bfs-commit-b" in commit_ids
377 assert "bfs-commit-c" in commit_ids