gabriel / musehub public
musehub_wire.py python
754 lines 28.2 KB
4a63ff4c feat(mwp): full msgpack wire protocol — replace JSON+base64 on all push… Gabriel Cardona <cgcardona@gmail.com> 12h ago
1 """Wire protocol service — bridges Muse CLI push/fetch to MuseHub storage.
2
3 This module translates between:
4 Muse CLI native format (snake_case CommitDict / SnapshotDict / ObjectPayload)
5 MuseHub DB / storage (SQLAlchemy ORM + StorageBackend)
6
7 Entry points:
8 wire_refs(session, repo_id) → WireRefsResponse
9 wire_push(session, repo_id, req, pusher_id) → WirePushResponse
10 wire_fetch(session, repo_id, want, have) → WireFetchResponse
11
12 Design decisions:
13 - Non-fast-forward pushes are rejected by default; ``force=True`` blows
14 away the branch pointer (equivalent to ``git push --force``).
15 - Snapshot manifests are stored verbatim in musehub_snapshots.manifest.
16 - Object bytes are base64-decoded and handed to the StorageBackend.
17 - After a successful push, callers are expected to fire background tasks
18 for Qdrant embedding and event fan-out.
19 """
20 from __future__ import annotations
21
22 import logging
23 from datetime import datetime, timezone
24
25 from sqlalchemy import select
26 from sqlalchemy.ext.asyncio import AsyncSession
27
28 from musehub.db import musehub_models as db
29 from musehub.models.wire import (
30 WireCommit,
31 WireBundle,
32 WireFetchRequest,
33 WireFetchResponse,
34 WireFilterRequest,
35 WireFilterResponse,
36 WireNegotiateRequest,
37 WireNegotiateResponse,
38 WireObject,
39 WireObjectsRequest,
40 WireObjectsResponse,
41 WirePresignRequest,
42 WirePresignResponse,
43 WirePushRequest,
44 WirePushResponse,
45 WireRefsResponse,
46 WireSnapshot,
47 )
48 from musehub.storage import get_backend
49
50 logger = logging.getLogger(__name__)
51
52
53 # ── helpers ────────────────────────────────────────────────────────────────────
54
55 def _utc_now() -> datetime:
56 return datetime.now(tz=timezone.utc)
57
58
59 def _parse_iso(s: str) -> datetime:
60 """Parse an ISO-8601 string; fall back to now() on failure."""
61 try:
62 return datetime.fromisoformat(s.replace("Z", "+00:00"))
63 except (ValueError, AttributeError):
64 return _utc_now()
65
66
67 def _str_values(d: object) -> dict[str, str]:
68 """Safely coerce a dict with unknown value types to ``dict[str, str]``."""
69 if not isinstance(d, dict):
70 return {}
71 return {str(k): str(v) for k, v in d.items()}
72
73
74 def _str_list(v: object) -> list[str]:
75 """Safely coerce a list with unknown element types to ``list[str]``."""
76 if not isinstance(v, list):
77 return []
78 return [str(x) for x in v]
79
80
81 def _int_safe(v: object, default: int = 0) -> int:
82 """Return *v* as an int when it is numeric; fall back to *default*."""
83 return int(v) if isinstance(v, (int, float)) else default
84
85
86 def _to_wire_commit(row: db.MusehubCommit) -> WireCommit:
87 """Convert a DB commit row back to WireCommit format for fetch responses."""
88 meta: dict[str, object] = row.commit_meta if isinstance(row.commit_meta, dict) else {}
89 parent_ids: list[str] = row.parent_ids if isinstance(row.parent_ids, list) else []
90 return WireCommit(
91 commit_id=row.commit_id,
92 repo_id=row.repo_id,
93 branch=row.branch or "",
94 snapshot_id=row.snapshot_id,
95 message=row.message or "",
96 committed_at=row.timestamp.isoformat() if row.timestamp else "",
97 parent_commit_id=parent_ids[0] if len(parent_ids) >= 1 else None,
98 parent2_commit_id=parent_ids[1] if len(parent_ids) >= 2 else None,
99 author=row.author or "",
100 metadata=_str_values(meta.get("metadata")),
101 structured_delta=meta.get("structured_delta"),
102 sem_ver_bump=str(meta.get("sem_ver_bump") or "none"),
103 breaking_changes=_str_list(meta.get("breaking_changes")),
104 agent_id=str(meta.get("agent_id") or ""),
105 model_id=str(meta.get("model_id") or ""),
106 toolchain_id=str(meta.get("toolchain_id") or ""),
107 prompt_hash=str(meta.get("prompt_hash") or ""),
108 signature=str(meta.get("signature") or ""),
109 signer_key_id=str(meta.get("signer_key_id") or ""),
110 format_version=_int_safe(meta.get("format_version"), default=1),
111 reviewed_by=_str_list(meta.get("reviewed_by")),
112 test_runs=_int_safe(meta.get("test_runs")),
113 )
114
115
116 def _to_wire_snapshot(row: db.MusehubSnapshot) -> WireSnapshot:
117 return WireSnapshot(
118 snapshot_id=row.snapshot_id,
119 manifest=row.manifest if isinstance(row.manifest, dict) else {},
120 created_at=row.created_at.isoformat() if row.created_at else "",
121 )
122
123
124 # ── public service functions ────────────────────────────────────────────────────
125
126 async def wire_refs(
127 session: AsyncSession,
128 repo_id: str,
129 ) -> WireRefsResponse | None:
130 """Return branch heads and repo metadata for the refs endpoint.
131
132 Returns None if the repo does not exist.
133 """
134 repo_row = await session.get(db.MusehubRepo, repo_id)
135 if repo_row is None or repo_row.deleted_at is not None:
136 return None
137
138 branch_rows = (
139 await session.execute(
140 select(db.MusehubBranch).where(db.MusehubBranch.repo_id == repo_id)
141 )
142 ).scalars().all()
143
144 branch_heads: dict[str, str] = {
145 b.name: b.head_commit_id
146 for b in branch_rows
147 if b.head_commit_id
148 }
149
150 domain_meta: dict[str, object] = (
151 repo_row.domain_meta if isinstance(repo_row.domain_meta, dict) else {}
152 )
153 domain = str(domain_meta.get("domain", "code"))
154 default_branch = getattr(repo_row, "default_branch", None) or "main"
155
156 return WireRefsResponse(
157 repo_id=repo_id,
158 domain=domain,
159 default_branch=default_branch,
160 branch_heads=branch_heads,
161 )
162
163
164 async def wire_push(
165 session: AsyncSession,
166 repo_id: str,
167 req: WirePushRequest,
168 pusher_id: str | None = None,
169 ) -> WirePushResponse:
170 """Ingest a push bundle from the Muse CLI.
171
172 Steps:
173 1. Validate repo exists and pusher has access.
174 2. Persist new objects to StorageBackend.
175 3. Persist new snapshots to musehub_snapshots.
176 4. Persist new commits to musehub_commits.
177 5. Update / create branch pointer.
178 6. Update repo pushed_at.
179 7. Return updated branch_heads.
180 """
181 repo_row = await session.get(db.MusehubRepo, repo_id)
182 if repo_row is None or repo_row.deleted_at is not None:
183 return WirePushResponse(ok=False, message="repo not found", branch_heads={})
184
185 # ── Authorization: only the repo owner may push ───────────────────────────
186 # Future: expand to a collaborators table with write-permission check.
187 if not pusher_id or pusher_id != repo_row.owner_user_id:
188 logger.warning(
189 "⚠️ Push rejected: pusher=%s is not owner of repo=%s (owner=%s)",
190 pusher_id,
191 repo_id,
192 repo_row.owner_user_id,
193 )
194 return WirePushResponse(ok=False, message="push rejected: not authorized", branch_heads={})
195
196 backend = get_backend()
197 bundle: WireBundle = req.bundle
198 branch_name: str = req.branch or "main"
199
200 # Resolve the pusher's public username to use as the author fallback when
201 # commits arrive without an --author flag from the CLI.
202 _pusher_profile = await session.get(db.MusehubProfile, pusher_id)
203 _pusher_username: str = (
204 _pusher_profile.username if _pusher_profile is not None else pusher_id or ""
205 )
206
207 # ── 1. Objects ────────────────────────────────────────────────────────────
208 for wire_obj in bundle.objects:
209 if not wire_obj.object_id or not wire_obj.content:
210 continue
211 existing = await session.get(db.MusehubObject, wire_obj.object_id)
212 if existing is not None:
213 continue # already stored — idempotent
214
215 raw = wire_obj.content
216 storage_uri = await backend.put(repo_id, wire_obj.object_id, raw)
217 obj_row = db.MusehubObject(
218 object_id=wire_obj.object_id,
219 repo_id=repo_id,
220 path=wire_obj.path or "",
221 size_bytes=len(raw),
222 disk_path=storage_uri.replace("local://", ""),
223 storage_uri=storage_uri,
224 )
225 session.add(obj_row)
226
227 # ── 2. Snapshots ──────────────────────────────────────────────────────────
228 for wire_snap in bundle.snapshots:
229 if not wire_snap.snapshot_id:
230 continue
231 existing_snap = await session.get(db.MusehubSnapshot, wire_snap.snapshot_id)
232 if existing_snap is not None:
233 continue
234 snap_row = db.MusehubSnapshot(
235 snapshot_id=wire_snap.snapshot_id,
236 repo_id=repo_id,
237 manifest=wire_snap.manifest,
238 created_at=_parse_iso(wire_snap.created_at) if wire_snap.created_at else _utc_now(),
239 )
240 session.add(snap_row)
241
242 # ── 3. Commits ────────────────────────────────────────────────────────────
243 ordered_commits = _topological_sort(bundle.commits)
244 new_head: str | None = None
245
246 for wire_commit in ordered_commits:
247 if not wire_commit.commit_id:
248 continue
249 existing_commit = await session.get(db.MusehubCommit, wire_commit.commit_id)
250 if existing_commit is not None:
251 new_head = wire_commit.commit_id
252 continue
253
254 parent_ids: list[str] = []
255 if wire_commit.parent_commit_id:
256 parent_ids.append(wire_commit.parent_commit_id)
257 if wire_commit.parent2_commit_id:
258 parent_ids.append(wire_commit.parent2_commit_id)
259
260 commit_meta: dict[str, object] = {
261 "metadata": wire_commit.metadata,
262 "structured_delta": wire_commit.structured_delta,
263 "sem_ver_bump": wire_commit.sem_ver_bump,
264 "breaking_changes": wire_commit.breaking_changes,
265 "agent_id": wire_commit.agent_id,
266 "model_id": wire_commit.model_id,
267 "toolchain_id": wire_commit.toolchain_id,
268 "prompt_hash": wire_commit.prompt_hash,
269 "signature": wire_commit.signature,
270 "signer_key_id": wire_commit.signer_key_id,
271 "format_version": wire_commit.format_version,
272 "reviewed_by": wire_commit.reviewed_by,
273 "test_runs": wire_commit.test_runs,
274 }
275
276 # Fall back to the pusher's username when the CLI didn't supply --author.
277 author = wire_commit.author or _pusher_username
278 commit_row = db.MusehubCommit(
279 commit_id=wire_commit.commit_id,
280 repo_id=repo_id,
281 branch=branch_name,
282 parent_ids=parent_ids,
283 message=wire_commit.message,
284 author=author,
285 timestamp=_parse_iso(wire_commit.committed_at) if wire_commit.committed_at else _utc_now(),
286 snapshot_id=wire_commit.snapshot_id,
287 commit_meta=commit_meta,
288 )
289 session.add(commit_row)
290 new_head = wire_commit.commit_id
291
292 if new_head is None and bundle.commits:
293 new_head = bundle.commits[-1].commit_id
294
295 # ── 4. Branch pointer ─────────────────────────────────────────────────────
296 branch_row = (
297 await session.execute(
298 select(db.MusehubBranch).where(
299 db.MusehubBranch.repo_id == repo_id,
300 db.MusehubBranch.name == branch_name,
301 )
302 )
303 ).scalar_one_or_none()
304
305 if branch_row is None:
306 branch_row = db.MusehubBranch(
307 repo_id=repo_id,
308 name=branch_name,
309 head_commit_id=new_head or "",
310 )
311 session.add(branch_row)
312 else:
313 # fast-forward check: existing HEAD must appear in the parent chain of
314 # any pushed commit. BFS both parents so merge commits are handled.
315 if not req.force and branch_row.head_commit_id:
316 if not _is_ancestor_in_bundle(branch_row.head_commit_id, bundle.commits):
317 await session.rollback()
318 return WirePushResponse(
319 ok=False,
320 message=(
321 f"non-fast-forward push to '{branch_name}' — "
322 "use force=true to overwrite"
323 ),
324 branch_heads={},
325 )
326 branch_row.head_commit_id = new_head or branch_row.head_commit_id
327
328 # ── 5. Update repo ────────────────────────────────────────────────────────
329 repo_row.pushed_at = _utc_now()
330 # Only set default_branch on the very first push (no other branches exist).
331 # Never overwrite on subsequent pushes — doing so would silently corrupt the
332 # published default every time anyone pushes a non-default branch.
333 other_branches_q = await session.execute(
334 select(db.MusehubBranch).where(
335 db.MusehubBranch.repo_id == repo_id,
336 db.MusehubBranch.name != branch_name,
337 )
338 )
339 if other_branches_q.scalars().first() is None:
340 repo_row.default_branch = branch_name
341
342 await session.commit()
343
344 # Re-fetch branch heads after commit
345 branch_rows = (
346 await session.execute(
347 select(db.MusehubBranch).where(db.MusehubBranch.repo_id == repo_id)
348 )
349 ).scalars().all()
350 branch_heads = {b.name: b.head_commit_id for b in branch_rows if b.head_commit_id}
351
352 return WirePushResponse(
353 ok=True,
354 message=f"pushed {len(bundle.commits)} commit(s) to '{branch_name}'",
355 branch_heads=branch_heads,
356 remote_head=new_head or "",
357 )
358
359
360 async def wire_push_objects(
361 session: AsyncSession,
362 repo_id: str,
363 req: WireObjectsRequest,
364 pusher_id: str | None = None,
365 ) -> WireObjectsResponse:
366 """Pre-upload a chunk of objects for a large push.
367
368 This is Phase 1 of a chunked push. The client splits its object list
369 into batches of ≤ MAX_OBJECTS_PER_PUSH and calls this endpoint once per
370 batch. Phase 2 is ``wire_push`` with an empty ``bundle.objects`` list;
371 the final push only carries commits and snapshots (which are small).
372
373 Objects are content-addressed, so uploading the same object twice is
374 harmless — existing objects are skipped and counted as ``skipped``.
375 Authorization mirrors ``wire_push``: only the repo owner may upload.
376 """
377 repo_row = await session.get(db.MusehubRepo, repo_id)
378 if repo_row is None or repo_row.deleted_at is not None:
379 raise ValueError("repo not found")
380
381 if not pusher_id or pusher_id != repo_row.owner_user_id:
382 logger.warning(
383 "⚠️ push/objects rejected: pusher=%s is not owner of repo=%s",
384 pusher_id,
385 repo_id,
386 )
387 raise PermissionError("push rejected: not authorized")
388
389 backend = get_backend()
390 stored = 0
391 skipped = 0
392
393 for wire_obj in req.objects:
394 if not wire_obj.object_id or not wire_obj.content:
395 continue
396 existing = await session.get(db.MusehubObject, wire_obj.object_id)
397 if existing is not None:
398 skipped += 1
399 continue
400
401 raw = wire_obj.content
402 storage_uri = await backend.put(repo_id, wire_obj.object_id, raw)
403 obj_row = db.MusehubObject(
404 object_id=wire_obj.object_id,
405 repo_id=repo_id,
406 path=wire_obj.path or "",
407 size_bytes=len(raw),
408 disk_path=storage_uri.replace("local://", ""),
409 storage_uri=storage_uri,
410 )
411 session.add(obj_row)
412 stored += 1
413
414 await session.commit()
415 logger.info(
416 "✅ push/objects repo=%s stored=%d skipped=%d", repo_id, stored, skipped
417 )
418 return WireObjectsResponse(stored=stored, skipped=skipped)
419
420
421 async def _fetch_commit(
422 session: AsyncSession,
423 commit_id: str,
424 ) -> db.MusehubCommit | None:
425 """Load a single commit by primary key — avoids full-table scans."""
426 return await session.get(db.MusehubCommit, commit_id)
427
428
429 async def wire_fetch(
430 session: AsyncSession,
431 repo_id: str,
432 req: WireFetchRequest,
433 ) -> WireFetchResponse | None:
434 """Return the minimal set of commits/snapshots/objects to satisfy ``want``.
435
436 BFS from each ``want`` commit toward its ancestors, stopping at any commit
437 in ``have`` (already on client) or when a commit is missing (orphan).
438
439 Commits are loaded one at a time by primary key rather than doing a
440 full ``SELECT … WHERE repo_id = ?`` table scan. This keeps memory
441 proportional to the *delta* (commits the client needs) rather than the
442 total repository history.
443 """
444 repo_row = await session.get(db.MusehubRepo, repo_id)
445 if repo_row is None or repo_row.deleted_at is not None:
446 return None
447
448 have_set = set(req.have)
449 want_set = set(req.want)
450
451 # BFS from want → collect commits not in have.
452 # Each BFS level is fetched in a single batched SELECT … WHERE commit_id IN (…)
453 # rather than one query per commit, keeping N round-trips proportional to
454 # the *depth* of the delta (number of BFS levels) rather than its *width*.
455 needed_rows: dict[str, db.MusehubCommit] = {}
456 frontier: set[str] = want_set - have_set
457 visited: set[str] = set()
458
459 while frontier:
460 # Remove already-visited from this level's batch
461 batch = frontier - visited
462 if not batch:
463 break
464 visited.update(batch)
465
466 # One IN query for the entire frontier level
467 rows_q = await session.execute(
468 select(db.MusehubCommit).where(
469 db.MusehubCommit.commit_id.in_(batch),
470 db.MusehubCommit.repo_id == repo_id,
471 )
472 )
473 next_frontier: set[str] = set()
474 for row in rows_q.scalars().all():
475 needed_rows[row.commit_id] = row
476 for pid in (row.parent_ids or []):
477 if pid not in visited and pid not in have_set:
478 next_frontier.add(pid)
479 frontier = next_frontier
480
481 needed_commits = [_to_wire_commit(needed_rows[cid]) for cid in needed_rows]
482
483 # Collect snapshot IDs we need to send
484 snap_ids = {c.snapshot_id for c in needed_commits if c.snapshot_id}
485 wire_snapshots: list[WireSnapshot] = []
486 if snap_ids:
487 snap_rows_q = await session.execute(
488 select(db.MusehubSnapshot).where(db.MusehubSnapshot.snapshot_id.in_(snap_ids))
489 )
490 for sr in snap_rows_q.scalars().all():
491 wire_snapshots.append(_to_wire_snapshot(sr))
492
493 # Collect object IDs referenced by snapshots
494 all_obj_ids: set[str] = set()
495 for ws in wire_snapshots:
496 all_obj_ids.update(ws.manifest.values())
497
498 wire_objects: list[WireObject] = []
499 backend = get_backend()
500 if all_obj_ids:
501 obj_rows_q = await session.execute(
502 select(db.MusehubObject).where(
503 db.MusehubObject.repo_id == repo_id,
504 db.MusehubObject.object_id.in_(all_obj_ids),
505 )
506 )
507 for obj_row in obj_rows_q.scalars().all():
508 raw = await backend.get(repo_id, obj_row.object_id)
509 if raw is None:
510 continue
511 wire_objects.append(WireObject(
512 object_id=obj_row.object_id,
513 content=raw,
514 path=obj_row.path or "",
515 ))
516
517 # Current branch heads
518 branch_rows_q = await session.execute(
519 select(db.MusehubBranch).where(db.MusehubBranch.repo_id == repo_id)
520 )
521 branch_heads = {
522 b.name: b.head_commit_id
523 for b in branch_rows_q.scalars().all()
524 if b.head_commit_id
525 }
526
527 return WireFetchResponse(
528 commits=needed_commits,
529 snapshots=wire_snapshots,
530 objects=wire_objects,
531 branch_heads=branch_heads,
532 )
533
534
535 # ── MWP/2 service functions ────────────────────────────────────────────────────
536
537
538 async def wire_filter_objects(
539 session: AsyncSession,
540 repo_id: str,
541 req: WireFilterRequest,
542 ) -> WireFilterResponse:
543 """Return the subset of *req.object_ids* the remote does NOT already hold.
544
545 A single SQL ``WHERE object_id IN (…)`` query determines which IDs are
546 present. The complement is returned so the client uploads only the delta.
547 This is the highest-impact MWP/2 change: incremental pushes become
548 proportional to the *change*, not the full history.
549 """
550 if not req.object_ids:
551 return WireFilterResponse(missing=[])
552
553 present_q = await session.execute(
554 select(db.MusehubObject.object_id).where(
555 db.MusehubObject.repo_id == repo_id,
556 db.MusehubObject.object_id.in_(req.object_ids),
557 )
558 )
559 present: set[str] = {row[0] for row in present_q}
560 missing = [oid for oid in req.object_ids if oid not in present]
561 logger.info(
562 "filter-objects repo=%s total=%d missing=%d",
563 repo_id,
564 len(req.object_ids),
565 len(missing),
566 )
567 return WireFilterResponse(missing=missing)
568
569
570 async def wire_presign(
571 session: AsyncSession,
572 repo_id: str,
573 req: WirePresignRequest,
574 pusher_id: str | None = None,
575 ) -> WirePresignResponse:
576 """Return presigned S3/R2 PUT or GET URLs for large objects.
577
578 Objects are uploaded/downloaded directly to object storage, bypassing
579 the API server entirely. When the active backend is ``local://`` it does
580 not support presigned URLs, so all IDs are returned in ``inline`` and the
581 client falls back to the normal pack upload path.
582 """
583 repo_row = await session.get(db.MusehubRepo, repo_id)
584 if repo_row is None or repo_row.deleted_at is not None:
585 raise ValueError("repo not found")
586
587 if req.direction == "put" and (not pusher_id or pusher_id != repo_row.owner_user_id):
588 raise PermissionError("presign rejected: not authorized")
589
590 backend = get_backend()
591 # Local backends do not support presigned URLs — return all as inline.
592 if not hasattr(backend, "presign_put") or not hasattr(backend, "presign_get"):
593 return WirePresignResponse(presigned={}, inline=list(req.object_ids))
594
595 presigned: dict[str, str] = {}
596 for oid in req.object_ids:
597 if req.direction == "put":
598 url: str = await backend.presign_put(
599 repo_id, oid, ttl_seconds=req.ttl_seconds
600 )
601 else:
602 url = await backend.presign_get(
603 repo_id, oid, ttl_seconds=req.ttl_seconds
604 )
605 presigned[oid] = url
606
607 return WirePresignResponse(presigned=presigned, inline=[])
608
609
610 async def wire_negotiate(
611 session: AsyncSession,
612 repo_id: str,
613 req: WireNegotiateRequest,
614 ) -> WireNegotiateResponse:
615 """Multi-round commit negotiation (MWP/2 Phase 5).
616
617 The client sends a depth-limited list of ``have`` commit IDs it already
618 holds and the branch tips it ``want``s. The server responds with which
619 ``have`` IDs it recognises (``ack``), the deepest shared ancestor found
620 (``common_base``), and whether the common base is sufficient to compute
621 the delta without another round (``ready``).
622
623 A single round is almost always sufficient for incremental pulls. Full
624 clones of large repos may require 2-3 rounds with increasing depth.
625 """
626 repo_row = await session.get(db.MusehubRepo, repo_id)
627 if repo_row is None or repo_row.deleted_at is not None:
628 raise ValueError("repo not found")
629
630 have_set = set(req.have)
631 want_set = set(req.want)
632
633 # Which of the client's have-IDs does the server recognise?
634 if have_set:
635 ack_q = await session.execute(
636 select(db.MusehubCommit.commit_id).where(
637 db.MusehubCommit.repo_id == repo_id,
638 db.MusehubCommit.commit_id.in_(have_set),
639 )
640 )
641 ack = [row[0] for row in ack_q]
642 else:
643 ack = []
644
645 # Find the deepest acknowledged ancestor reachable from want.
646 # BFS from want, stop when we hit an acked have-ID.
647 ack_set = set(ack)
648 common_base: str | None = None
649
650 if ack_set and want_set:
651 frontier: set[str] = want_set - ack_set
652 visited: set[str] = set()
653 found = False
654 while frontier and not found:
655 batch = frontier - visited
656 if not batch:
657 break
658 visited.update(batch)
659 rows_q = await session.execute(
660 select(db.MusehubCommit).where(
661 db.MusehubCommit.commit_id.in_(batch),
662 db.MusehubCommit.repo_id == repo_id,
663 )
664 )
665 next_frontier: set[str] = set()
666 for row in rows_q.scalars().all():
667 for pid in (row.parent_ids or []):
668 if pid in ack_set:
669 common_base = pid
670 found = True
671 break
672 if pid not in visited:
673 next_frontier.add(pid)
674 if found:
675 break
676 frontier = next_frontier
677
678 # ready = True when we know the common base or the client has no have-IDs
679 # (full clone — server just sends everything from want).
680 ready = common_base is not None or not have_set
681
682 return WireNegotiateResponse(ack=ack, common_base=common_base, ready=ready)
683
684
685 # ── private helpers ────────────────────────────────────────────────────────────
686
687 def _topological_sort(commits: list[WireCommit]) -> list[WireCommit]:
688 """Sort commits so parents come before children (Kahn's algorithm)."""
689 if not commits:
690 return []
691 by_id = {c.commit_id: c for c in commits}
692 in_degree: dict[str, int] = {c.commit_id: 0 for c in commits}
693 children: dict[str, list[str]] = {c.commit_id: [] for c in commits}
694
695 for c in commits:
696 for pid in filter(None, [c.parent_commit_id, c.parent2_commit_id]):
697 if pid in by_id:
698 in_degree[c.commit_id] += 1
699 children[pid].append(c.commit_id)
700
701 queue = [cid for cid, deg in in_degree.items() if deg == 0]
702 result: list[WireCommit] = []
703 while queue:
704 cid = queue.pop(0)
705 result.append(by_id[cid])
706 for child_id in children.get(cid, []):
707 in_degree[child_id] -= 1
708 if in_degree[child_id] == 0:
709 queue.append(child_id)
710
711 # Append any commits that topological sort couldn't place (cycles/missing parents)
712 sorted_ids = {c.commit_id for c in result}
713 for c in commits:
714 if c.commit_id not in sorted_ids:
715 result.append(c)
716
717 return result
718
719
720 def _is_ancestor_in_bundle(head_id: str, commits: list[WireCommit]) -> bool:
721 """Return True if ``head_id`` appears in the ancestor graph of any pushed commit.
722
723 BFS both parents so merge commits are handled correctly: a merge commit
724 has parent_commit_id (first parent, the branch being merged into) and
725 parent2_commit_id (second parent, the branch being merged from). The
726 remote HEAD may be either parent.
727
728 The walk stops when it leaves the bundle (parent not found in commit_by_id)
729 since commits outside the bundle already exist on the server — if the
730 remote HEAD is outside the bundle, the client intentionally excluded it
731 via ``have``, which means it IS an ancestor (incremental push).
732 """
733 commit_by_id = {c.commit_id: c for c in commits}
734 visited: set[str] = set()
735 frontier: list[str] = [c.commit_id for c in commits]
736
737 while frontier:
738 cid = frontier.pop()
739 if cid in visited:
740 continue
741 visited.add(cid)
742 if cid == head_id:
743 return True
744 row = commit_by_id.get(cid)
745 if row is None:
746 # This commit is outside the bundle (already on server).
747 # The client used ``have`` to exclude it, meaning they consider
748 # this commit an ancestor they share — keep walking is pointless.
749 continue
750 for pid in filter(None, [row.parent_commit_id, row.parent2_commit_id]):
751 if pid not in visited:
752 frontier.append(pid)
753
754 return False