gabriel / musehub public
musehub_wire.py python
609 lines 23.0 KB
a8fbbd4a Release: dev → main (#58) Gabriel Cardona <cgcardona@gmail.com> 15h ago
1 """Wire protocol service — bridges Muse CLI push/fetch to MuseHub storage.
2
3 This module translates between:
4 Muse CLI native format (snake_case CommitDict / SnapshotDict / ObjectPayload)
5 MuseHub DB / storage (SQLAlchemy ORM + StorageBackend)
6
7 Entry points:
8 wire_refs(session, repo_id) → WireRefsResponse
9 wire_push(session, repo_id, req, pusher_id) → WirePushResponse
10 wire_fetch(session, repo_id, want, have) → WireFetchResponse
11
12 Design decisions:
13 - Non-fast-forward pushes are rejected by default; ``force=True`` blows
14 away the branch pointer (equivalent to ``git push --force``).
15 - Snapshot manifests are stored verbatim in musehub_snapshots.manifest.
16 - Object bytes are base64-decoded and handed to the StorageBackend.
17 - After a successful push, callers are expected to fire background tasks
18 for Qdrant embedding and event fan-out.
19 """
20 from __future__ import annotations
21
22 import base64
23 import logging
24 from datetime import datetime, timezone
25
26 from sqlalchemy import select
27 from sqlalchemy.ext.asyncio import AsyncSession
28
29 from musehub.db import musehub_models as db
30 from musehub.models.wire import (
31 WireCommit,
32 WireBundle,
33 WireFetchRequest,
34 WireFetchResponse,
35 WireObject,
36 WireObjectsRequest,
37 WireObjectsResponse,
38 WirePushRequest,
39 WirePushResponse,
40 WireRefsResponse,
41 WireSnapshot,
42 )
43 from musehub.storage import get_backend
44
45 logger = logging.getLogger(__name__)
46
47
48 # ── helpers ────────────────────────────────────────────────────────────────────
49
50 def _utc_now() -> datetime:
51 return datetime.now(tz=timezone.utc)
52
53
54 def _parse_iso(s: str) -> datetime:
55 """Parse an ISO-8601 string; fall back to now() on failure."""
56 try:
57 return datetime.fromisoformat(s.replace("Z", "+00:00"))
58 except (ValueError, AttributeError):
59 return _utc_now()
60
61
62 def _str_values(d: object) -> dict[str, str]:
63 """Safely coerce a dict with unknown value types to ``dict[str, str]``."""
64 if not isinstance(d, dict):
65 return {}
66 return {str(k): str(v) for k, v in d.items()}
67
68
69 def _str_list(v: object) -> list[str]:
70 """Safely coerce a list with unknown element types to ``list[str]``."""
71 if not isinstance(v, list):
72 return []
73 return [str(x) for x in v]
74
75
76 def _int_safe(v: object, default: int = 0) -> int:
77 """Return *v* as an int when it is numeric; fall back to *default*."""
78 return int(v) if isinstance(v, (int, float)) else default
79
80
81 def _to_wire_commit(row: db.MusehubCommit) -> WireCommit:
82 """Convert a DB commit row back to WireCommit format for fetch responses."""
83 meta: dict[str, object] = row.commit_meta if isinstance(row.commit_meta, dict) else {}
84 parent_ids: list[str] = row.parent_ids if isinstance(row.parent_ids, list) else []
85 return WireCommit(
86 commit_id=row.commit_id,
87 repo_id=row.repo_id,
88 branch=row.branch or "",
89 snapshot_id=row.snapshot_id,
90 message=row.message or "",
91 committed_at=row.timestamp.isoformat() if row.timestamp else "",
92 parent_commit_id=parent_ids[0] if len(parent_ids) >= 1 else None,
93 parent2_commit_id=parent_ids[1] if len(parent_ids) >= 2 else None,
94 author=row.author or "",
95 metadata=_str_values(meta.get("metadata")),
96 structured_delta=meta.get("structured_delta"),
97 sem_ver_bump=str(meta.get("sem_ver_bump") or "none"),
98 breaking_changes=_str_list(meta.get("breaking_changes")),
99 agent_id=str(meta.get("agent_id") or ""),
100 model_id=str(meta.get("model_id") or ""),
101 toolchain_id=str(meta.get("toolchain_id") or ""),
102 prompt_hash=str(meta.get("prompt_hash") or ""),
103 signature=str(meta.get("signature") or ""),
104 signer_key_id=str(meta.get("signer_key_id") or ""),
105 format_version=_int_safe(meta.get("format_version"), default=1),
106 reviewed_by=_str_list(meta.get("reviewed_by")),
107 test_runs=_int_safe(meta.get("test_runs")),
108 )
109
110
111 def _to_wire_snapshot(row: db.MusehubSnapshot) -> WireSnapshot:
112 return WireSnapshot(
113 snapshot_id=row.snapshot_id,
114 manifest=row.manifest if isinstance(row.manifest, dict) else {},
115 created_at=row.created_at.isoformat() if row.created_at else "",
116 )
117
118
119 # ── public service functions ────────────────────────────────────────────────────
120
121 async def wire_refs(
122 session: AsyncSession,
123 repo_id: str,
124 ) -> WireRefsResponse | None:
125 """Return branch heads and repo metadata for the refs endpoint.
126
127 Returns None if the repo does not exist.
128 """
129 repo_row = await session.get(db.MusehubRepo, repo_id)
130 if repo_row is None or repo_row.deleted_at is not None:
131 return None
132
133 branch_rows = (
134 await session.execute(
135 select(db.MusehubBranch).where(db.MusehubBranch.repo_id == repo_id)
136 )
137 ).scalars().all()
138
139 branch_heads: dict[str, str] = {
140 b.name: b.head_commit_id
141 for b in branch_rows
142 if b.head_commit_id
143 }
144
145 domain_meta: dict[str, object] = (
146 repo_row.domain_meta if isinstance(repo_row.domain_meta, dict) else {}
147 )
148 domain = str(domain_meta.get("domain", "code"))
149 default_branch = getattr(repo_row, "default_branch", None) or "main"
150
151 return WireRefsResponse(
152 repo_id=repo_id,
153 domain=domain,
154 default_branch=default_branch,
155 branch_heads=branch_heads,
156 )
157
158
159 async def wire_push(
160 session: AsyncSession,
161 repo_id: str,
162 req: WirePushRequest,
163 pusher_id: str | None = None,
164 ) -> WirePushResponse:
165 """Ingest a push bundle from the Muse CLI.
166
167 Steps:
168 1. Validate repo exists and pusher has access.
169 2. Persist new objects to StorageBackend.
170 3. Persist new snapshots to musehub_snapshots.
171 4. Persist new commits to musehub_commits.
172 5. Update / create branch pointer.
173 6. Update repo pushed_at.
174 7. Return updated branch_heads.
175 """
176 repo_row = await session.get(db.MusehubRepo, repo_id)
177 if repo_row is None or repo_row.deleted_at is not None:
178 return WirePushResponse(ok=False, message="repo not found", branch_heads={})
179
180 # ── Authorization: only the repo owner may push ───────────────────────────
181 # Future: expand to a collaborators table with write-permission check.
182 if not pusher_id or pusher_id != repo_row.owner_user_id:
183 logger.warning(
184 "⚠️ Push rejected: pusher=%s is not owner of repo=%s (owner=%s)",
185 pusher_id,
186 repo_id,
187 repo_row.owner_user_id,
188 )
189 return WirePushResponse(ok=False, message="push rejected: not authorized", branch_heads={})
190
191 backend = get_backend()
192 bundle: WireBundle = req.bundle
193 branch_name: str = req.branch or "main"
194
195 # Resolve the pusher's public username to use as the author fallback when
196 # commits arrive without an --author flag from the CLI.
197 _pusher_profile = await session.get(db.MusehubProfile, pusher_id)
198 _pusher_username: str = (
199 _pusher_profile.username if _pusher_profile is not None else pusher_id or ""
200 )
201
202 # ── 1. Objects ────────────────────────────────────────────────────────────
203 for wire_obj in bundle.objects:
204 if not wire_obj.object_id or not wire_obj.content_b64:
205 continue
206 existing = await session.get(db.MusehubObject, wire_obj.object_id)
207 if existing is not None:
208 continue # already stored — idempotent
209
210 try:
211 raw = base64.b64decode(wire_obj.content_b64 + "==")
212 except Exception as exc:
213 logger.warning("Failed to decode object %s: %s", wire_obj.object_id, exc)
214 continue
215
216 storage_uri = await backend.put(repo_id, wire_obj.object_id, raw)
217 obj_row = db.MusehubObject(
218 object_id=wire_obj.object_id,
219 repo_id=repo_id,
220 path=wire_obj.path or "",
221 size_bytes=len(raw),
222 disk_path=storage_uri.replace("local://", ""),
223 storage_uri=storage_uri,
224 )
225 session.add(obj_row)
226
227 # ── 2. Snapshots ──────────────────────────────────────────────────────────
228 for wire_snap in bundle.snapshots:
229 if not wire_snap.snapshot_id:
230 continue
231 existing_snap = await session.get(db.MusehubSnapshot, wire_snap.snapshot_id)
232 if existing_snap is not None:
233 continue
234 snap_row = db.MusehubSnapshot(
235 snapshot_id=wire_snap.snapshot_id,
236 repo_id=repo_id,
237 manifest=wire_snap.manifest,
238 created_at=_parse_iso(wire_snap.created_at) if wire_snap.created_at else _utc_now(),
239 )
240 session.add(snap_row)
241
242 # ── 3. Commits ────────────────────────────────────────────────────────────
243 ordered_commits = _topological_sort(bundle.commits)
244 new_head: str | None = None
245
246 for wire_commit in ordered_commits:
247 if not wire_commit.commit_id:
248 continue
249 existing_commit = await session.get(db.MusehubCommit, wire_commit.commit_id)
250 if existing_commit is not None:
251 new_head = wire_commit.commit_id
252 continue
253
254 parent_ids: list[str] = []
255 if wire_commit.parent_commit_id:
256 parent_ids.append(wire_commit.parent_commit_id)
257 if wire_commit.parent2_commit_id:
258 parent_ids.append(wire_commit.parent2_commit_id)
259
260 commit_meta: dict[str, object] = {
261 "metadata": wire_commit.metadata,
262 "structured_delta": wire_commit.structured_delta,
263 "sem_ver_bump": wire_commit.sem_ver_bump,
264 "breaking_changes": wire_commit.breaking_changes,
265 "agent_id": wire_commit.agent_id,
266 "model_id": wire_commit.model_id,
267 "toolchain_id": wire_commit.toolchain_id,
268 "prompt_hash": wire_commit.prompt_hash,
269 "signature": wire_commit.signature,
270 "signer_key_id": wire_commit.signer_key_id,
271 "format_version": wire_commit.format_version,
272 "reviewed_by": wire_commit.reviewed_by,
273 "test_runs": wire_commit.test_runs,
274 }
275
276 # Fall back to the pusher's username when the CLI didn't supply --author.
277 author = wire_commit.author or _pusher_username
278 commit_row = db.MusehubCommit(
279 commit_id=wire_commit.commit_id,
280 repo_id=repo_id,
281 branch=branch_name,
282 parent_ids=parent_ids,
283 message=wire_commit.message,
284 author=author,
285 timestamp=_parse_iso(wire_commit.committed_at) if wire_commit.committed_at else _utc_now(),
286 snapshot_id=wire_commit.snapshot_id,
287 commit_meta=commit_meta,
288 )
289 session.add(commit_row)
290 new_head = wire_commit.commit_id
291
292 if new_head is None and bundle.commits:
293 new_head = bundle.commits[-1].commit_id
294
295 # ── 4. Branch pointer ─────────────────────────────────────────────────────
296 branch_row = (
297 await session.execute(
298 select(db.MusehubBranch).where(
299 db.MusehubBranch.repo_id == repo_id,
300 db.MusehubBranch.name == branch_name,
301 )
302 )
303 ).scalar_one_or_none()
304
305 if branch_row is None:
306 branch_row = db.MusehubBranch(
307 repo_id=repo_id,
308 name=branch_name,
309 head_commit_id=new_head or "",
310 )
311 session.add(branch_row)
312 else:
313 # fast-forward check: existing HEAD must appear in the parent chain of
314 # any pushed commit. BFS both parents so merge commits are handled.
315 if not req.force and branch_row.head_commit_id:
316 if not _is_ancestor_in_bundle(branch_row.head_commit_id, bundle.commits):
317 await session.rollback()
318 return WirePushResponse(
319 ok=False,
320 message=(
321 f"non-fast-forward push to '{branch_name}' — "
322 "use force=true to overwrite"
323 ),
324 branch_heads={},
325 )
326 branch_row.head_commit_id = new_head or branch_row.head_commit_id
327
328 # ── 5. Update repo ────────────────────────────────────────────────────────
329 repo_row.pushed_at = _utc_now()
330 # Only set default_branch on the very first push (no other branches exist).
331 # Never overwrite on subsequent pushes — doing so would silently corrupt the
332 # published default every time anyone pushes a non-default branch.
333 other_branches_q = await session.execute(
334 select(db.MusehubBranch).where(
335 db.MusehubBranch.repo_id == repo_id,
336 db.MusehubBranch.name != branch_name,
337 )
338 )
339 if other_branches_q.scalars().first() is None:
340 repo_row.default_branch = branch_name
341
342 await session.commit()
343
344 # Re-fetch branch heads after commit
345 branch_rows = (
346 await session.execute(
347 select(db.MusehubBranch).where(db.MusehubBranch.repo_id == repo_id)
348 )
349 ).scalars().all()
350 branch_heads = {b.name: b.head_commit_id for b in branch_rows if b.head_commit_id}
351
352 return WirePushResponse(
353 ok=True,
354 message=f"pushed {len(bundle.commits)} commit(s) to '{branch_name}'",
355 branch_heads=branch_heads,
356 remote_head=new_head or "",
357 )
358
359
360 async def wire_push_objects(
361 session: AsyncSession,
362 repo_id: str,
363 req: WireObjectsRequest,
364 pusher_id: str | None = None,
365 ) -> WireObjectsResponse:
366 """Pre-upload a chunk of objects for a large push.
367
368 This is Phase 1 of a chunked push. The client splits its object list
369 into batches of ≤ MAX_OBJECTS_PER_PUSH and calls this endpoint once per
370 batch. Phase 2 is ``wire_push`` with an empty ``bundle.objects`` list;
371 the final push only carries commits and snapshots (which are small).
372
373 Objects are content-addressed, so uploading the same object twice is
374 harmless — existing objects are skipped and counted as ``skipped``.
375 Authorization mirrors ``wire_push``: only the repo owner may upload.
376 """
377 repo_row = await session.get(db.MusehubRepo, repo_id)
378 if repo_row is None or repo_row.deleted_at is not None:
379 raise ValueError("repo not found")
380
381 if not pusher_id or pusher_id != repo_row.owner_user_id:
382 logger.warning(
383 "⚠️ push/objects rejected: pusher=%s is not owner of repo=%s",
384 pusher_id,
385 repo_id,
386 )
387 raise PermissionError("push rejected: not authorized")
388
389 backend = get_backend()
390 stored = 0
391 skipped = 0
392
393 for wire_obj in req.objects:
394 if not wire_obj.object_id or not wire_obj.content_b64:
395 continue
396 existing = await session.get(db.MusehubObject, wire_obj.object_id)
397 if existing is not None:
398 skipped += 1
399 continue
400
401 try:
402 raw = base64.b64decode(wire_obj.content_b64 + "==")
403 except Exception as exc:
404 logger.warning("Failed to decode object %s: %s", wire_obj.object_id, exc)
405 continue
406
407 storage_uri = await backend.put(repo_id, wire_obj.object_id, raw)
408 obj_row = db.MusehubObject(
409 object_id=wire_obj.object_id,
410 repo_id=repo_id,
411 path=wire_obj.path or "",
412 size_bytes=len(raw),
413 disk_path=storage_uri.replace("local://", ""),
414 storage_uri=storage_uri,
415 )
416 session.add(obj_row)
417 stored += 1
418
419 await session.commit()
420 logger.info(
421 "✅ push/objects repo=%s stored=%d skipped=%d", repo_id, stored, skipped
422 )
423 return WireObjectsResponse(stored=stored, skipped=skipped)
424
425
426 async def _fetch_commit(
427 session: AsyncSession,
428 commit_id: str,
429 ) -> db.MusehubCommit | None:
430 """Load a single commit by primary key — avoids full-table scans."""
431 return await session.get(db.MusehubCommit, commit_id)
432
433
434 async def wire_fetch(
435 session: AsyncSession,
436 repo_id: str,
437 req: WireFetchRequest,
438 ) -> WireFetchResponse | None:
439 """Return the minimal set of commits/snapshots/objects to satisfy ``want``.
440
441 BFS from each ``want`` commit toward its ancestors, stopping at any commit
442 in ``have`` (already on client) or when a commit is missing (orphan).
443
444 Commits are loaded one at a time by primary key rather than doing a
445 full ``SELECT … WHERE repo_id = ?`` table scan. This keeps memory
446 proportional to the *delta* (commits the client needs) rather than the
447 total repository history.
448 """
449 repo_row = await session.get(db.MusehubRepo, repo_id)
450 if repo_row is None or repo_row.deleted_at is not None:
451 return None
452
453 have_set = set(req.have)
454 want_set = set(req.want)
455
456 # BFS from want → collect commits not in have.
457 # Each BFS level is fetched in a single batched SELECT … WHERE commit_id IN (…)
458 # rather than one query per commit, keeping N round-trips proportional to
459 # the *depth* of the delta (number of BFS levels) rather than its *width*.
460 needed_rows: dict[str, db.MusehubCommit] = {}
461 frontier: set[str] = want_set - have_set
462 visited: set[str] = set()
463
464 while frontier:
465 # Remove already-visited from this level's batch
466 batch = frontier - visited
467 if not batch:
468 break
469 visited.update(batch)
470
471 # One IN query for the entire frontier level
472 rows_q = await session.execute(
473 select(db.MusehubCommit).where(
474 db.MusehubCommit.commit_id.in_(batch),
475 db.MusehubCommit.repo_id == repo_id,
476 )
477 )
478 next_frontier: set[str] = set()
479 for row in rows_q.scalars().all():
480 needed_rows[row.commit_id] = row
481 for pid in (row.parent_ids or []):
482 if pid not in visited and pid not in have_set:
483 next_frontier.add(pid)
484 frontier = next_frontier
485
486 needed_commits = [_to_wire_commit(needed_rows[cid]) for cid in needed_rows]
487
488 # Collect snapshot IDs we need to send
489 snap_ids = {c.snapshot_id for c in needed_commits if c.snapshot_id}
490 wire_snapshots: list[WireSnapshot] = []
491 if snap_ids:
492 snap_rows_q = await session.execute(
493 select(db.MusehubSnapshot).where(db.MusehubSnapshot.snapshot_id.in_(snap_ids))
494 )
495 for sr in snap_rows_q.scalars().all():
496 wire_snapshots.append(_to_wire_snapshot(sr))
497
498 # Collect object IDs referenced by snapshots
499 all_obj_ids: set[str] = set()
500 for ws in wire_snapshots:
501 all_obj_ids.update(ws.manifest.values())
502
503 wire_objects: list[WireObject] = []
504 backend = get_backend()
505 if all_obj_ids:
506 obj_rows_q = await session.execute(
507 select(db.MusehubObject).where(
508 db.MusehubObject.repo_id == repo_id,
509 db.MusehubObject.object_id.in_(all_obj_ids),
510 )
511 )
512 for obj_row in obj_rows_q.scalars().all():
513 raw = await backend.get(repo_id, obj_row.object_id)
514 if raw is None:
515 continue
516 wire_objects.append(WireObject(
517 object_id=obj_row.object_id,
518 content_b64=base64.b64encode(raw).decode(),
519 path=obj_row.path or "",
520 ))
521
522 # Current branch heads
523 branch_rows_q = await session.execute(
524 select(db.MusehubBranch).where(db.MusehubBranch.repo_id == repo_id)
525 )
526 branch_heads = {
527 b.name: b.head_commit_id
528 for b in branch_rows_q.scalars().all()
529 if b.head_commit_id
530 }
531
532 return WireFetchResponse(
533 commits=needed_commits,
534 snapshots=wire_snapshots,
535 objects=wire_objects,
536 branch_heads=branch_heads,
537 )
538
539
540 # ── private helpers ────────────────────────────────────────────────────────────
541
542 def _topological_sort(commits: list[WireCommit]) -> list[WireCommit]:
543 """Sort commits so parents come before children (Kahn's algorithm)."""
544 if not commits:
545 return []
546 by_id = {c.commit_id: c for c in commits}
547 in_degree: dict[str, int] = {c.commit_id: 0 for c in commits}
548 children: dict[str, list[str]] = {c.commit_id: [] for c in commits}
549
550 for c in commits:
551 for pid in filter(None, [c.parent_commit_id, c.parent2_commit_id]):
552 if pid in by_id:
553 in_degree[c.commit_id] += 1
554 children[pid].append(c.commit_id)
555
556 queue = [cid for cid, deg in in_degree.items() if deg == 0]
557 result: list[WireCommit] = []
558 while queue:
559 cid = queue.pop(0)
560 result.append(by_id[cid])
561 for child_id in children.get(cid, []):
562 in_degree[child_id] -= 1
563 if in_degree[child_id] == 0:
564 queue.append(child_id)
565
566 # Append any commits that topological sort couldn't place (cycles/missing parents)
567 sorted_ids = {c.commit_id for c in result}
568 for c in commits:
569 if c.commit_id not in sorted_ids:
570 result.append(c)
571
572 return result
573
574
575 def _is_ancestor_in_bundle(head_id: str, commits: list[WireCommit]) -> bool:
576 """Return True if ``head_id`` appears in the ancestor graph of any pushed commit.
577
578 BFS both parents so merge commits are handled correctly: a merge commit
579 has parent_commit_id (first parent, the branch being merged into) and
580 parent2_commit_id (second parent, the branch being merged from). The
581 remote HEAD may be either parent.
582
583 The walk stops when it leaves the bundle (parent not found in commit_by_id)
584 since commits outside the bundle already exist on the server — if the
585 remote HEAD is outside the bundle, the client intentionally excluded it
586 via ``have``, which means it IS an ancestor (incremental push).
587 """
588 commit_by_id = {c.commit_id: c for c in commits}
589 visited: set[str] = set()
590 frontier: list[str] = [c.commit_id for c in commits]
591
592 while frontier:
593 cid = frontier.pop()
594 if cid in visited:
595 continue
596 visited.add(cid)
597 if cid == head_id:
598 return True
599 row = commit_by_id.get(cid)
600 if row is None:
601 # This commit is outside the bundle (already on server).
602 # The client used ``have`` to exclude it, meaning they consider
603 # this commit an ancestor they share — keep walking is pointless.
604 continue
605 for pid in filter(None, [row.parent_commit_id, row.parent2_commit_id]):
606 if pid not in visited:
607 frontier.append(pid)
608
609 return False