gabriel / musehub public
test_wire_protocol.py python
457 lines 15.4 KB
4a63ff4c feat(mwp): full msgpack wire protocol — replace JSON+base64 on all push… Gabriel Cardona <cgcardona@gmail.com> 12h ago
1 """Wire protocol endpoint tests.
2
3 Covers the three Muse CLI transport endpoints (Git-style URLs):
4 GET /{owner}/{slug}/refs
5 POST /{owner}/{slug}/push
6 POST /{owner}/{slug}/fetch
7
8 And the content-addressed CDN endpoint:
9 GET /o/{object_id}
10
11 Remote URL format (same pattern as Git):
12 muse remote add origin https://musehub.ai/cgcardona/muse
13 """
14 from __future__ import annotations
15
16 import os
17 import uuid
18 from datetime import datetime, timezone
19
20 import msgpack
21 import pytest
22 from httpx import AsyncClient
23 from sqlalchemy.ext.asyncio import AsyncSession
24
25 from musehub.auth.tokens import create_access_token
26 from musehub.db import musehub_models as db
27 from tests.factories import create_repo as factory_create_repo
28
29
30 # ── helpers ────────────────────────────────────────────────────────────────────
31
32 def _utc_now() -> datetime:
33 return datetime.now(tz=timezone.utc)
34
35
36 def _make_commit(repo_id: str, commit_id: str | None = None, parent: str | None = None) -> dict:
37 return {
38 "commit_id": commit_id or str(uuid.uuid4()),
39 "repo_id": repo_id,
40 "branch": "main",
41 "snapshot_id": f"snap_{uuid.uuid4().hex[:8]}",
42 "message": "chore: add test commit",
43 "committed_at": _utc_now().isoformat(),
44 "parent_commit_id": parent,
45 "author": "Test User <test@example.com>",
46 "sem_ver_bump": "patch",
47 }
48
49
50 def _make_object(content: bytes = b"hello world") -> dict:
51 oid = uuid.uuid4().hex
52 return {
53 "object_id": oid,
54 "content": content,
55 "path": "README.md",
56 }
57
58
59 def _make_snapshot(snap_id: str, object_id: str) -> dict:
60 return {
61 "snapshot_id": snap_id,
62 "manifest": {"README.md": object_id},
63 "created_at": _utc_now().isoformat(),
64 }
65
66
67 @pytest.fixture(autouse=True)
68 def _tmp_objects_dir(tmp_path: object, monkeypatch: pytest.MonkeyPatch) -> None:
69 """Override object storage to use a temp directory in tests."""
70 import musehub.storage.backends as _backends
71 import musehub.services.musehub_wire as _wire_svc
72 import musehub.api.routes.wire as _wire_route
73
74 obj_dir = str(tmp_path) + "/objects" # type: ignore[operator]
75 os.makedirs(obj_dir, exist_ok=True)
76 test_backend = _backends.LocalBackend(objects_dir=obj_dir)
77 monkeypatch.setattr(_wire_svc, "get_backend", lambda: test_backend)
78 monkeypatch.setattr(_wire_route, "get_backend", lambda: test_backend)
79
80
81 @pytest.fixture
82 def auth_wire_token() -> str:
83 return create_access_token(user_id="test-user-wire", expires_hours=1)
84
85
86 @pytest.fixture
87 def wire_headers(auth_wire_token: str) -> dict[str, str]:
88 return {
89 "Authorization": f"Bearer {auth_wire_token}",
90 "Content-Type": "application/x-msgpack",
91 "Accept": "application/x-msgpack",
92 }
93
94
95 def _mp(data: object) -> bytes:
96 """Encode data as msgpack for test request bodies."""
97 return msgpack.packb(data, use_bin_type=True)
98
99
100 # ── refs endpoint ──────────────────────────────────────────────────────────────
101
102 @pytest.mark.asyncio
103 async def test_refs_returns_404_for_unknown_owner_slug(client: AsyncClient) -> None:
104 resp = await client.get("/no-such-owner/no-such-slug/refs")
105 assert resp.status_code == 404
106
107
108 @pytest.mark.asyncio
109 async def test_refs_returns_branch_heads(
110 client: AsyncClient,
111 db_session: AsyncSession,
112 ) -> None:
113 repo = await factory_create_repo(db_session, slug="muse-test", domain_meta={"domain": "code"})
114 branch = db.MusehubBranch(
115 repo_id=repo.repo_id,
116 name="main",
117 head_commit_id="abc123",
118 )
119 db_session.add(branch)
120 await db_session.commit()
121
122 owner = repo.owner
123 slug = repo.slug
124 resp = await client.get(f"/{owner}/{slug}/refs")
125 assert resp.status_code == 200
126 data = resp.json()
127 assert data["repo_id"] == repo.repo_id
128 assert data["default_branch"] == "main"
129 assert data["domain"] == "code"
130 assert data["branch_heads"]["main"] == "abc123"
131
132
133 @pytest.mark.asyncio
134 async def test_refs_url_is_owner_slash_slug(
135 client: AsyncClient,
136 db_session: AsyncSession,
137 ) -> None:
138 """Confirm the remote URL pattern matches Git: /{owner}/{slug}/refs — no /wire/ prefix."""
139 repo = await factory_create_repo(db_session, slug="git-style-test")
140 owner, slug = repo.owner, repo.slug
141
142 resp = await client.get(f"/{owner}/{slug}/refs")
143 assert resp.status_code == 200
144 # Should NOT need /wire/ in the path
145 resp_wire = await client.get(f"/wire/repos/{repo.repo_id}/refs")
146 assert resp_wire.status_code == 404
147
148
149 @pytest.mark.asyncio
150 async def test_refs_empty_repo_has_empty_branch_heads(
151 client: AsyncClient,
152 db_session: AsyncSession,
153 ) -> None:
154 repo = await factory_create_repo(db_session, slug="empty-test")
155 resp = await client.get(f"/{repo.owner}/{repo.slug}/refs")
156 assert resp.status_code == 200
157 data = resp.json()
158 assert data["branch_heads"] == {}
159
160
161 # ── push endpoint ──────────────────────────────────────────────────────────────
162
163 @pytest.mark.asyncio
164 async def test_push_requires_auth(client: AsyncClient, db_session: AsyncSession) -> None:
165 repo = await factory_create_repo(db_session, slug="push-auth-test", owner_user_id="test-user-wire")
166 resp = await client.post(
167 f"/{repo.owner}/{repo.slug}/push",
168 content=_mp({"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"}),
169 headers={"Content-Type": "application/x-msgpack"},
170 )
171 assert resp.status_code in (401, 403)
172
173
174 @pytest.mark.asyncio
175 async def test_push_404_for_unknown_repo(
176 client: AsyncClient,
177 wire_headers: dict,
178 ) -> None:
179 resp = await client.post(
180 "/nobody/no-such-repo/push",
181 content=_mp({"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"}),
182 headers=wire_headers,
183 )
184 assert resp.status_code == 404
185
186
187 @pytest.mark.asyncio
188 async def test_push_rejected_for_non_owner(
189 client: AsyncClient,
190 db_session: AsyncSession,
191 wire_headers: dict,
192 ) -> None:
193 """Authenticated user who is NOT the repo owner must be rejected."""
194 repo = await factory_create_repo(
195 db_session,
196 slug="push-nonowner-test",
197 owner_user_id="someone-else", # different from test-user-wire
198 )
199 resp = await client.post(
200 f"/{repo.owner}/{repo.slug}/push",
201 content=_mp({"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"}),
202 headers=wire_headers,
203 )
204 assert resp.status_code == 409
205 assert "not authorized" in resp.json()["detail"]
206
207
208 @pytest.mark.asyncio
209 async def test_push_ingests_commit_and_branch(
210 client: AsyncClient,
211 db_session: AsyncSession,
212 wire_headers: dict,
213 ) -> None:
214 repo = await factory_create_repo(db_session, slug="push-ingest-test", owner_user_id="test-user-wire")
215
216 commit_id = uuid.uuid4().hex
217 obj = _make_object()
218 snap_id = f"snap_{uuid.uuid4().hex[:8]}"
219 snap = _make_snapshot(snap_id, obj["object_id"])
220 commit = _make_commit(repo.repo_id, commit_id=commit_id)
221 commit["snapshot_id"] = snap_id
222
223 payload = {
224 "bundle": {
225 "commits": [commit],
226 "snapshots": [snap],
227 "objects": [obj],
228 },
229 "branch": "main",
230 "force": False,
231 }
232 resp = await client.post(
233 f"/{repo.owner}/{repo.slug}/push",
234 content=_mp(payload),
235 headers=wire_headers,
236 )
237 assert resp.status_code == 200, resp.text
238 data = msgpack.unpackb(resp.content, raw=False)
239 assert data["ok"] is True
240 assert "main" in data["branch_heads"]
241 assert data["remote_head"] == commit_id
242
243
244 @pytest.mark.asyncio
245 async def test_push_is_idempotent(
246 client: AsyncClient,
247 db_session: AsyncSession,
248 wire_headers: dict,
249 ) -> None:
250 """Pushing the same commit twice must succeed both times."""
251 repo = await factory_create_repo(db_session, slug="push-idempotent-test", owner_user_id="test-user-wire")
252 commit = _make_commit(repo.repo_id)
253 payload = {
254 "bundle": {"commits": [commit], "snapshots": [], "objects": []},
255 "branch": "main",
256 }
257 url = f"/{repo.owner}/{repo.slug}/push"
258 resp1 = await client.post(url, content=_mp(payload), headers=wire_headers)
259 assert resp1.status_code == 200
260 resp2 = await client.post(url, content=_mp(payload), headers=wire_headers)
261 assert resp2.status_code == 200
262
263
264 @pytest.mark.asyncio
265 async def test_push_non_fast_forward_rejected(
266 client: AsyncClient,
267 db_session: AsyncSession,
268 wire_headers: dict,
269 ) -> None:
270 repo = await factory_create_repo(db_session, slug="push-nff-test", owner_user_id="test-user-wire")
271 existing_commit_id = uuid.uuid4().hex
272 branch = db.MusehubBranch(
273 repo_id=repo.repo_id,
274 name="main",
275 head_commit_id=existing_commit_id,
276 )
277 db_session.add(branch)
278 await db_session.commit()
279
280 # Push a commit without existing_commit_id as parent
281 new_commit = _make_commit(repo.repo_id, parent=None)
282 payload = {
283 "bundle": {"commits": [new_commit], "snapshots": [], "objects": []},
284 "branch": "main",
285 "force": False,
286 }
287 resp = await client.post(f"/{repo.owner}/{repo.slug}/push", content=_mp(payload), headers=wire_headers)
288 assert resp.status_code == 409 # 409 Conflict for non-fast-forward
289 assert "non-fast-forward" in resp.json()["detail"]
290
291
292 @pytest.mark.asyncio
293 async def test_push_force_overwrites_branch(
294 client: AsyncClient,
295 db_session: AsyncSession,
296 wire_headers: dict,
297 ) -> None:
298 repo = await factory_create_repo(db_session, slug="push-force-test", owner_user_id="test-user-wire")
299 old_head = uuid.uuid4().hex
300 branch = db.MusehubBranch(
301 repo_id=repo.repo_id,
302 name="main",
303 head_commit_id=old_head,
304 )
305 db_session.add(branch)
306 await db_session.commit()
307
308 new_commit = _make_commit(repo.repo_id, parent=None)
309 payload = {
310 "bundle": {"commits": [new_commit], "snapshots": [], "objects": []},
311 "branch": "main",
312 "force": True,
313 }
314 resp = await client.post(f"/{repo.owner}/{repo.slug}/push", content=_mp(payload), headers=wire_headers)
315 assert resp.status_code == 200
316 data = msgpack.unpackb(resp.content, raw=False)
317 assert data["ok"] is True
318 assert data["branch_heads"]["main"] != old_head
319
320
321 # ── fetch endpoint ─────────────────────────────────────────────────────────────
322
323 @pytest.mark.asyncio
324 async def test_fetch_404_for_unknown_repo(client: AsyncClient) -> None:
325 resp = await client.post(
326 "/nobody/no-such-repo/fetch",
327 content=_mp({"want": [], "have": []}),
328 headers={"Content-Type": "application/x-msgpack"},
329 )
330 assert resp.status_code == 404
331
332
333 @pytest.mark.asyncio
334 async def test_fetch_empty_want_returns_empty_bundle(
335 client: AsyncClient,
336 db_session: AsyncSession,
337 ) -> None:
338 repo = await factory_create_repo(db_session, slug="fetch-empty-test")
339 resp = await client.post(
340 f"/{repo.owner}/{repo.slug}/fetch",
341 content=_mp({"want": [], "have": []}),
342 headers={"Content-Type": "application/x-msgpack", "Accept": "application/x-msgpack"},
343 )
344 assert resp.status_code == 200
345 data = msgpack.unpackb(resp.content, raw=False)
346 assert data["commits"] == []
347 assert data["snapshots"] == []
348 assert data["objects"] == []
349
350
351 @pytest.mark.asyncio
352 async def test_fetch_returns_missing_commits(
353 client: AsyncClient,
354 db_session: AsyncSession,
355 ) -> None:
356 repo = await factory_create_repo(db_session, slug="fetch-commits-test")
357 commit_id = uuid.uuid4().hex
358 commit_row = db.MusehubCommit(
359 commit_id=commit_id,
360 repo_id=repo.repo_id,
361 branch="main",
362 parent_ids=[],
363 message="initial commit",
364 author="Test",
365 timestamp=_utc_now(),
366 snapshot_id=None,
367 commit_meta={},
368 )
369 branch_row = db.MusehubBranch(
370 repo_id=repo.repo_id,
371 name="main",
372 head_commit_id=commit_id,
373 )
374 db_session.add(commit_row)
375 db_session.add(branch_row)
376 await db_session.commit()
377
378 resp = await client.post(
379 f"/{repo.owner}/{repo.slug}/fetch",
380 content=_mp({"want": [commit_id], "have": []}),
381 headers={"Content-Type": "application/x-msgpack", "Accept": "application/x-msgpack"},
382 )
383 assert resp.status_code == 200
384 data = msgpack.unpackb(resp.content, raw=False)
385 assert len(data["commits"]) == 1
386 assert data["commits"][0]["commit_id"] == commit_id
387 assert data["branch_heads"]["main"] == commit_id
388
389
390 # ── content-addressed CDN ──────────────────────────────────────────────────────
391
392 @pytest.mark.asyncio
393 async def test_object_cdn_returns_404_for_missing(client: AsyncClient) -> None:
394 resp = await client.get("/o/nonexistent-sha-12345")
395 assert resp.status_code == 404
396
397
398 # ── unit tests ─────────────────────────────────────────────────────────────────
399
400 @pytest.mark.asyncio
401 async def test_wire_models_parse_correctly() -> None:
402 """WireBundle Pydantic parsing mirrors Muse CLI format."""
403 from musehub.models.wire import WireBundle, WirePushRequest
404
405 commit_dict = {
406 "commit_id": "abc123",
407 "message": "feat: add track",
408 "committed_at": "2026-03-19T10:00:00+00:00",
409 "author": "Gabriel <g@example.com>",
410 "sem_ver_bump": "minor",
411 "breaking_changes": [],
412 "agent_id": "",
413 "format_version": 5,
414 }
415 req = WirePushRequest(
416 bundle=WireBundle(commits=[commit_dict], snapshots=[], objects=[]), # type: ignore[list-item]
417 branch="main",
418 force=False,
419 )
420 assert req.bundle.commits[0].commit_id == "abc123"
421 assert req.bundle.commits[0].sem_ver_bump == "minor"
422 assert req.force is False
423
424
425 @pytest.mark.asyncio
426 async def test_topological_sort_orders_parents_first() -> None:
427 from musehub.models.wire import WireCommit
428 from musehub.services.musehub_wire import _topological_sort
429
430 c1 = WireCommit(commit_id="parent", message="parent")
431 c2 = WireCommit(commit_id="child", message="child", parent_commit_id="parent")
432 sorted_ = _topological_sort([c2, c1])
433 ids = [c.commit_id for c in sorted_]
434 assert ids.index("parent") < ids.index("child")
435
436
437 @pytest.mark.asyncio
438 async def test_remote_url_format_matches_git_pattern(
439 client: AsyncClient,
440 db_session: AsyncSession,
441 ) -> None:
442 """The remote URL is /{owner}/{slug} — no /wire/ prefix, no UUID.
443
444 This mirrors Git:
445 git remote add origin https://github.com/owner/repo
446 versus UUID-based alternatives like:
447 muse remote add origin https://musehub.ai/wire/repos/550e8400-.../
448 """
449 repo = await factory_create_repo(db_session, slug="url-format-test")
450
451 # /{owner}/{slug}/refs must work
452 resp = await client.get(f"/{repo.owner}/{repo.slug}/refs")
453 assert resp.status_code == 200
454
455 # The response confirms which repo was resolved — no UUID needed in the URL
456 data = resp.json()
457 assert data["repo_id"] == repo.repo_id