gabriel / musehub public
test_wire_protocol.py python
447 lines 14.7 KB
cf1a85cf fix: patch all four critical security vulnerabilities (C1–C4) — dev → m… Gabriel Cardona <cgcardona@gmail.com> 2d ago
1 """Wire protocol endpoint tests.
2
3 Covers the three Muse CLI transport endpoints (Git-style URLs):
4 GET /{owner}/{slug}/refs
5 POST /{owner}/{slug}/push
6 POST /{owner}/{slug}/fetch
7
8 And the content-addressed CDN endpoint:
9 GET /o/{object_id}
10
11 Remote URL format (same pattern as Git):
12 muse remote add origin https://musehub.ai/cgcardona/muse
13 """
14 from __future__ import annotations
15
16 import base64
17 import os
18 import uuid
19 from datetime import datetime, timezone
20
21 import pytest
22 from httpx import AsyncClient
23 from sqlalchemy.ext.asyncio import AsyncSession
24
25 from musehub.auth.tokens import create_access_token
26 from musehub.db import musehub_models as db
27 from tests.factories import create_repo as factory_create_repo
28
29
30 # ── helpers ────────────────────────────────────────────────────────────────────
31
32 def _utc_now() -> datetime:
33 return datetime.now(tz=timezone.utc)
34
35
36 def _make_commit(repo_id: str, commit_id: str | None = None, parent: str | None = None) -> dict:
37 return {
38 "commit_id": commit_id or str(uuid.uuid4()),
39 "repo_id": repo_id,
40 "branch": "main",
41 "snapshot_id": f"snap_{uuid.uuid4().hex[:8]}",
42 "message": "chore: add test commit",
43 "committed_at": _utc_now().isoformat(),
44 "parent_commit_id": parent,
45 "author": "Test User <test@example.com>",
46 "sem_ver_bump": "patch",
47 }
48
49
50 def _make_object(content: bytes = b"hello world") -> dict:
51 oid = uuid.uuid4().hex
52 return {
53 "object_id": oid,
54 "content_b64": base64.b64encode(content).decode(),
55 "path": "README.md",
56 }
57
58
59 def _make_snapshot(snap_id: str, object_id: str) -> dict:
60 return {
61 "snapshot_id": snap_id,
62 "manifest": {"README.md": object_id},
63 "created_at": _utc_now().isoformat(),
64 }
65
66
67 @pytest.fixture(autouse=True)
68 def _tmp_objects_dir(tmp_path: object, monkeypatch: pytest.MonkeyPatch) -> None:
69 """Override object storage to use a temp directory in tests."""
70 import musehub.storage.backends as _backends
71 import musehub.services.musehub_wire as _wire_svc
72 import musehub.api.routes.wire as _wire_route
73
74 obj_dir = str(tmp_path) + "/objects" # type: ignore[operator]
75 os.makedirs(obj_dir, exist_ok=True)
76 test_backend = _backends.LocalBackend(objects_dir=obj_dir)
77 monkeypatch.setattr(_wire_svc, "get_backend", lambda: test_backend)
78 monkeypatch.setattr(_wire_route, "get_backend", lambda: test_backend)
79
80
81 @pytest.fixture
82 def auth_wire_token() -> str:
83 return create_access_token(user_id="test-user-wire", expires_hours=1)
84
85
86 @pytest.fixture
87 def wire_headers(auth_wire_token: str) -> dict[str, str]:
88 return {
89 "Authorization": f"Bearer {auth_wire_token}",
90 "Content-Type": "application/json",
91 }
92
93
94 # ── refs endpoint ──────────────────────────────────────────────────────────────
95
96 @pytest.mark.asyncio
97 async def test_refs_returns_404_for_unknown_owner_slug(client: AsyncClient) -> None:
98 resp = await client.get("/no-such-owner/no-such-slug/refs")
99 assert resp.status_code == 404
100
101
102 @pytest.mark.asyncio
103 async def test_refs_returns_branch_heads(
104 client: AsyncClient,
105 db_session: AsyncSession,
106 ) -> None:
107 repo = await factory_create_repo(db_session, slug="muse-test", domain_meta={"domain": "code"})
108 branch = db.MusehubBranch(
109 repo_id=repo.repo_id,
110 name="main",
111 head_commit_id="abc123",
112 )
113 db_session.add(branch)
114 await db_session.commit()
115
116 owner = repo.owner
117 slug = repo.slug
118 resp = await client.get(f"/{owner}/{slug}/refs")
119 assert resp.status_code == 200
120 data = resp.json()
121 assert data["repo_id"] == repo.repo_id
122 assert data["default_branch"] == "main"
123 assert data["domain"] == "code"
124 assert data["branch_heads"]["main"] == "abc123"
125
126
127 @pytest.mark.asyncio
128 async def test_refs_url_is_owner_slash_slug(
129 client: AsyncClient,
130 db_session: AsyncSession,
131 ) -> None:
132 """Confirm the remote URL pattern matches Git: /{owner}/{slug}/refs — no /wire/ prefix."""
133 repo = await factory_create_repo(db_session, slug="git-style-test")
134 owner, slug = repo.owner, repo.slug
135
136 resp = await client.get(f"/{owner}/{slug}/refs")
137 assert resp.status_code == 200
138 # Should NOT need /wire/ in the path
139 resp_wire = await client.get(f"/wire/repos/{repo.repo_id}/refs")
140 assert resp_wire.status_code == 404
141
142
143 @pytest.mark.asyncio
144 async def test_refs_empty_repo_has_empty_branch_heads(
145 client: AsyncClient,
146 db_session: AsyncSession,
147 ) -> None:
148 repo = await factory_create_repo(db_session, slug="empty-test")
149 resp = await client.get(f"/{repo.owner}/{repo.slug}/refs")
150 assert resp.status_code == 200
151 data = resp.json()
152 assert data["branch_heads"] == {}
153
154
155 # ── push endpoint ──────────────────────────────────────────────────────────────
156
157 @pytest.mark.asyncio
158 async def test_push_requires_auth(client: AsyncClient, db_session: AsyncSession) -> None:
159 repo = await factory_create_repo(db_session, slug="push-auth-test", owner_user_id="test-user-wire")
160 resp = await client.post(
161 f"/{repo.owner}/{repo.slug}/push",
162 json={"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"},
163 )
164 assert resp.status_code in (401, 403)
165
166
167 @pytest.mark.asyncio
168 async def test_push_404_for_unknown_repo(
169 client: AsyncClient,
170 wire_headers: dict,
171 ) -> None:
172 resp = await client.post(
173 "/nobody/no-such-repo/push",
174 json={"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"},
175 headers=wire_headers,
176 )
177 assert resp.status_code == 404
178
179
180 @pytest.mark.asyncio
181 async def test_push_rejected_for_non_owner(
182 client: AsyncClient,
183 db_session: AsyncSession,
184 wire_headers: dict,
185 ) -> None:
186 """Authenticated user who is NOT the repo owner must be rejected."""
187 repo = await factory_create_repo(
188 db_session,
189 slug="push-nonowner-test",
190 owner_user_id="someone-else", # different from test-user-wire
191 )
192 resp = await client.post(
193 f"/{repo.owner}/{repo.slug}/push",
194 json={"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"},
195 headers=wire_headers,
196 )
197 assert resp.status_code == 409
198 assert "not authorized" in resp.json()["detail"]
199
200
201 @pytest.mark.asyncio
202 async def test_push_ingests_commit_and_branch(
203 client: AsyncClient,
204 db_session: AsyncSession,
205 wire_headers: dict,
206 ) -> None:
207 repo = await factory_create_repo(db_session, slug="push-ingest-test", owner_user_id="test-user-wire")
208
209 commit_id = uuid.uuid4().hex
210 obj = _make_object()
211 snap_id = f"snap_{uuid.uuid4().hex[:8]}"
212 snap = _make_snapshot(snap_id, obj["object_id"])
213 commit = _make_commit(repo.repo_id, commit_id=commit_id)
214 commit["snapshot_id"] = snap_id
215
216 payload = {
217 "bundle": {
218 "commits": [commit],
219 "snapshots": [snap],
220 "objects": [obj],
221 },
222 "branch": "main",
223 "force": False,
224 }
225 resp = await client.post(
226 f"/{repo.owner}/{repo.slug}/push",
227 json=payload,
228 headers=wire_headers,
229 )
230 assert resp.status_code == 200, resp.text
231 data = resp.json()
232 assert data["ok"] is True
233 assert "main" in data["branch_heads"]
234 assert data["remote_head"] == commit_id
235
236
237 @pytest.mark.asyncio
238 async def test_push_is_idempotent(
239 client: AsyncClient,
240 db_session: AsyncSession,
241 wire_headers: dict,
242 ) -> None:
243 """Pushing the same commit twice must succeed both times."""
244 repo = await factory_create_repo(db_session, slug="push-idempotent-test", owner_user_id="test-user-wire")
245 commit = _make_commit(repo.repo_id)
246 payload = {
247 "bundle": {"commits": [commit], "snapshots": [], "objects": []},
248 "branch": "main",
249 }
250 url = f"/{repo.owner}/{repo.slug}/push"
251 resp1 = await client.post(url, json=payload, headers=wire_headers)
252 assert resp1.status_code == 200
253 resp2 = await client.post(url, json=payload, headers=wire_headers)
254 assert resp2.status_code == 200
255
256
257 @pytest.mark.asyncio
258 async def test_push_non_fast_forward_rejected(
259 client: AsyncClient,
260 db_session: AsyncSession,
261 wire_headers: dict,
262 ) -> None:
263 repo = await factory_create_repo(db_session, slug="push-nff-test", owner_user_id="test-user-wire")
264 existing_commit_id = uuid.uuid4().hex
265 branch = db.MusehubBranch(
266 repo_id=repo.repo_id,
267 name="main",
268 head_commit_id=existing_commit_id,
269 )
270 db_session.add(branch)
271 await db_session.commit()
272
273 # Push a commit without existing_commit_id as parent
274 new_commit = _make_commit(repo.repo_id, parent=None)
275 payload = {
276 "bundle": {"commits": [new_commit], "snapshots": [], "objects": []},
277 "branch": "main",
278 "force": False,
279 }
280 resp = await client.post(f"/{repo.owner}/{repo.slug}/push", json=payload, headers=wire_headers)
281 assert resp.status_code == 409 # 409 Conflict for non-fast-forward
282 assert "non-fast-forward" in resp.json()["detail"]
283
284
285 @pytest.mark.asyncio
286 async def test_push_force_overwrites_branch(
287 client: AsyncClient,
288 db_session: AsyncSession,
289 wire_headers: dict,
290 ) -> None:
291 repo = await factory_create_repo(db_session, slug="push-force-test", owner_user_id="test-user-wire")
292 old_head = uuid.uuid4().hex
293 branch = db.MusehubBranch(
294 repo_id=repo.repo_id,
295 name="main",
296 head_commit_id=old_head,
297 )
298 db_session.add(branch)
299 await db_session.commit()
300
301 new_commit = _make_commit(repo.repo_id, parent=None)
302 payload = {
303 "bundle": {"commits": [new_commit], "snapshots": [], "objects": []},
304 "branch": "main",
305 "force": True,
306 }
307 resp = await client.post(f"/{repo.owner}/{repo.slug}/push", json=payload, headers=wire_headers)
308 assert resp.status_code == 200
309 data = resp.json()
310 assert data["ok"] is True
311 assert data["branch_heads"]["main"] != old_head
312
313
314 # ── fetch endpoint ─────────────────────────────────────────────────────────────
315
316 @pytest.mark.asyncio
317 async def test_fetch_404_for_unknown_repo(client: AsyncClient) -> None:
318 resp = await client.post(
319 "/nobody/no-such-repo/fetch",
320 json={"want": [], "have": []},
321 )
322 assert resp.status_code == 404
323
324
325 @pytest.mark.asyncio
326 async def test_fetch_empty_want_returns_empty_bundle(
327 client: AsyncClient,
328 db_session: AsyncSession,
329 ) -> None:
330 repo = await factory_create_repo(db_session, slug="fetch-empty-test")
331 resp = await client.post(
332 f"/{repo.owner}/{repo.slug}/fetch",
333 json={"want": [], "have": []},
334 )
335 assert resp.status_code == 200
336 data = resp.json()
337 assert data["commits"] == []
338 assert data["snapshots"] == []
339 assert data["objects"] == []
340
341
342 @pytest.mark.asyncio
343 async def test_fetch_returns_missing_commits(
344 client: AsyncClient,
345 db_session: AsyncSession,
346 ) -> None:
347 repo = await factory_create_repo(db_session, slug="fetch-commits-test")
348 commit_id = uuid.uuid4().hex
349 commit_row = db.MusehubCommit(
350 commit_id=commit_id,
351 repo_id=repo.repo_id,
352 branch="main",
353 parent_ids=[],
354 message="initial commit",
355 author="Test",
356 timestamp=_utc_now(),
357 snapshot_id=None,
358 commit_meta={},
359 )
360 branch_row = db.MusehubBranch(
361 repo_id=repo.repo_id,
362 name="main",
363 head_commit_id=commit_id,
364 )
365 db_session.add(commit_row)
366 db_session.add(branch_row)
367 await db_session.commit()
368
369 resp = await client.post(
370 f"/{repo.owner}/{repo.slug}/fetch",
371 json={"want": [commit_id], "have": []},
372 )
373 assert resp.status_code == 200
374 data = resp.json()
375 assert len(data["commits"]) == 1
376 assert data["commits"][0]["commit_id"] == commit_id
377 assert data["branch_heads"]["main"] == commit_id
378
379
380 # ── content-addressed CDN ──────────────────────────────────────────────────────
381
382 @pytest.mark.asyncio
383 async def test_object_cdn_returns_404_for_missing(client: AsyncClient) -> None:
384 resp = await client.get("/o/nonexistent-sha-12345")
385 assert resp.status_code == 404
386
387
388 # ── unit tests ─────────────────────────────────────────────────────────────────
389
390 @pytest.mark.asyncio
391 async def test_wire_models_parse_correctly() -> None:
392 """WireBundle Pydantic parsing mirrors Muse CLI format."""
393 from musehub.models.wire import WireBundle, WirePushRequest
394
395 commit_dict = {
396 "commit_id": "abc123",
397 "message": "feat: add track",
398 "committed_at": "2026-03-19T10:00:00+00:00",
399 "author": "Gabriel <g@example.com>",
400 "sem_ver_bump": "minor",
401 "breaking_changes": [],
402 "agent_id": "",
403 "format_version": 5,
404 }
405 req = WirePushRequest(
406 bundle=WireBundle(commits=[commit_dict], snapshots=[], objects=[]), # type: ignore[list-item]
407 branch="main",
408 force=False,
409 )
410 assert req.bundle.commits[0].commit_id == "abc123"
411 assert req.bundle.commits[0].sem_ver_bump == "minor"
412 assert req.force is False
413
414
415 @pytest.mark.asyncio
416 async def test_topological_sort_orders_parents_first() -> None:
417 from musehub.models.wire import WireCommit
418 from musehub.services.musehub_wire import _topological_sort
419
420 c1 = WireCommit(commit_id="parent", message="parent")
421 c2 = WireCommit(commit_id="child", message="child", parent_commit_id="parent")
422 sorted_ = _topological_sort([c2, c1])
423 ids = [c.commit_id for c in sorted_]
424 assert ids.index("parent") < ids.index("child")
425
426
427 @pytest.mark.asyncio
428 async def test_remote_url_format_matches_git_pattern(
429 client: AsyncClient,
430 db_session: AsyncSession,
431 ) -> None:
432 """The remote URL is /{owner}/{slug} — no /wire/ prefix, no UUID.
433
434 This mirrors Git:
435 git remote add origin https://github.com/owner/repo
436 versus UUID-based alternatives like:
437 muse remote add origin https://musehub.ai/wire/repos/550e8400-.../
438 """
439 repo = await factory_create_repo(db_session, slug="url-format-test")
440
441 # /{owner}/{slug}/refs must work
442 resp = await client.get(f"/{repo.owner}/{repo.slug}/refs")
443 assert resp.status_code == 200
444
445 # The response confirms which repo was resolved — no UUID needed in the URL
446 data = resp.json()
447 assert data["repo_id"] == repo.repo_id