gabriel / muse public
forecast.py python
256 lines 9.8 KB
bda49bdb feat: redesign .museignore as TOML with domain-scoped sections (#100) Gabriel Cardona <cgcardona@gmail.com> 5d ago
1 """muse forecast — predict merge conflicts before they happen.
2
3 Reads all active reservations and intents across branches, then uses the
4 reverse call graph to compute *likely* conflicts before any code is written.
5
6 This turns merge conflict resolution from a reactive ("it broke") problem into
7 a proactive ("we predicted it") workflow — essential when many agents operate
8 on a codebase simultaneously.
9
10 Conflict types detected
11 -----------------------
12 ``address_overlap``
13 Two agents have reserved the same symbol address. Direct collision.
14
15 ``blast_radius_overlap``
16 Agent A's reserved symbol is in the call chain of Agent B's target, or
17 vice versa. A change to A's symbol will affect B's symbol.
18
19 ``operation_conflict``
20 Agent A intends to delete/rename a symbol that Agent B intends to modify.
21 Classic use-after-free / use-after-rename semantic conflict.
22
23 Usage::
24
25 muse forecast
26 muse forecast --branch feature-x
27 muse forecast --json
28
29 Output::
30
31 Conflict forecast — 3 active reservations, 1 intent
32 ──────────────────────────────────────────────────────────────
33
34 ⚠️ address_overlap (confidence 1.00)
35 src/billing.py::compute_total
36 agent-41 (branch: main) ↔ agent-42 (branch: feature/billing)
37
38 ⚠️ blast_radius_overlap (confidence 0.75)
39 agent-42 reserved src/billing.py::compute_total
40 agent-39 reserved src/api.py::process_payment
41 → compute_total is in the call chain of process_payment
42
43 No operation conflicts detected.
44
45 1 high-risk, 1 medium-risk, 0 low-risk conflict(s)
46
47 Flags:
48
49 ``--branch BRANCH``
50 Filter to reservations on a specific branch.
51
52 ``--json``
53 Emit the full forecast as JSON.
54 """
55
56 from __future__ import annotations
57
58 import json
59 import logging
60 import pathlib
61
62 import typer
63
64 from muse.core.coordination import active_reservations, load_all_intents
65 from muse.core.repo import require_repo
66 from muse.core.store import get_commit_snapshot_manifest, resolve_commit_ref
67 from muse.plugins.code._callgraph import build_reverse_graph, transitive_callers
68
69 logger = logging.getLogger(__name__)
70
71 app = typer.Typer()
72
73
74 class _ConflictPrediction:
75 def __init__(
76 self,
77 conflict_type: str,
78 addresses: list[str],
79 agents: list[str],
80 confidence: float,
81 description: str,
82 ) -> None:
83 self.conflict_type = conflict_type
84 self.addresses = addresses
85 self.agents = agents
86 self.confidence = confidence
87 self.description = description
88
89 def to_dict(self) -> dict[str, str | float | list[str]]:
90 return {
91 "conflict_type": self.conflict_type,
92 "addresses": self.addresses,
93 "agents": self.agents,
94 "confidence": round(self.confidence, 3),
95 "description": self.description,
96 }
97
98
99 def _read_repo_id(root: pathlib.Path) -> str:
100 return str(json.loads((root / ".muse" / "repo.json").read_text())["repo_id"])
101
102
103 def _read_branch(root: pathlib.Path) -> str:
104 head_ref = (root / ".muse" / "HEAD").read_text().strip()
105 return head_ref.removeprefix("refs/heads/").strip()
106
107
108 @app.callback(invoke_without_command=True)
109 def forecast(
110 ctx: typer.Context,
111 branch_filter: str | None = typer.Option(
112 None, "--branch", "-b", metavar="BRANCH",
113 help="Restrict to reservations on this branch.",
114 ),
115 as_json: bool = typer.Option(False, "--json", help="Emit forecast as JSON."),
116 ) -> None:
117 """Predict merge conflicts from active reservations and intents.
118
119 Reads ``.muse/coordination/reservations/`` and ``intents/``, then:
120
121 1. Reports direct address-level overlaps (confidence 1.0).
122 2. Computes blast-radius overlaps using the Python call graph.
123 3. Reports operation-type conflicts (delete vs modify on same address).
124
125 Use ``muse reserve`` and ``muse intent`` to register your agent's work
126 plan before this command becomes useful.
127 """
128 import pathlib
129
130 root = require_repo()
131 repo_id = _read_repo_id(root)
132 branch = _read_branch(root)
133
134 reservations = active_reservations(root)
135 intents = load_all_intents(root)
136
137 if branch_filter:
138 reservations = [r for r in reservations if r.branch == branch_filter]
139 intents = [i for i in intents if i.branch == branch_filter]
140
141 conflicts: list[_ConflictPrediction] = []
142
143 # ── Direct address overlap ─────────────────────────────────────────────
144 # Build: address → list of (run_id, branch)
145 addr_agents: dict[str, list[str]] = {}
146 for res in reservations:
147 for addr in res.addresses:
148 addr_agents.setdefault(addr, []).append(f"{res.run_id}@{res.branch}")
149
150 for addr, agents in sorted(addr_agents.items()):
151 unique_agents = list(dict.fromkeys(agents))
152 if len(unique_agents) > 1:
153 conflicts.append(_ConflictPrediction(
154 conflict_type="address_overlap",
155 addresses=[addr],
156 agents=unique_agents,
157 confidence=1.0,
158 description=f"{addr} reserved by {len(unique_agents)} agents simultaneously",
159 ))
160
161 # ── Blast-radius overlap ───────────────────────────────────────────────
162 # Use the call graph to check if any two reservations' addresses are in
163 # each other's transitive call chains.
164 commit = resolve_commit_ref(root, repo_id, branch, None)
165 if commit is not None:
166 manifest = get_commit_snapshot_manifest(root, commit.commit_id) or {}
167 try:
168 reverse = build_reverse_graph(root, manifest)
169 all_addresses = list(addr_agents.keys())
170 for i, addr_a in enumerate(all_addresses):
171 callers_a = transitive_callers(addr_a, reverse, max_depth=5)
172 callers_set: set[str] = {c for lvl in callers_a.values() for c in lvl}
173 for addr_b in all_addresses[i + 1:]:
174 if addr_b in callers_set:
175 agents_a = addr_agents.get(addr_a, [])
176 agents_b = addr_agents.get(addr_b, [])
177 if set(agents_a) != set(agents_b):
178 conflicts.append(_ConflictPrediction(
179 conflict_type="blast_radius_overlap",
180 addresses=[addr_a, addr_b],
181 agents=list(set(agents_a) | set(agents_b)),
182 confidence=0.75,
183 description=(
184 f"{addr_b} is in the transitive call chain of {addr_a}"
185 ),
186 ))
187 except Exception as exc: # noqa: BLE001
188 logger.debug("Call graph unavailable for forecast: %s", exc)
189
190 # ── Operation conflicts ────────────────────────────────────────────────
191 # Collect intents by address.
192 intent_ops: dict[str, list[str]] = {} # address → list of operations
193 intent_agents: dict[str, list[str]] = {} # address → list of run_ids
194 for it in intents:
195 for addr in it.addresses:
196 intent_ops.setdefault(addr, []).append(it.operation)
197 intent_agents.setdefault(addr, []).append(f"{it.run_id}@{it.branch}")
198
199 for addr, ops in sorted(intent_ops.items()):
200 if len(set(ops)) <= 1 and len(set(intent_agents.get(addr, []))) <= 1:
201 continue # Same op by same agent — not a conflict.
202 has_delete = "delete" in ops
203 has_modify = any(op in ("modify", "rename", "extract") for op in ops)
204 if has_delete and has_modify:
205 agents = list(dict.fromkeys(intent_agents.get(addr, [])))
206 conflicts.append(_ConflictPrediction(
207 conflict_type="operation_conflict",
208 addresses=[addr],
209 agents=agents,
210 confidence=0.9,
211 description=f"delete vs modify conflict on {addr}",
212 ))
213
214 if as_json:
215 typer.echo(json.dumps(
216 {
217 "schema_version": 1,
218 "active_reservations": len(reservations),
219 "intents": len(intents),
220 "branch_filter": branch_filter,
221 "conflicts": [c.to_dict() for c in conflicts],
222 "high_risk": sum(1 for c in conflicts if c.confidence >= 0.9),
223 "medium_risk": sum(1 for c in conflicts if 0.5 <= c.confidence < 0.9),
224 "low_risk": sum(1 for c in conflicts if c.confidence < 0.5),
225 },
226 indent=2,
227 ))
228 return
229
230 typer.echo(
231 f"\nConflict forecast — "
232 f"{len(reservations)} active reservation(s), {len(intents)} intent(s)"
233 )
234 typer.echo("─" * 62)
235
236 if not conflicts:
237 typer.echo("\n ✅ No conflicts predicted.")
238 if not reservations:
239 typer.echo(" (no active reservations — run 'muse reserve' first)")
240 return
241
242 for c in conflicts:
243 icon = "🔴" if c.confidence >= 0.9 else "⚠️ "
244 typer.echo(f"\n{icon} {c.conflict_type} (confidence {c.confidence:.2f})")
245 for addr in c.addresses:
246 typer.echo(f" {addr}")
247 for agent in c.agents:
248 typer.echo(f" agent: {agent}")
249 typer.echo(f" → {c.description}")
250
251 high = sum(1 for c in conflicts if c.confidence >= 0.9)
252 med = sum(1 for c in conflicts if 0.5 <= c.confidence < 0.9)
253 typer.echo(
254 f"\n {high} high-risk, {med} medium-risk conflict(s) predicted"
255 )
256 typer.echo(" Run 'muse plan-merge' for a detailed merge strategy.")