midi_merge.py
python
| 1 | """MIDI dimension-aware merge for the Muse music plugin. |
| 2 | |
| 3 | This module implements the multidimensional merge that makes Muse meaningfully |
| 4 | different from git. Git treats every file as an opaque byte sequence: any |
| 5 | two-branch change to the same file is a conflict. Muse understands that a |
| 6 | MIDI file has *independent orthogonal axes*, and two collaborators can touch |
| 7 | different axes of the same file without conflicting. |
| 8 | |
| 9 | Dimensions |
| 10 | ---------- |
| 11 | |
| 12 | +---------------+----------------------------------------------------+ |
| 13 | | Dimension | MIDI event types | |
| 14 | +===============+====================================================+ |
| 15 | | ``melodic`` | ``note_on`` / ``note_off`` (pitch + timing) | |
| 16 | +---------------+----------------------------------------------------+ |
| 17 | | ``rhythmic`` | Alias for ``melodic`` — timing is inseparable from | |
| 18 | | | pitch in the MIDI event model; provided as a | |
| 19 | | | user-facing label in ``.museattributes`` rules. | |
| 20 | +---------------+----------------------------------------------------+ |
| 21 | | ``harmonic`` | ``pitchwheel`` events | |
| 22 | +---------------+----------------------------------------------------+ |
| 23 | | ``dynamic`` | ``control_change`` events | |
| 24 | +---------------+----------------------------------------------------+ |
| 25 | | ``structural``| ``set_tempo``, ``time_signature``, ``key_signature``,| |
| 26 | | | ``program_change``, text/sysex meta events | |
| 27 | +---------------+----------------------------------------------------+ |
| 28 | |
| 29 | Merge algorithm |
| 30 | --------------- |
| 31 | |
| 32 | 1. Parse ``base``, ``left``, and ``right`` MIDI bytes into event streams. |
| 33 | 2. Convert to absolute-tick representation and bucket by dimension. |
| 34 | 3. Hash each bucket; compare ``base ↔ left`` and ``base ↔ right`` to detect |
| 35 | per-dimension changes. |
| 36 | 4. For each dimension apply the winning side determined by ``.museattributes`` |
| 37 | strategy (or the standard one-sided-change rule when no conflict exists). |
| 38 | 5. Reconstruct a valid MIDI file by merging winning dimension slices, sorting |
| 39 | by absolute tick, converting back to delta-time, and writing to bytes. |
| 40 | |
| 41 | Public API |
| 42 | ---------- |
| 43 | |
| 44 | - :func:`extract_dimensions` — parse MIDI bytes → ``dict[dim, DimensionSlice]`` |
| 45 | - :func:`merge_midi_dimensions` — three-way dimension merge → bytes or ``None`` |
| 46 | - :func:`dimension_conflict_detail` — per-dimension change report for logging |
| 47 | """ |
| 48 | from __future__ import annotations |
| 49 | |
| 50 | import hashlib |
| 51 | import io |
| 52 | import json |
| 53 | from dataclasses import dataclass, field |
| 54 | |
| 55 | import mido |
| 56 | |
| 57 | from muse.core.attributes import AttributeRule, resolve_strategy |
| 58 | |
| 59 | # --------------------------------------------------------------------------- |
| 60 | # Dimension constants |
| 61 | # --------------------------------------------------------------------------- |
| 62 | |
| 63 | #: Internal dimension names used as dict keys throughout this module. |
| 64 | INTERNAL_DIMS: list[str] = ["notes", "harmonic", "dynamic", "structural"] |
| 65 | |
| 66 | #: User-facing dimension names from .museattributes mapped to internal buckets. |
| 67 | #: Both "melodic" and "rhythmic" map to the same "notes" bucket because MIDI |
| 68 | #: event timing and pitch are carried in the same event structure. |
| 69 | DIM_ALIAS: dict[str, str] = { |
| 70 | "melodic": "notes", |
| 71 | "rhythmic": "notes", |
| 72 | "harmonic": "harmonic", |
| 73 | "dynamic": "dynamic", |
| 74 | "structural": "structural", |
| 75 | } |
| 76 | |
| 77 | #: Canonical alias → internal dim name, with internal names as pass-throughs. |
| 78 | _CANONICAL: dict[str, str] = {**DIM_ALIAS, "notes": "notes"} |
| 79 | |
| 80 | |
| 81 | # --------------------------------------------------------------------------- |
| 82 | # Data types |
| 83 | # --------------------------------------------------------------------------- |
| 84 | |
| 85 | |
| 86 | @dataclass |
| 87 | class DimensionSlice: |
| 88 | """Events belonging to one dimension of a MIDI file. |
| 89 | |
| 90 | ``events`` is a list of ``(abs_tick, mido.Message)`` pairs sorted by |
| 91 | ascending absolute tick. ``content_hash`` is the SHA-256 digest of the |
| 92 | canonical JSON serialisation of the event list (used for change detection |
| 93 | without loading file bytes). |
| 94 | """ |
| 95 | |
| 96 | name: str |
| 97 | events: list[tuple[int, mido.Message]] = field(default_factory=list) |
| 98 | content_hash: str = "" |
| 99 | |
| 100 | def __post_init__(self) -> None: |
| 101 | if not self.content_hash: |
| 102 | self.content_hash = _hash_events(self.events) |
| 103 | |
| 104 | |
| 105 | @dataclass |
| 106 | class MidiDimensions: |
| 107 | """All dimension slices extracted from one MIDI file.""" |
| 108 | |
| 109 | ticks_per_beat: int |
| 110 | file_type: int |
| 111 | slices: dict[str, DimensionSlice] # internal dim name → slice |
| 112 | |
| 113 | def get(self, user_dim: str) -> DimensionSlice: |
| 114 | """Return the slice for a user-facing or internal dimension name.""" |
| 115 | internal = _CANONICAL.get(user_dim, user_dim) |
| 116 | return self.slices[internal] |
| 117 | |
| 118 | |
| 119 | # --------------------------------------------------------------------------- |
| 120 | # Internal helpers |
| 121 | # --------------------------------------------------------------------------- |
| 122 | |
| 123 | |
| 124 | def _classify_event(msg: mido.Message) -> str | None: |
| 125 | """Map a mido Message to an internal dimension bucket, or ``None`` to skip.""" |
| 126 | t = msg.type |
| 127 | if t in ("note_on", "note_off"): |
| 128 | return "notes" |
| 129 | if t == "pitchwheel": |
| 130 | return "harmonic" |
| 131 | if t == "control_change": |
| 132 | return "dynamic" |
| 133 | if t in ( |
| 134 | "set_tempo", |
| 135 | "time_signature", |
| 136 | "key_signature", |
| 137 | "program_change", |
| 138 | "sysex", |
| 139 | "text", |
| 140 | "copyright", |
| 141 | "track_name", |
| 142 | "instrument_name", |
| 143 | "lyrics", |
| 144 | "marker", |
| 145 | "cue_marker", |
| 146 | "sequencer_specific", |
| 147 | "end_of_track", |
| 148 | ): |
| 149 | return "structural" |
| 150 | # Unrecognised meta events → structural bucket as a safe default. |
| 151 | if getattr(msg, "is_meta", False): |
| 152 | return "structural" |
| 153 | return None |
| 154 | |
| 155 | |
| 156 | type _MsgVal = int | str | list[int] |
| 157 | |
| 158 | |
| 159 | def _msg_to_dict(msg: mido.Message) -> dict[str, _MsgVal]: |
| 160 | """Serialise a mido Message to a JSON-compatible dict.""" |
| 161 | d: dict[str, _MsgVal] = {"type": msg.type} |
| 162 | for attr in ("channel", "note", "velocity", "control", "value", |
| 163 | "pitch", "program", "numerator", "denominator", |
| 164 | "clocks_per_click", "notated_32nd_notes_per_beat", |
| 165 | "tempo", "key", "scale", "text", "data"): |
| 166 | if hasattr(msg, attr): |
| 167 | raw = getattr(msg, attr) |
| 168 | if isinstance(raw, (bytes, bytearray)): |
| 169 | d[attr] = list(raw) |
| 170 | elif isinstance(raw, str): |
| 171 | d[attr] = raw |
| 172 | elif isinstance(raw, int): |
| 173 | d[attr] = raw |
| 174 | # Other types (float, etc.) are skipped — not present in standard MIDI |
| 175 | return d |
| 176 | |
| 177 | |
| 178 | def _hash_events(events: list[tuple[int, mido.Message]]) -> str: |
| 179 | """SHA-256 of the canonical JSON representation of an event list.""" |
| 180 | payload = json.dumps( |
| 181 | [(tick, _msg_to_dict(msg)) for tick, msg in events], |
| 182 | sort_keys=True, |
| 183 | separators=(",", ":"), |
| 184 | ).encode() |
| 185 | return hashlib.sha256(payload).hexdigest() |
| 186 | |
| 187 | |
| 188 | def _to_absolute(track: mido.MidiTrack) -> list[tuple[int, mido.Message]]: |
| 189 | """Convert a delta-time track to a list of ``(abs_tick, msg)`` pairs.""" |
| 190 | result: list[tuple[int, mido.Message]] = [] |
| 191 | abs_tick = 0 |
| 192 | for msg in track: |
| 193 | abs_tick += msg.time |
| 194 | result.append((abs_tick, msg)) |
| 195 | return result |
| 196 | |
| 197 | |
| 198 | # --------------------------------------------------------------------------- |
| 199 | # Public: extract_dimensions |
| 200 | # --------------------------------------------------------------------------- |
| 201 | |
| 202 | |
| 203 | def extract_dimensions(midi_bytes: bytes) -> MidiDimensions: |
| 204 | """Parse *midi_bytes* and bucket events by dimension. |
| 205 | |
| 206 | Args: |
| 207 | midi_bytes: Raw bytes of a ``.mid`` file. |
| 208 | |
| 209 | Returns: |
| 210 | A :class:`MidiDimensions` with one :class:`DimensionSlice` per |
| 211 | internal dimension. Events are sorted by ascending absolute tick. |
| 212 | |
| 213 | Raises: |
| 214 | ValueError: If *midi_bytes* cannot be parsed as a MIDI file. |
| 215 | """ |
| 216 | try: |
| 217 | mid = mido.MidiFile(file=io.BytesIO(midi_bytes)) |
| 218 | except Exception as exc: |
| 219 | raise ValueError(f"Failed to parse MIDI data: {exc}") from exc |
| 220 | |
| 221 | buckets: dict[str, list[tuple[int, mido.Message]]] = { |
| 222 | dim: [] for dim in INTERNAL_DIMS |
| 223 | } |
| 224 | |
| 225 | for track in mid.tracks: |
| 226 | for abs_tick, msg in _to_absolute(track): |
| 227 | bucket = _classify_event(msg) |
| 228 | if bucket is not None: |
| 229 | buckets[bucket].append((abs_tick, msg)) |
| 230 | |
| 231 | # Sort each bucket by ascending absolute tick, then by event type for |
| 232 | # determinism when multiple events share the same tick. |
| 233 | for dim in INTERNAL_DIMS: |
| 234 | buckets[dim].sort(key=lambda x: (x[0], x[1].type)) |
| 235 | |
| 236 | slices = {dim: DimensionSlice(name=dim, events=events) |
| 237 | for dim, events in buckets.items()} |
| 238 | return MidiDimensions( |
| 239 | ticks_per_beat=mid.ticks_per_beat, |
| 240 | file_type=mid.type, |
| 241 | slices=slices, |
| 242 | ) |
| 243 | |
| 244 | |
| 245 | # --------------------------------------------------------------------------- |
| 246 | # Public: dimension_conflict_detail |
| 247 | # --------------------------------------------------------------------------- |
| 248 | |
| 249 | |
| 250 | def dimension_conflict_detail( |
| 251 | base: MidiDimensions, |
| 252 | left: MidiDimensions, |
| 253 | right: MidiDimensions, |
| 254 | ) -> dict[str, str]: |
| 255 | """Return a per-dimension change report for a conflicting file. |
| 256 | |
| 257 | Returns a dict mapping internal dimension name to one of: |
| 258 | |
| 259 | - ``"unchanged"`` — neither side changed this dimension. |
| 260 | - ``"left_only"`` — only the left (ours) side changed. |
| 261 | - ``"right_only"`` — only the right (theirs) side changed. |
| 262 | - ``"both"`` — both sides changed; a dimension-level conflict. |
| 263 | |
| 264 | This is used by :func:`merge_midi_dimensions` and can also be surfaced |
| 265 | in ``muse merge`` output for human-readable conflict diagnostics. |
| 266 | """ |
| 267 | report: dict[str, str] = {} |
| 268 | for dim in INTERNAL_DIMS: |
| 269 | base_hash = base.slices[dim].content_hash |
| 270 | left_hash = left.slices[dim].content_hash |
| 271 | right_hash = right.slices[dim].content_hash |
| 272 | left_changed = base_hash != left_hash |
| 273 | right_changed = base_hash != right_hash |
| 274 | if left_changed and right_changed: |
| 275 | report[dim] = "both" |
| 276 | elif left_changed: |
| 277 | report[dim] = "left_only" |
| 278 | elif right_changed: |
| 279 | report[dim] = "right_only" |
| 280 | else: |
| 281 | report[dim] = "unchanged" |
| 282 | return report |
| 283 | |
| 284 | |
| 285 | # --------------------------------------------------------------------------- |
| 286 | # Reconstruction helpers |
| 287 | # --------------------------------------------------------------------------- |
| 288 | |
| 289 | |
| 290 | def _events_to_track( |
| 291 | events: list[tuple[int, mido.Message]], |
| 292 | ) -> mido.MidiTrack: |
| 293 | """Convert absolute-tick events to a mido MidiTrack with delta times.""" |
| 294 | track = mido.MidiTrack() |
| 295 | prev_tick = 0 |
| 296 | for abs_tick, msg in sorted(events, key=lambda x: (x[0], x[1].type)): |
| 297 | delta = abs_tick - prev_tick |
| 298 | # mido Message objects are immutable; copy() gives us a mutable clone. |
| 299 | new_msg = msg.copy(time=delta) |
| 300 | track.append(new_msg) |
| 301 | prev_tick = abs_tick |
| 302 | # Ensure every track ends with end_of_track. |
| 303 | if not track or track[-1].type != "end_of_track": |
| 304 | track.append(mido.MetaMessage("end_of_track", time=0)) |
| 305 | return track |
| 306 | |
| 307 | |
| 308 | def _reconstruct( |
| 309 | ticks_per_beat: int, |
| 310 | winning_slices: dict[str, list[tuple[int, mido.Message]]], |
| 311 | ) -> bytes: |
| 312 | """Build a type-0 MIDI file from winning dimension event lists. |
| 313 | |
| 314 | All dimension events are merged into a single track (type-0) for |
| 315 | maximum compatibility. The absolute-tick ordering is preserved. |
| 316 | """ |
| 317 | all_events: list[tuple[int, mido.Message]] = [] |
| 318 | for events in winning_slices.values(): |
| 319 | all_events.extend(events) |
| 320 | |
| 321 | # Remove duplicate end_of_track messages; add exactly one at the end. |
| 322 | all_events = [ |
| 323 | (tick, msg) for tick, msg in all_events |
| 324 | if msg.type != "end_of_track" |
| 325 | ] |
| 326 | all_events.sort(key=lambda x: (x[0], x[1].type)) |
| 327 | |
| 328 | track = _events_to_track(all_events) |
| 329 | mid = mido.MidiFile(type=0, ticks_per_beat=ticks_per_beat) |
| 330 | mid.tracks.append(track) |
| 331 | |
| 332 | buf = io.BytesIO() |
| 333 | mid.save(file=buf) |
| 334 | return buf.getvalue() |
| 335 | |
| 336 | |
| 337 | # --------------------------------------------------------------------------- |
| 338 | # Public: merge_midi_dimensions |
| 339 | # --------------------------------------------------------------------------- |
| 340 | |
| 341 | |
| 342 | def merge_midi_dimensions( |
| 343 | base_bytes: bytes, |
| 344 | left_bytes: bytes, |
| 345 | right_bytes: bytes, |
| 346 | attrs_rules: list[AttributeRule], |
| 347 | path: str, |
| 348 | ) -> tuple[bytes, dict[str, str]] | None: |
| 349 | """Attempt a dimension-level three-way merge of a MIDI file. |
| 350 | |
| 351 | For each internal dimension: |
| 352 | |
| 353 | - If neither side changed → keep base. |
| 354 | - If only one side changed → take that side (clean auto-merge). |
| 355 | - If both sides changed → consult ``.museattributes`` strategy: |
| 356 | |
| 357 | * ``ours`` / ``theirs`` → take the specified side; record in report. |
| 358 | * ``manual`` / ``auto`` / ``union`` → unresolvable; return ``None``. |
| 359 | |
| 360 | Args: |
| 361 | base_bytes: MIDI bytes for the common ancestor. |
| 362 | left_bytes: MIDI bytes for the ours (left) branch. |
| 363 | right_bytes: MIDI bytes for the theirs (right) branch. |
| 364 | attrs_rules: Rule list from :func:`muse.core.attributes.load_attributes`. |
| 365 | path: Workspace-relative POSIX path (used for strategy lookup). |
| 366 | |
| 367 | Returns: |
| 368 | A ``(merged_bytes, dimension_report)`` tuple when all dimension |
| 369 | conflicts can be resolved, or ``None`` when at least one dimension |
| 370 | conflict has no resolvable strategy. |
| 371 | |
| 372 | *dimension_report* maps each internal dimension name to the side |
| 373 | chosen: ``"base"``, ``"left"``, ``"right"``, or the strategy string. |
| 374 | |
| 375 | Raises: |
| 376 | ValueError: If any of the byte strings cannot be parsed as MIDI. |
| 377 | """ |
| 378 | base_dims = extract_dimensions(base_bytes) |
| 379 | left_dims = extract_dimensions(left_bytes) |
| 380 | right_dims = extract_dimensions(right_bytes) |
| 381 | |
| 382 | detail = dimension_conflict_detail(base_dims, left_dims, right_dims) |
| 383 | |
| 384 | winning_slices: dict[str, list[tuple[int, mido.Message]]] = {} |
| 385 | dimension_report: dict[str, str] = {} |
| 386 | |
| 387 | for dim in INTERNAL_DIMS: |
| 388 | change = detail[dim] |
| 389 | |
| 390 | if change == "unchanged": |
| 391 | winning_slices[dim] = base_dims.slices[dim].events |
| 392 | dimension_report[dim] = "base" |
| 393 | |
| 394 | elif change == "left_only": |
| 395 | winning_slices[dim] = left_dims.slices[dim].events |
| 396 | dimension_report[dim] = "left" |
| 397 | |
| 398 | elif change == "right_only": |
| 399 | winning_slices[dim] = right_dims.slices[dim].events |
| 400 | dimension_report[dim] = "right" |
| 401 | |
| 402 | else: |
| 403 | # Both sides changed — consult .museattributes for this dimension. |
| 404 | # Try user-facing aliases first, then internal name. |
| 405 | user_dim_names = [k for k, v in DIM_ALIAS.items() if v == dim] + [dim] |
| 406 | strategy = "auto" |
| 407 | for user_dim in user_dim_names: |
| 408 | s = resolve_strategy(attrs_rules, path, user_dim) |
| 409 | if s != "auto": |
| 410 | strategy = s |
| 411 | break |
| 412 | # Also try dimension wildcard ("*") |
| 413 | if strategy == "auto": |
| 414 | strategy = resolve_strategy(attrs_rules, path, "*") |
| 415 | |
| 416 | if strategy == "ours": |
| 417 | winning_slices[dim] = left_dims.slices[dim].events |
| 418 | dimension_report[dim] = f"ours ({dim})" |
| 419 | elif strategy == "theirs": |
| 420 | winning_slices[dim] = right_dims.slices[dim].events |
| 421 | dimension_report[dim] = f"theirs ({dim})" |
| 422 | else: |
| 423 | # "auto", "union", "manual" — cannot resolve this dimension. |
| 424 | return None |
| 425 | |
| 426 | merged_bytes = _reconstruct(base_dims.ticks_per_beat, winning_slices) |
| 427 | return merged_bytes, dimension_report |