gabriel / musehub public
test_wire_protocol.py python
387 lines 12.5 KB
cb3d85e8 feat: wire protocol, storage abstraction, unified identities, Qdrant pi… Gabriel Cardona <cgcardona@gmail.com> 4d ago
1 """Wire protocol endpoint tests.
2
3 Covers the three Muse CLI transport endpoints:
4 GET /wire/repos/{repo_id}/refs
5 POST /wire/repos/{repo_id}/push
6 POST /wire/repos/{repo_id}/fetch
7
8 And the content-addressed CDN endpoint:
9 GET /o/{object_id}
10 """
11 from __future__ import annotations
12
13 import base64
14 import json
15 import uuid
16 from datetime import datetime, timezone
17
18 import os
19 import tempfile
20
21 import pytest
22 import pytest_asyncio
23 from httpx import AsyncClient
24 from sqlalchemy.ext.asyncio import AsyncSession
25
26 from musehub.auth.tokens import create_access_token
27 from musehub.db import musehub_models as db
28 from tests.factories import create_repo as factory_create_repo, create_branch as factory_create_branch
29
30
31 # ── helpers ────────────────────────────────────────────────────────────────────
32
33 def _utc_now() -> datetime:
34 return datetime.now(tz=timezone.utc)
35
36
37 def _make_commit(repo_id: str, commit_id: str | None = None, parent: str | None = None) -> dict:
38 return {
39 "commit_id": commit_id or str(uuid.uuid4()),
40 "repo_id": repo_id,
41 "branch": "main",
42 "snapshot_id": f"snap_{uuid.uuid4().hex[:8]}",
43 "message": "chore: add test commit",
44 "committed_at": _utc_now().isoformat(),
45 "parent_commit_id": parent,
46 "author": "Test User <test@example.com>",
47 "sem_ver_bump": "patch",
48 }
49
50
51 def _make_object(content: bytes = b"hello world") -> dict:
52 oid = uuid.uuid4().hex
53 return {
54 "object_id": oid,
55 "content_b64": base64.b64encode(content).decode(),
56 "path": "README.md",
57 }
58
59
60 def _make_snapshot(snap_id: str, object_id: str) -> dict:
61 return {
62 "snapshot_id": snap_id,
63 "manifest": {"README.md": object_id},
64 "created_at": _utc_now().isoformat(),
65 }
66
67
68 @pytest.fixture(autouse=True)
69 def _tmp_objects_dir(tmp_path: object, monkeypatch: pytest.MonkeyPatch) -> None:
70 """Override object storage to use a temp directory in tests."""
71 import musehub.storage.backends as _backends
72 import musehub.services.musehub_wire as _wire_svc
73 import musehub.api.routes.wire as _wire_route
74
75 obj_dir = str(tmp_path) + "/objects" # type: ignore[operator]
76 os.makedirs(obj_dir, exist_ok=True)
77 test_backend = _backends.LocalBackend(objects_dir=obj_dir)
78 monkeypatch.setattr(_wire_svc, "get_backend", lambda: test_backend)
79 monkeypatch.setattr(_wire_route, "get_backend", lambda: test_backend)
80
81
82 @pytest.fixture
83 def auth_wire_token() -> str:
84 return create_access_token(user_id="test-user-wire", expires_hours=1)
85
86
87 @pytest.fixture
88 def wire_headers(auth_wire_token: str) -> dict[str, str]:
89 return {
90 "Authorization": f"Bearer {auth_wire_token}",
91 "Content-Type": "application/json",
92 }
93
94
95 # ── refs endpoint ──────────────────────────────────────────────────────────────
96
97 @pytest.mark.asyncio
98 async def test_refs_returns_404_for_unknown_repo(client: AsyncClient) -> None:
99 resp = await client.get("/wire/repos/nonexistent-repo-id/refs")
100 assert resp.status_code == 404
101
102
103 @pytest.mark.asyncio
104 async def test_refs_returns_branch_heads(
105 client: AsyncClient,
106 db_session: AsyncSession,
107 ) -> None:
108 repo = await factory_create_repo(db_session, slug="muse-test", domain_meta={"domain": "code"})
109 branch = db.MusehubBranch(
110 repo_id=repo.repo_id,
111 name="main",
112 head_commit_id="abc123",
113 )
114 db_session.add(branch)
115 await db_session.commit()
116
117 resp = await client.get(f"/wire/repos/{repo.repo_id}/refs")
118 assert resp.status_code == 200
119 data = resp.json()
120 assert data["repo_id"] == repo.repo_id
121 assert data["default_branch"] == "main"
122 assert data["domain"] == "code"
123 assert data["branch_heads"]["main"] == "abc123"
124
125
126 @pytest.mark.asyncio
127 async def test_refs_empty_repo_has_empty_branch_heads(
128 client: AsyncClient,
129 db_session: AsyncSession,
130 ) -> None:
131 repo = await factory_create_repo(db_session, slug="empty-test")
132 resp = await client.get(f"/wire/repos/{repo.repo_id}/refs")
133 assert resp.status_code == 200
134 data = resp.json()
135 assert data["branch_heads"] == {}
136
137
138 # ── push endpoint ──────────────────────────────────────────────────────────────
139
140 @pytest.mark.asyncio
141 async def test_push_requires_auth(client: AsyncClient, db_session: AsyncSession) -> None:
142 repo = await factory_create_repo(db_session, slug="push-auth-test")
143 resp = await client.post(
144 f"/wire/repos/{repo.repo_id}/push",
145 json={"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"},
146 )
147 assert resp.status_code in (401, 403)
148
149
150 @pytest.mark.asyncio
151 async def test_push_404_for_unknown_repo(
152 client: AsyncClient,
153 wire_headers: dict,
154 ) -> None:
155 resp = await client.post(
156 "/wire/repos/does-not-exist/push",
157 json={"bundle": {"commits": [], "snapshots": [], "objects": []}, "branch": "main"},
158 headers=wire_headers,
159 )
160 assert resp.status_code == 422 # wire service returns ok=False → 422
161
162
163 @pytest.mark.asyncio
164 async def test_push_ingests_commit_and_branch(
165 client: AsyncClient,
166 db_session: AsyncSession,
167 wire_headers: dict,
168 ) -> None:
169 repo = await factory_create_repo(db_session, slug="push-ingest-test")
170
171 commit_id = uuid.uuid4().hex
172 obj = _make_object()
173 snap_id = f"snap_{uuid.uuid4().hex[:8]}"
174 snap = _make_snapshot(snap_id, obj["object_id"])
175 commit = _make_commit(repo.repo_id, commit_id=commit_id)
176 commit["snapshot_id"] = snap_id
177
178 payload = {
179 "bundle": {
180 "commits": [commit],
181 "snapshots": [snap],
182 "objects": [obj],
183 },
184 "branch": "main",
185 "force": False,
186 }
187 resp = await client.post(
188 f"/wire/repos/{repo.repo_id}/push",
189 json=payload,
190 headers=wire_headers,
191 )
192 assert resp.status_code == 200, resp.text
193 data = resp.json()
194 assert data["ok"] is True
195 assert "main" in data["branch_heads"]
196 assert data["remote_head"] == commit_id
197
198
199 @pytest.mark.asyncio
200 async def test_push_is_idempotent(
201 client: AsyncClient,
202 db_session: AsyncSession,
203 wire_headers: dict,
204 ) -> None:
205 """Pushing the same commit twice must succeed both times."""
206 repo = await factory_create_repo(db_session, slug="push-idempotent-test")
207 commit = _make_commit(repo.repo_id)
208 payload = {
209 "bundle": {"commits": [commit], "snapshots": [], "objects": []},
210 "branch": "main",
211 }
212
213 resp1 = await client.post(f"/wire/repos/{repo.repo_id}/push", json=payload, headers=wire_headers)
214 assert resp1.status_code == 200
215 resp2 = await client.post(f"/wire/repos/{repo.repo_id}/push", json=payload, headers=wire_headers)
216 assert resp2.status_code == 200
217
218
219 @pytest.mark.asyncio
220 async def test_push_non_fast_forward_rejected(
221 client: AsyncClient,
222 db_session: AsyncSession,
223 wire_headers: dict,
224 ) -> None:
225 repo = await factory_create_repo(db_session, slug="push-nff-test")
226 existing_commit_id = uuid.uuid4().hex
227 branch = db.MusehubBranch(
228 repo_id=repo.repo_id,
229 name="main",
230 head_commit_id=existing_commit_id,
231 )
232 db_session.add(branch)
233 await db_session.commit()
234
235 # Push a commit that does NOT have existing_commit_id as parent
236 new_commit = _make_commit(repo.repo_id, parent=None)
237 payload = {
238 "bundle": {"commits": [new_commit], "snapshots": [], "objects": []},
239 "branch": "main",
240 "force": False,
241 }
242 resp = await client.post(f"/wire/repos/{repo.repo_id}/push", json=payload, headers=wire_headers)
243 assert resp.status_code == 422
244 assert "non-fast-forward" in resp.json()["detail"]
245
246
247 @pytest.mark.asyncio
248 async def test_push_force_overwrites_branch(
249 client: AsyncClient,
250 db_session: AsyncSession,
251 wire_headers: dict,
252 ) -> None:
253 repo = await factory_create_repo(db_session, slug="push-force-test")
254 old_head = uuid.uuid4().hex
255 branch = db.MusehubBranch(
256 repo_id=repo.repo_id,
257 name="main",
258 head_commit_id=old_head,
259 )
260 db_session.add(branch)
261 await db_session.commit()
262
263 new_commit = _make_commit(repo.repo_id, parent=None)
264 payload = {
265 "bundle": {"commits": [new_commit], "snapshots": [], "objects": []},
266 "branch": "main",
267 "force": True,
268 }
269 resp = await client.post(f"/wire/repos/{repo.repo_id}/push", json=payload, headers=wire_headers)
270 assert resp.status_code == 200
271 data = resp.json()
272 assert data["ok"] is True
273 assert data["branch_heads"]["main"] != old_head
274
275
276 # ── fetch endpoint ─────────────────────────────────────────────────────────────
277
278 @pytest.mark.asyncio
279 async def test_fetch_404_for_unknown_repo(client: AsyncClient) -> None:
280 resp = await client.post(
281 "/wire/repos/no-such-repo/fetch",
282 json={"want": [], "have": []},
283 )
284 assert resp.status_code == 404
285
286
287 @pytest.mark.asyncio
288 async def test_fetch_empty_want_returns_empty_bundle(
289 client: AsyncClient,
290 db_session: AsyncSession,
291 ) -> None:
292 repo = await factory_create_repo(db_session, slug="fetch-empty-test")
293
294 resp = await client.post(
295 f"/wire/repos/{repo.repo_id}/fetch",
296 json={"want": [], "have": []},
297 )
298 assert resp.status_code == 200
299 data = resp.json()
300 assert data["commits"] == []
301 assert data["snapshots"] == []
302 assert data["objects"] == []
303
304
305 @pytest.mark.asyncio
306 async def test_fetch_returns_missing_commits(
307 client: AsyncClient,
308 db_session: AsyncSession,
309 ) -> None:
310 """After a push, fetch should return the pushed commits when wanted."""
311 repo = await factory_create_repo(db_session, slug="fetch-commits-test")
312 commit_id = uuid.uuid4().hex
313 commit_row = db.MusehubCommit(
314 commit_id=commit_id,
315 repo_id=repo.repo_id,
316 branch="main",
317 parent_ids=[],
318 message="initial commit",
319 author="Test",
320 timestamp=_utc_now(),
321 snapshot_id=None,
322 commit_meta={},
323 )
324 branch_row = db.MusehubBranch(
325 repo_id=repo.repo_id,
326 name="main",
327 head_commit_id=commit_id,
328 )
329 db_session.add(commit_row)
330 db_session.add(branch_row)
331 await db_session.commit()
332
333 resp = await client.post(
334 f"/wire/repos/{repo.repo_id}/fetch",
335 json={"want": [commit_id], "have": []},
336 )
337 assert resp.status_code == 200
338 data = resp.json()
339 assert len(data["commits"]) == 1
340 assert data["commits"][0]["commit_id"] == commit_id
341 assert data["branch_heads"]["main"] == commit_id
342
343
344 # ── content-addressed CDN ──────────────────────────────────────────────────────
345
346 @pytest.mark.asyncio
347 async def test_object_cdn_returns_404_for_missing(client: AsyncClient) -> None:
348 resp = await client.get("/o/nonexistent-sha-12345")
349 assert resp.status_code == 404
350
351
352 @pytest.mark.asyncio
353 async def test_wire_models_parse_correctly() -> None:
354 """Unit test: WireBundle Pydantic parsing from dict mirrors Muse CLI format."""
355 from musehub.models.wire import WireBundle, WirePushRequest
356
357 commit_dict = {
358 "commit_id": "abc123",
359 "message": "feat: add track",
360 "committed_at": "2026-03-19T10:00:00+00:00",
361 "author": "Gabriel <g@example.com>",
362 "sem_ver_bump": "minor",
363 "breaking_changes": [],
364 "agent_id": "",
365 "format_version": 5,
366 }
367 req = WirePushRequest(
368 bundle=WireBundle(commits=[commit_dict], snapshots=[], objects=[]), # type: ignore[list-item]
369 branch="main",
370 force=False,
371 )
372 assert req.bundle.commits[0].commit_id == "abc123"
373 assert req.bundle.commits[0].sem_ver_bump == "minor"
374 assert req.force is False
375
376
377 @pytest.mark.asyncio
378 async def test_topological_sort_orders_parents_first() -> None:
379 from musehub.models.wire import WireCommit
380 from musehub.services.musehub_wire import _topological_sort
381
382 c1 = WireCommit(commit_id="parent", message="parent")
383 c2 = WireCommit(commit_id="child", message="child", parent_commit_id="parent")
384 # Pass in reverse order
385 sorted_ = _topological_sort([c2, c1])
386 ids = [c.commit_id for c in sorted_]
387 assert ids.index("parent") < ids.index("child")