gabriel / muse public
forecast.py python
254 lines 9.7 KB
e6786943 feat: upgrade to Python 3.14, drop from __future__ import annotations Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """muse forecast — predict merge conflicts before they happen.
2
3 Reads all active reservations and intents across branches, then uses the
4 reverse call graph to compute *likely* conflicts before any code is written.
5
6 This turns merge conflict resolution from a reactive ("it broke") problem into
7 a proactive ("we predicted it") workflow — essential when many agents operate
8 on a codebase simultaneously.
9
10 Conflict types detected
11 -----------------------
12 ``address_overlap``
13 Two agents have reserved the same symbol address. Direct collision.
14
15 ``blast_radius_overlap``
16 Agent A's reserved symbol is in the call chain of Agent B's target, or
17 vice versa. A change to A's symbol will affect B's symbol.
18
19 ``operation_conflict``
20 Agent A intends to delete/rename a symbol that Agent B intends to modify.
21 Classic use-after-free / use-after-rename semantic conflict.
22
23 Usage::
24
25 muse forecast
26 muse forecast --branch feature-x
27 muse forecast --json
28
29 Output::
30
31 Conflict forecast — 3 active reservations, 1 intent
32 ──────────────────────────────────────────────────────────────
33
34 ⚠️ address_overlap (confidence 1.00)
35 src/billing.py::compute_total
36 agent-41 (branch: main) ↔ agent-42 (branch: feature/billing)
37
38 ⚠️ blast_radius_overlap (confidence 0.75)
39 agent-42 reserved src/billing.py::compute_total
40 agent-39 reserved src/api.py::process_payment
41 → compute_total is in the call chain of process_payment
42
43 No operation conflicts detected.
44
45 1 high-risk, 1 medium-risk, 0 low-risk conflict(s)
46
47 Flags:
48
49 ``--branch BRANCH``
50 Filter to reservations on a specific branch.
51
52 ``--json``
53 Emit the full forecast as JSON.
54 """
55
56 import json
57 import logging
58 import pathlib
59
60 import typer
61
62 from muse.core.coordination import active_reservations, load_all_intents
63 from muse.core.repo import require_repo
64 from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref
65 from muse.plugins.code._callgraph import build_reverse_graph, transitive_callers
66
67 logger = logging.getLogger(__name__)
68
69 app = typer.Typer()
70
71
72 class _ConflictPrediction:
73 def __init__(
74 self,
75 conflict_type: str,
76 addresses: list[str],
77 agents: list[str],
78 confidence: float,
79 description: str,
80 ) -> None:
81 self.conflict_type = conflict_type
82 self.addresses = addresses
83 self.agents = agents
84 self.confidence = confidence
85 self.description = description
86
87 def to_dict(self) -> dict[str, str | float | list[str]]:
88 return {
89 "conflict_type": self.conflict_type,
90 "addresses": self.addresses,
91 "agents": self.agents,
92 "confidence": round(self.confidence, 3),
93 "description": self.description,
94 }
95
96
97 def _read_repo_id(root: pathlib.Path) -> str:
98 return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"])
99
100
101 def _read_branch(root: pathlib.Path) -> str:
102 head_ref = (root / ".muse" / "HEAD").read_text().strip()
103 return head_ref.removeprefix("refs/heads/").strip()
104
105
106 @app.callback(invoke_without_command=True)
107 def forecast(
108 ctx: typer.Context,
109 branch_filter: str | None = typer.Option(
110 None, "--branch", "-b", metavar="BRANCH",
111 help="Restrict to reservations on this branch.",
112 ),
113 as_json: bool = typer.Option(False, "--json", help="Emit forecast as JSON."),
114 ) -> None:
115 """Predict merge conflicts from active reservations and intents.
116
117 Reads ``.muse/coordination/reservations/`` and ``intents/``, then:
118
119 1. Reports direct address-level overlaps (confidence 1.0).
120 2. Computes blast-radius overlaps using the Python call graph.
121 3. Reports operation-type conflicts (delete vs modify on same address).
122
123 Use ``muse reserve`` and ``muse intent`` to register your agent's work
124 plan before this command becomes useful.
125 """
126 import pathlib
127
128 root = require_repo()
129 repo_id = _read_repo_id(root)
130 branch = _read_branch(root)
131
132 reservations = active_reservations(root)
133 intents = load_all_intents(root)
134
135 if branch_filter:
136 reservations = [r for r in reservations if r.branch == branch_filter]
137 intents = [i for i in intents if i.branch == branch_filter]
138
139 conflicts: list[_ConflictPrediction] = []
140
141 # ── Direct address overlap ─────────────────────────────────────────────
142 # Build: address → list of (run_id, branch)
143 addr_agents: dict[str, list[str]] = {}
144 for res in reservations:
145 for addr in res.addresses:
146 addr_agents.setdefault(addr, []).append(f"{res.run_id}@{res.branch}")
147
148 for addr, agents in sorted(addr_agents.items()):
149 unique_agents = list(dict.fromkeys(agents))
150 if len(unique_agents) > 1:
151 conflicts.append(_ConflictPrediction(
152 conflict_type="address_overlap",
153 addresses=[addr],
154 agents=unique_agents,
155 confidence=1.0,
156 description=f"{addr} reserved by {len(unique_agents)} agents simultaneously",
157 ))
158
159 # ── Blast-radius overlap ───────────────────────────────────────────────
160 # Use the call graph to check if any two reservations' addresses are in
161 # each other's transitive call chains.
162 commit = resolve_commit_ref(root, repo_id, branch, None)
163 if commit is not None:
164 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
165 try:
166 reverse = build_reverse_graph(root, manifest)
167 all_addresses = list(addr_agents.keys())
168 for i, addr_a in enumerate(all_addresses):
169 callers_a = transitive_callers(addr_a, reverse, max_depth=5)
170 callers_set: set[str] = {c for lvl in callers_a.values() for c in lvl}
171 for addr_b in all_addresses[i + 1:]:
172 if addr_b in callers_set:
173 agents_a = addr_agents.get(addr_a, [])
174 agents_b = addr_agents.get(addr_b, [])
175 if set(agents_a) != set(agents_b):
176 conflicts.append(_ConflictPrediction(
177 conflict_type="blast_radius_overlap",
178 addresses=[addr_a, addr_b],
179 agents=list(set(agents_a) | set(agents_b)),
180 confidence=0.75,
181 description=(
182 f"{addr_b} is in the transitive call chain of {addr_a}"
183 ),
184 ))
185 except Exception as exc: # noqa: BLE001
186 logger.debug("Call graph unavailable for forecast: %s", exc)
187
188 # ── Operation conflicts ────────────────────────────────────────────────
189 # Collect intents by address.
190 intent_ops: dict[str, list[str]] = {} # address → list of operations
191 intent_agents: dict[str, list[str]] = {} # address → list of run_ids
192 for it in intents:
193 for addr in it.addresses:
194 intent_ops.setdefault(addr, []).append(it.operation)
195 intent_agents.setdefault(addr, []).append(f"{it.run_id}@{it.branch}")
196
197 for addr, ops in sorted(intent_ops.items()):
198 if len(set(ops)) <= 1 and len(set(intent_agents.get(addr, []))) <= 1:
199 continue # Same op by same agent — not a conflict.
200 has_delete = "delete" in ops
201 has_modify = any(op in ("modify", "rename", "extract") for op in ops)
202 if has_delete and has_modify:
203 agents = list(dict.fromkeys(intent_agents.get(addr, [])))
204 conflicts.append(_ConflictPrediction(
205 conflict_type="operation_conflict",
206 addresses=[addr],
207 agents=agents,
208 confidence=0.9,
209 description=f"delete vs modify conflict on {addr}",
210 ))
211
212 if as_json:
213 typer.echo(json.dumps(
214 {
215 "schema_version": 1,
216 "active_reservations": len(reservations),
217 "intents": len(intents),
218 "branch_filter": branch_filter,
219 "conflicts": [c.to_dict() for c in conflicts],
220 "high_risk": sum(1 for c in conflicts if c.confidence >= 0.9),
221 "medium_risk": sum(1 for c in conflicts if 0.5 <= c.confidence < 0.9),
222 "low_risk": sum(1 for c in conflicts if c.confidence < 0.5),
223 },
224 indent=2,
225 ))
226 return
227
228 typer.echo(
229 f"\nConflict forecast — "
230 f"{len(reservations)} active reservation(s), {len(intents)} intent(s)"
231 )
232 typer.echo("─" * 62)
233
234 if not conflicts:
235 typer.echo("\n ✅ No conflicts predicted.")
236 if not reservations:
237 typer.echo(" (no active reservations — run 'muse reserve' first)")
238 return
239
240 for c in conflicts:
241 icon = "🔴" if c.confidence >= 0.9 else "⚠️ "
242 typer.echo(f"\n{icon} {c.conflict_type} (confidence {c.confidence:.2f})")
243 for addr in c.addresses:
244 typer.echo(f" {addr}")
245 for agent in c.agents:
246 typer.echo(f" agent: {agent}")
247 typer.echo(f" → {c.description}")
248
249 high = sum(1 for c in conflicts if c.confidence >= 0.9)
250 med = sum(1 for c in conflicts if 0.5 <= c.confidence < 0.9)
251 typer.echo(
252 f"\n {high} high-risk, {med} medium-risk conflict(s) predicted"
253 )
254 typer.echo(" Run 'muse plan-merge' for a detailed merge strategy.")