gabriel / musehub public
test_high_security_h2_h8.py python
371 lines 12.9 KB
4a63ff4c feat(mwp): full msgpack wire protocol — replace JSON+base64 on all push… Gabriel Cardona <cgcardona@gmail.com> 12h ago
1 """Regression tests for HIGH security fixes H2–H8.
2
3 H1 (private repo visibility) is already covered by test_wire_protocol.py.
4
5 Tests grouped by finding:
6 H2 — Namespace squatting in musehub_publish_domain
7 H3 — Unbounded push bundle (Pydantic max_length enforcement)
8 H4 — object content size check (MAX_OBJECT_BYTES)
9 H5 — Batched BFS fetch (was N sequential queries)
10 H6 — MCP session store cap → 503
11 H7 — muse_push per-user storage quota
12 H8 — compute_pull_delta pagination (500 objects/page)
13 """
14 from __future__ import annotations
15
16 import uuid
17 from datetime import datetime, timezone
18
19 import pytest
20 import pytest_asyncio
21 from httpx import AsyncClient
22 from sqlalchemy.ext.asyncio import AsyncSession
23
24 from musehub.auth.tokens import create_access_token
25 from musehub.mcp.session import (
26 SessionCapacityError,
27 _MAX_SESSIONS,
28 _SESSIONS,
29 create_session,
30 delete_session,
31 )
32 from musehub.models.wire import (
33 MAX_OBJECT_BYTES,
34 MAX_COMMITS_PER_PUSH,
35 MAX_OBJECTS_PER_PUSH,
36 MAX_WANT_PER_FETCH,
37 WireBundle,
38 WireFetchRequest,
39 WireObject,
40 )
41 from musehub.services.musehub_mcp_executor import execute_musehub_publish_domain
42 from musehub.services.musehub_sync import compute_pull_delta, _PULL_OBJECTS_PAGE_SIZE
43 from musehub.db import musehub_models as db
44 from tests.factories import create_repo as factory_create_repo
45
46
47 # ── fixtures ──────────────────────────────────────────────────────────────────
48
49 @pytest.fixture
50 def wire_token() -> str:
51 return create_access_token(user_id="test-user-wire", expires_hours=1)
52
53
54 @pytest.fixture
55 def wire_headers(wire_token: str) -> dict[str, str]:
56 return {
57 "Authorization": f"Bearer {wire_token}",
58 "Content-Type": "application/json",
59 }
60
61
62 # ── H2: namespace squatting ───────────────────────────────────────────────────
63
64 @pytest.mark.asyncio
65 async def test_publish_domain_blocked_if_handle_owned_by_other(
66 db_session: AsyncSession,
67 ) -> None:
68 """A user cannot publish under a handle that belongs to a different account."""
69 other_user_id = str(uuid.uuid4())
70 identity = db.MusehubIdentity(
71 id=other_user_id,
72 handle="alice",
73 identity_type="human",
74 )
75 db_session.add(identity)
76 await db_session.commit()
77
78 # attacker tries to publish under @alice but their user_id is different
79 result = await execute_musehub_publish_domain(
80 author_slug="alice",
81 slug="hijacked-domain",
82 display_name="Hijacked",
83 description="Should be rejected",
84 capabilities={},
85 viewer_type="generic",
86 user_id="attacker-user-id", # != other_user_id
87 )
88 assert result.ok is False
89 assert result.error_code == "forbidden"
90 assert "alice" in (result.error_message or "")
91
92
93 @pytest.mark.asyncio
94 async def test_publish_domain_allowed_if_handle_matches_caller(
95 db_session: AsyncSession,
96 ) -> None:
97 """A user CAN publish under their own handle."""
98 owner_id = str(uuid.uuid4())
99 identity = db.MusehubIdentity(
100 id=owner_id,
101 handle="bobthebuilder",
102 identity_type="human",
103 )
104 db_session.add(identity)
105 await db_session.commit()
106
107 result = await execute_musehub_publish_domain(
108 author_slug="bobthebuilder",
109 slug=f"my-domain-{uuid.uuid4().hex[:6]}",
110 display_name="Bob's Domain",
111 description="Legitimate publish",
112 capabilities={},
113 viewer_type="generic",
114 user_id=owner_id,
115 )
116 # ok=True means the squatting check passed (may fail at DB level if domains
117 # table requires more setup, but the guard itself did not reject it)
118 assert result.error_code != "forbidden"
119
120
121 @pytest.mark.asyncio
122 async def test_publish_domain_allowed_if_no_identity_registered(
123 db_session: AsyncSession,
124 ) -> None:
125 """Publishing is allowed when no identity row exists for the handle yet."""
126 result = await execute_musehub_publish_domain(
127 author_slug="brand-new-user",
128 slug=f"first-domain-{uuid.uuid4().hex[:6]}",
129 display_name="First Domain",
130 description="No identity row yet — should not be forbidden",
131 capabilities={},
132 viewer_type="generic",
133 user_id="some-user-id",
134 )
135 assert result.error_code != "forbidden"
136
137
138 # ── H3: unbounded push bundle ─────────────────────────────────────────────────
139
140 def test_wire_bundle_commits_capped() -> None:
141 """WireBundle rejects more than MAX_COMMITS_PER_PUSH commits."""
142 from pydantic import ValidationError
143 commits = [{"commit_id": f"c{i}", "repo_id": "r"} for i in range(MAX_COMMITS_PER_PUSH + 1)]
144 with pytest.raises(ValidationError, match="List should have at most"):
145 WireBundle(commits=commits) # type: ignore[arg-type]
146
147
148 def test_wire_bundle_objects_capped() -> None:
149 """WireBundle rejects more than MAX_OBJECTS_PER_PUSH objects."""
150 from pydantic import ValidationError
151 objects = [{"object_id": f"o{i}", "content": b"x"} for i in range(MAX_OBJECTS_PER_PUSH + 1)]
152 with pytest.raises(ValidationError, match="List should have at most"):
153 WireBundle(objects=objects) # type: ignore[arg-type]
154
155
156 def test_wire_fetch_request_want_capped() -> None:
157 """WireFetchRequest rejects more than MAX_WANT_PER_FETCH want entries."""
158 from pydantic import ValidationError
159 with pytest.raises(ValidationError, match="List should have at most"):
160 WireFetchRequest(want=["sha"] * (MAX_WANT_PER_FETCH + 1))
161
162
163 def test_wire_bundle_at_limit_is_accepted() -> None:
164 """Exactly MAX items is valid — the cap is inclusive."""
165 bundle = WireBundle(
166 commits=[{"commit_id": f"c{i}", "repo_id": "r"} for i in range(MAX_COMMITS_PER_PUSH)], # type: ignore[arg-type]
167 objects=[{"object_id": f"o{i}", "content": b"x"} for i in range(MAX_OBJECTS_PER_PUSH)], # type: ignore[arg-type]
168 )
169 assert len(bundle.commits) == MAX_COMMITS_PER_PUSH
170 assert len(bundle.objects) == MAX_OBJECTS_PER_PUSH
171
172
173 # ── H4: object content size check ────────────────────────────────────────────
174
175 def test_wire_object_rejects_oversized_content() -> None:
176 """WireObject rejects content larger than MAX_OBJECT_BYTES."""
177 from pydantic import ValidationError
178 oversized = b"x" * (MAX_OBJECT_BYTES + 1)
179 with pytest.raises(ValidationError, match="exceeds maximum size|at most"):
180 WireObject(object_id="too-big", content=oversized)
181
182
183 def test_wire_object_accepts_at_limit() -> None:
184 """WireObject accepts content well under MAX_OBJECT_BYTES."""
185 obj = WireObject(object_id="ok", content=b"x" * 1000)
186 assert obj.object_id == "ok"
187
188
189 # ── H6: MCP session store cap ────────────────────────────────────────────────
190
191 def test_create_session_raises_when_store_full() -> None:
192 """create_session raises SessionCapacityError when _SESSIONS is at _MAX_SESSIONS."""
193 original_sessions = dict(_SESSIONS)
194 try:
195 # Fill the store to the exact cap
196 fake_ids: list[str] = []
197 for i in range(_MAX_SESSIONS - len(_SESSIONS)):
198 sid = f"fake-session-{i}"
199 from musehub.mcp.session import MCPSession
200 _SESSIONS[sid] = MCPSession(
201 session_id=sid,
202 user_id=None,
203 client_capabilities={},
204 )
205 fake_ids.append(sid)
206
207 assert len(_SESSIONS) == _MAX_SESSIONS
208 with pytest.raises(SessionCapacityError):
209 create_session(user_id="overflow-user", client_capabilities={})
210 finally:
211 # Clean up all fake sessions we injected
212 for sid in fake_ids:
213 _SESSIONS.pop(sid, None)
214
215
216 # ── H8: compute_pull_delta pagination ────────────────────────────────────────
217
218 @pytest.mark.asyncio
219 async def test_pull_delta_paginates_objects(db_session: AsyncSession) -> None:
220 """compute_pull_delta returns at most _PULL_OBJECTS_PAGE_SIZE objects per call."""
221 repo = await factory_create_repo(db_session, slug="pull-pagination-test")
222
223 # Insert _PULL_OBJECTS_PAGE_SIZE + 5 objects
224 total = _PULL_OBJECTS_PAGE_SIZE + 5
225 for i in range(total):
226 obj = db.MusehubObject(
227 object_id=f"obj-{i:05d}",
228 repo_id=repo.repo_id,
229 path=f"file_{i}.mid",
230 size_bytes=100,
231 disk_path=f"local://{repo.repo_id}/obj-{i:05d}",
232 storage_uri=f"local://{repo.repo_id}/obj-{i:05d}",
233 )
234 db_session.add(obj)
235 await db_session.commit()
236
237 result = await compute_pull_delta(
238 db_session,
239 repo_id=repo.repo_id,
240 branch="main",
241 have_commits=[],
242 have_objects=[],
243 )
244 assert len(result.objects) == _PULL_OBJECTS_PAGE_SIZE
245 assert result.has_more is True
246 assert result.next_cursor is not None
247
248
249 @pytest.mark.asyncio
250 async def test_pull_delta_second_page(db_session: AsyncSession) -> None:
251 """The second page contains the remaining objects."""
252 repo = await factory_create_repo(db_session, slug="pull-page2-test")
253
254 total = _PULL_OBJECTS_PAGE_SIZE + 3
255 for i in range(total):
256 obj = db.MusehubObject(
257 object_id=f"p2obj-{i:05d}",
258 repo_id=repo.repo_id,
259 path=f"file_{i}.mid",
260 size_bytes=50,
261 disk_path=f"local://{repo.repo_id}/p2obj-{i:05d}",
262 storage_uri=f"local://{repo.repo_id}/p2obj-{i:05d}",
263 )
264 db_session.add(obj)
265 await db_session.commit()
266
267 page1 = await compute_pull_delta(
268 db_session,
269 repo_id=repo.repo_id,
270 branch="main",
271 have_commits=[],
272 have_objects=[],
273 )
274 assert page1.has_more is True
275
276 page2 = await compute_pull_delta(
277 db_session,
278 repo_id=repo.repo_id,
279 branch="main",
280 have_commits=[],
281 have_objects=[],
282 cursor=page1.next_cursor,
283 )
284 assert len(page2.objects) == 3
285 assert page2.has_more is False
286 assert page2.next_cursor is None
287
288
289 @pytest.mark.asyncio
290 async def test_pull_delta_no_pagination_when_under_limit(db_session: AsyncSession) -> None:
291 """When fewer objects than the page limit exist, has_more is False."""
292 repo = await factory_create_repo(db_session, slug="pull-under-limit-test")
293
294 for i in range(5):
295 db_session.add(db.MusehubObject(
296 object_id=f"small-{i}",
297 repo_id=repo.repo_id,
298 path=f"f{i}.mid",
299 size_bytes=10,
300 disk_path=f"local://x/small-{i}",
301 storage_uri=f"local://x/small-{i}",
302 ))
303 await db_session.commit()
304
305 result = await compute_pull_delta(
306 db_session,
307 repo_id=repo.repo_id,
308 branch="main",
309 have_commits=[],
310 have_objects=[],
311 )
312 assert len(result.objects) == 5
313 assert result.has_more is False
314 assert result.next_cursor is None
315
316
317 # ── H5: batched BFS fetch (wire protocol) ────────────────────────────────────
318
319 @pytest.mark.asyncio
320 async def test_fetch_batched_bfs_returns_all_commits(
321 client: AsyncClient,
322 db_session: AsyncSession,
323 ) -> None:
324 """wire_fetch BFS returns all reachable commits without individual PK queries."""
325 repo = await factory_create_repo(db_session, slug="fetch-bfs-batch-test")
326
327 # Build a 3-commit chain: A → B → C (parent chain)
328 _ts = datetime.now(timezone.utc)
329 commit_a = db.MusehubCommit(
330 commit_id="bfs-commit-a",
331 repo_id=repo.repo_id,
332 branch="main",
333 message="A",
334 parent_ids=[],
335 author="tester",
336 commit_meta={},
337 timestamp=_ts,
338 )
339 commit_b = db.MusehubCommit(
340 commit_id="bfs-commit-b",
341 repo_id=repo.repo_id,
342 branch="main",
343 message="B",
344 parent_ids=["bfs-commit-a"],
345 author="tester",
346 commit_meta={},
347 timestamp=_ts,
348 )
349 commit_c = db.MusehubCommit(
350 commit_id="bfs-commit-c",
351 repo_id=repo.repo_id,
352 branch="main",
353 message="C",
354 parent_ids=["bfs-commit-b"],
355 author="tester",
356 commit_meta={},
357 timestamp=_ts,
358 )
359 db_session.add_all([commit_a, commit_b, commit_c])
360 await db_session.commit()
361
362 resp = await client.post(
363 f"/{repo.owner}/{repo.slug}/fetch",
364 json={"want": ["bfs-commit-c"], "have": []},
365 )
366 assert resp.status_code == 200
367 data = resp.json()
368 commit_ids = {c["commit_id"] for c in data.get("commits", [])}
369 assert "bfs-commit-a" in commit_ids
370 assert "bfs-commit-b" in commit_ids
371 assert "bfs-commit-c" in commit_ids