gabriel / musehub public
musehub_pull_requests.py python
530 lines 17.0 KB
cd448303 Initial extraction of MuseHub from maestro monorepo. Gabriel Cardona <gabriel@tellurstori.com> 7d ago
1 """Muse Hub pull request persistence adapter — single point of DB access for PRs.
2
3 This module is the ONLY place that touches the ``musehub_pull_requests`` table.
4 Route handlers delegate here; no business logic lives in routes.
5
6 Boundary rules:
7 - Must NOT import state stores, SSE queues, or LLM clients.
8 - Must NOT import musehub.core.* modules.
9 - May import ORM models from musehub.db.musehub_models.
10 - May import Pydantic response models from musehub.models.musehub.
11
12 Merge strategy
13 --------------
14 ``merge_commit`` is the only strategy at MVP. It creates a new commit on
15 ``to_branch`` whose parent_ids are [to_branch head, from_branch head], then
16 updates the ``to_branch`` head pointer and marks the PR as merged.
17
18 If either branch has no commits yet (no head commit), the merge is rejected with
19 a ``ValueError`` — there is nothing to merge.
20 """
21 from __future__ import annotations
22
23 import logging
24 import uuid
25 from datetime import datetime, timezone
26
27 from sqlalchemy import select
28 from sqlalchemy.ext.asyncio import AsyncSession
29
30 from musehub.db import musehub_models as db
31 from musehub.models.musehub import (
32 PRCommentListResponse,
33 PRCommentResponse,
34 PRResponse,
35 PRReviewListResponse,
36 PRReviewResponse,
37 )
38
39 logger = logging.getLogger(__name__)
40
41
42 def _utc_now() -> datetime:
43 return datetime.now(tz=timezone.utc)
44
45
46 def _to_pr_response(row: db.MusehubPullRequest) -> PRResponse:
47 return PRResponse(
48 pr_id=row.pr_id,
49 title=row.title,
50 body=row.body,
51 state=row.state,
52 from_branch=row.from_branch,
53 to_branch=row.to_branch,
54 merge_commit_id=row.merge_commit_id,
55 merged_at=row.merged_at,
56 author=row.author,
57 created_at=row.created_at,
58 )
59
60
61 async def _get_branch(
62 session: AsyncSession, repo_id: str, branch_name: str
63 ) -> db.MusehubBranch | None:
64 """Return the branch record by repo + name, or None."""
65 stmt = select(db.MusehubBranch).where(
66 db.MusehubBranch.repo_id == repo_id,
67 db.MusehubBranch.name == branch_name,
68 )
69 return (await session.execute(stmt)).scalar_one_or_none()
70
71
72 async def create_pr(
73 session: AsyncSession,
74 *,
75 repo_id: str,
76 title: str,
77 from_branch: str,
78 to_branch: str,
79 body: str = "",
80 author: str = "",
81 ) -> PRResponse:
82 """Persist a new pull request in ``open`` state and return its wire representation.
83
84 ``author`` identifies the user opening the PR — typically the JWT ``sub``
85 claim from the request token, or a display name from the seed script.
86
87 Raises ``ValueError`` if ``from_branch`` does not exist in the repo
88 the caller should surface this as HTTP 404.
89 """
90 branch = await _get_branch(session, repo_id, from_branch)
91 if branch is None:
92 raise ValueError(f"Branch '{from_branch}' not found in repo {repo_id}")
93
94 pr = db.MusehubPullRequest(
95 repo_id=repo_id,
96 title=title,
97 body=body,
98 state="open",
99 from_branch=from_branch,
100 to_branch=to_branch,
101 author=author,
102 )
103 session.add(pr)
104 await session.flush()
105 await session.refresh(pr)
106 logger.info("✅ Created PR '%s' (%s → %s) in repo %s", title, from_branch, to_branch, repo_id)
107 return _to_pr_response(pr)
108
109
110 async def list_prs(
111 session: AsyncSession,
112 repo_id: str,
113 *,
114 state: str = "all",
115 ) -> list[PRResponse]:
116 """Return pull requests for a repo, ordered by created_at ascending.
117
118 ``state`` may be ``"open"``, ``"merged"``, ``"closed"``, or ``"all"``.
119 """
120 stmt = select(db.MusehubPullRequest).where(
121 db.MusehubPullRequest.repo_id == repo_id
122 )
123 if state != "all":
124 stmt = stmt.where(db.MusehubPullRequest.state == state)
125 stmt = stmt.order_by(db.MusehubPullRequest.created_at)
126 rows = (await session.execute(stmt)).scalars().all()
127 return [_to_pr_response(r) for r in rows]
128
129
130 async def get_pr(
131 session: AsyncSession,
132 repo_id: str,
133 pr_id: str,
134 ) -> PRResponse | None:
135 """Return a single PR by its ID, or None if not found."""
136 stmt = select(db.MusehubPullRequest).where(
137 db.MusehubPullRequest.repo_id == repo_id,
138 db.MusehubPullRequest.pr_id == pr_id,
139 )
140 row = (await session.execute(stmt)).scalar_one_or_none()
141 if row is None:
142 return None
143 return _to_pr_response(row)
144
145
146 async def merge_pr(
147 session: AsyncSession,
148 repo_id: str,
149 pr_id: str,
150 *,
151 merge_strategy: str = "merge_commit",
152 ) -> PRResponse:
153 """Merge an open PR using the given strategy.
154
155 Creates a merge commit on ``to_branch`` with parent_ids =
156 [to_branch head, from_branch head], updates the branch head pointer, and
157 marks the PR as ``merged``. Sets ``merged_at`` to the current UTC time
158 so the timeline overlay can position the merge marker at the actual merge
159 instant rather than the PR creation date.
160
161 Raises:
162 ValueError: PR not found or ``from_branch`` does not exist or has no commits.
163 RuntimeError: PR is already merged or closed (caller surfaces as 409).
164 """
165 stmt = select(db.MusehubPullRequest).where(
166 db.MusehubPullRequest.repo_id == repo_id,
167 db.MusehubPullRequest.pr_id == pr_id,
168 )
169 pr = (await session.execute(stmt)).scalar_one_or_none()
170 if pr is None:
171 raise ValueError(f"Pull request {pr_id} not found in repo {repo_id}")
172
173 if pr.state != "open":
174 raise RuntimeError(f"Pull request {pr_id} is already {pr.state}")
175
176 from_b = await _get_branch(session, repo_id, pr.from_branch)
177 to_b = await _get_branch(session, repo_id, pr.to_branch)
178
179 # Collect parent commit IDs for the merge commit.
180 parent_ids: list[str] = []
181 if to_b is not None and to_b.head_commit_id is not None:
182 parent_ids.append(to_b.head_commit_id)
183 if from_b is not None and from_b.head_commit_id is not None:
184 parent_ids.append(from_b.head_commit_id)
185
186 if not parent_ids:
187 raise ValueError(
188 f"Cannot merge: neither '{pr.from_branch}' nor '{pr.to_branch}' has any commits"
189 )
190
191 # Create the merge commit on to_branch.
192 merge_commit_id = str(uuid.uuid4()).replace("-", "")
193 merge_commit = db.MusehubCommit(
194 commit_id=merge_commit_id,
195 repo_id=repo_id,
196 branch=pr.to_branch,
197 parent_ids=parent_ids,
198 message=f"Merge '{pr.from_branch}' into '{pr.to_branch}' — PR: {pr.title}",
199 author="musehub-server",
200 timestamp=_utc_now(),
201 )
202 session.add(merge_commit)
203
204 # Advance (or create) the to_branch head pointer.
205 if to_b is None:
206 to_b = db.MusehubBranch(
207 repo_id=repo_id,
208 name=pr.to_branch,
209 head_commit_id=merge_commit_id,
210 )
211 session.add(to_b)
212 else:
213 to_b.head_commit_id = merge_commit_id
214
215 # Mark PR as merged and record the exact merge timestamp.
216 pr.state = "merged"
217 pr.merge_commit_id = merge_commit_id
218 pr.merged_at = _utc_now()
219
220 await session.flush()
221 await session.refresh(pr)
222 logger.info(
223 "✅ Merged PR %s ('%s' → '%s') in repo %s, merge commit %s",
224 pr_id,
225 pr.from_branch,
226 pr.to_branch,
227 repo_id,
228 merge_commit_id,
229 )
230 return _to_pr_response(pr)
231
232
233 # ---------------------------------------------------------------------------
234 # PR review comments
235 # ---------------------------------------------------------------------------
236
237
238 def _to_comment_response(row: db.MusehubPRComment) -> PRCommentResponse:
239 return PRCommentResponse(
240 comment_id=row.comment_id,
241 pr_id=row.pr_id,
242 author=row.author,
243 body=row.body,
244 target_type=row.target_type,
245 target_track=row.target_track,
246 target_beat_start=row.target_beat_start,
247 target_beat_end=row.target_beat_end,
248 target_note_pitch=row.target_note_pitch,
249 parent_comment_id=row.parent_comment_id,
250 created_at=row.created_at,
251 )
252
253
254 async def create_pr_comment(
255 session: AsyncSession,
256 *,
257 pr_id: str,
258 repo_id: str,
259 author: str,
260 body: str,
261 target_type: str = "general",
262 target_track: str | None = None,
263 target_beat_start: float | None = None,
264 target_beat_end: float | None = None,
265 target_note_pitch: int | None = None,
266 parent_comment_id: str | None = None,
267 ) -> PRCommentResponse:
268 """Persist a new review comment on a PR and return its wire representation.
269
270 ``author`` is the JWT ``sub`` claim of the reviewer.
271 ``parent_comment_id`` must be an existing top-level comment on the same PR
272 when creating a threaded reply; the caller validates this constraint before
273 calling here.
274
275 Raises ``ValueError`` if the PR does not exist in the given repo.
276 """
277 stmt = select(db.MusehubPullRequest).where(
278 db.MusehubPullRequest.pr_id == pr_id,
279 db.MusehubPullRequest.repo_id == repo_id,
280 )
281 pr = (await session.execute(stmt)).scalar_one_or_none()
282 if pr is None:
283 raise ValueError(f"Pull request {pr_id} not found in repo {repo_id}")
284
285 comment = db.MusehubPRComment(
286 pr_id=pr_id,
287 repo_id=repo_id,
288 author=author,
289 body=body,
290 target_type=target_type,
291 target_track=target_track,
292 target_beat_start=target_beat_start,
293 target_beat_end=target_beat_end,
294 target_note_pitch=target_note_pitch,
295 parent_comment_id=parent_comment_id,
296 )
297 session.add(comment)
298 await session.flush()
299 await session.refresh(comment)
300 logger.info("✅ Created PR comment %s on PR %s by %s", comment.comment_id, pr_id, author)
301 return _to_comment_response(comment)
302
303
304 async def list_pr_comments(
305 session: AsyncSession,
306 pr_id: str,
307 repo_id: str,
308 ) -> PRCommentListResponse:
309 """Return all review comments for a PR, assembled into a two-level thread tree.
310
311 Top-level comments (``parent_comment_id`` is None) form the root list.
312 Each carries a ``replies`` list with direct children sorted by
313 ``created_at`` ascending. Grandchildren are not supported — the caller
314 should reply to the original top-level comment.
315
316 Returns ``PRCommentListResponse`` with ``total`` covering all levels.
317 """
318 stmt = (
319 select(db.MusehubPRComment)
320 .where(
321 db.MusehubPRComment.pr_id == pr_id,
322 db.MusehubPRComment.repo_id == repo_id,
323 )
324 .order_by(db.MusehubPRComment.created_at)
325 )
326 rows = (await session.execute(stmt)).scalars().all()
327
328 # Build id → response map first; attach replies in a second pass.
329 top_level: list[PRCommentResponse] = []
330 by_id: dict[str, PRCommentResponse] = {}
331 for row in rows:
332 resp = _to_comment_response(row)
333 by_id[row.comment_id] = resp
334 if row.parent_comment_id is None:
335 top_level.append(resp)
336
337 for row in rows:
338 if row.parent_comment_id is not None:
339 parent = by_id.get(row.parent_comment_id)
340 if parent is not None:
341 parent.replies.append(by_id[row.comment_id])
342
343 return PRCommentListResponse(comments=top_level, total=len(rows))
344
345
346 # ---------------------------------------------------------------------------
347 # PR reviews (reviewer assignment + approval workflow)
348 # ---------------------------------------------------------------------------
349
350
351 def _to_review_response(row: db.MusehubPRReview) -> PRReviewResponse:
352 return PRReviewResponse(
353 id=row.id,
354 pr_id=row.pr_id,
355 reviewer_username=row.reviewer_username,
356 state=row.state,
357 body=row.body,
358 submitted_at=row.submitted_at,
359 created_at=row.created_at,
360 )
361
362
363 async def _assert_pr_exists(session: AsyncSession, repo_id: str, pr_id: str) -> None:
364 """Raise ``ValueError`` if the PR does not exist in the given repo."""
365 stmt = select(db.MusehubPullRequest).where(
366 db.MusehubPullRequest.pr_id == pr_id,
367 db.MusehubPullRequest.repo_id == repo_id,
368 )
369 pr = (await session.execute(stmt)).scalar_one_or_none()
370 if pr is None:
371 raise ValueError(f"Pull request {pr_id} not found in repo {repo_id}")
372
373
374 async def request_reviewers(
375 session: AsyncSession,
376 *,
377 repo_id: str,
378 pr_id: str,
379 reviewers: list[str],
380 ) -> PRReviewListResponse:
381 """Add reviewer assignments to a PR, creating a ``pending`` row for each.
382
383 Idempotent: if a reviewer already has a row (in any state), the existing row
384 is left unchanged so a submitted approval is never reset by a re-request.
385
386 Raises ``ValueError`` if the PR does not exist in the repo.
387
388 Returns the full updated review list for the PR.
389 """
390 await _assert_pr_exists(session, repo_id, pr_id)
391
392 for username in reviewers:
393 existing_stmt = select(db.MusehubPRReview).where(
394 db.MusehubPRReview.pr_id == pr_id,
395 db.MusehubPRReview.reviewer_username == username,
396 )
397 existing = (await session.execute(existing_stmt)).scalar_one_or_none()
398 if existing is None:
399 review = db.MusehubPRReview(pr_id=pr_id, reviewer_username=username, state="pending")
400 session.add(review)
401 logger.info("✅ Requested review from '%s' on PR %s", username, pr_id)
402
403 await session.flush()
404 return await list_reviews(session, repo_id=repo_id, pr_id=pr_id)
405
406
407 async def remove_reviewer(
408 session: AsyncSession,
409 *,
410 repo_id: str,
411 pr_id: str,
412 username: str,
413 ) -> PRReviewListResponse:
414 """Remove a pending review request for ``username`` on a PR.
415
416 Only ``pending`` rows may be removed — submitted reviews are immutable to
417 preserve the audit trail.
418
419 Raises ``ValueError`` if the PR does not exist, the reviewer was never
420 requested, or the reviewer has already submitted a non-pending review.
421
422 Returns the updated review list.
423 """
424 await _assert_pr_exists(session, repo_id, pr_id)
425
426 stmt = select(db.MusehubPRReview).where(
427 db.MusehubPRReview.pr_id == pr_id,
428 db.MusehubPRReview.reviewer_username == username,
429 )
430 row = (await session.execute(stmt)).scalar_one_or_none()
431 if row is None:
432 raise ValueError(f"Reviewer '{username}' was not requested on PR {pr_id}")
433 if row.state != "pending":
434 raise ValueError(
435 f"Cannot remove reviewer '{username}': review already submitted (state={row.state})"
436 )
437
438 await session.delete(row)
439 await session.flush()
440 logger.info("✅ Removed review request for '%s' from PR %s", username, pr_id)
441 return await list_reviews(session, repo_id=repo_id, pr_id=pr_id)
442
443
444 async def list_reviews(
445 session: AsyncSession,
446 *,
447 repo_id: str,
448 pr_id: str,
449 state: str | None = None,
450 ) -> PRReviewListResponse:
451 """Return all reviews for a PR, optionally filtered by state.
452
453 ``state`` may be one of ``pending``, ``approved``, ``changes_requested``,
454 or ``dismissed``. When ``None``, all reviews are returned.
455
456 Raises ``ValueError`` if the PR does not exist in the repo.
457 """
458 await _assert_pr_exists(session, repo_id, pr_id)
459
460 stmt = select(db.MusehubPRReview).where(db.MusehubPRReview.pr_id == pr_id)
461 if state is not None:
462 stmt = stmt.where(db.MusehubPRReview.state == state)
463 stmt = stmt.order_by(db.MusehubPRReview.created_at)
464 rows = (await session.execute(stmt)).scalars().all()
465 reviews = [_to_review_response(r) for r in rows]
466 return PRReviewListResponse(reviews=reviews, total=len(reviews))
467
468
469 async def submit_review(
470 session: AsyncSession,
471 *,
472 repo_id: str,
473 pr_id: str,
474 reviewer_username: str,
475 event: str,
476 body: str = "",
477 ) -> PRReviewResponse:
478 """Submit or update a formal review for ``reviewer_username`` on a PR.
479
480 ``event`` maps to a new state:
481 - ``approve`` → ``approved``
482 - ``request_changes`` → ``changes_requested``
483 - ``comment`` → ``pending`` (body-only, no verdict change)
484
485 If an existing row for this reviewer already exists, it is updated in-place.
486 If no row exists (reviewer was not formally requested), a new row is created
487 so ad-hoc reviews are allowed.
488
489 Raises ``ValueError`` if the PR does not exist in the repo.
490 """
491 await _assert_pr_exists(session, repo_id, pr_id)
492
493 _EVENT_TO_STATE: dict[str, str] = {
494 "approve": "approved",
495 "request_changes": "changes_requested",
496 "comment": "pending",
497 }
498 new_state = _EVENT_TO_STATE[event]
499
500 stmt = select(db.MusehubPRReview).where(
501 db.MusehubPRReview.pr_id == pr_id,
502 db.MusehubPRReview.reviewer_username == reviewer_username,
503 )
504 row = (await session.execute(stmt)).scalar_one_or_none()
505
506 now = _utc_now()
507 if row is None:
508 row = db.MusehubPRReview(
509 pr_id=pr_id,
510 reviewer_username=reviewer_username,
511 state=new_state,
512 body=body or None,
513 submitted_at=now if event != "comment" else None,
514 )
515 session.add(row)
516 else:
517 row.state = new_state
518 row.body = body or None
519 row.submitted_at = now if event != "comment" else row.submitted_at
520
521 await session.flush()
522 await session.refresh(row)
523 logger.info(
524 "✅ Review submitted by '%s' on PR %s: event=%s state=%s",
525 reviewer_username,
526 pr_id,
527 event,
528 new_state,
529 )
530 return _to_review_response(row)