gabriel / muse public
quantize.py python
148 lines 5.3 KB
00373ad0 feat: migrate CLI from typer to argparse (POSIX-compliant, order-independent) Gabriel Cardona <gabriel@tellurstori.com> 1d ago
1 """muse quantize — snap note onsets to a rhythmic grid.
2
3 Moves every note's start tick to the nearest multiple of the chosen
4 subdivision (16th, 8th, quarter, etc.). Duration is preserved. An
5 essential post-processing step after human-recorded or agent-generated
6 MIDI that needs to be grid-aligned before mixing.
7
8 Usage::
9
10 muse quantize tracks/piano.mid --grid 16th
11 muse quantize tracks/bass.mid --grid 8th --strength 0.5
12 muse quantize tracks/melody.mid --dry-run
13 muse quantize tracks/drums.mid --grid 32nd
14
15 Grid values: whole, half, quarter, 8th, 16th, 32nd, triplet-8th, triplet-16th
16
17 Output::
18
19 ✅ Quantised tracks/piano.mid → 16th-note grid
20 64 notes adjusted · avg shift: 14 ticks · max shift: 58 ticks
21 Run `muse status` to review, then `muse commit`
22 """
23
24 from __future__ import annotations
25
26 import argparse
27 import logging
28 import pathlib
29 import sys
30
31 from muse.core.errors import ExitCode
32 from muse.core.validation import contain_path
33 from muse.core.repo import require_repo
34 from muse.plugins.midi._query import NoteInfo, load_track_from_workdir, notes_to_midi_bytes
35
36 logger = logging.getLogger(__name__)
37
38 _GRID_FRACTIONS: dict[str, float] = {
39 "whole": 4.0,
40 "half": 2.0,
41 "quarter": 1.0,
42 "8th": 0.5,
43 "16th": 0.25,
44 "32nd": 0.125,
45 "triplet-8th": 1 / 3,
46 "triplet-16th": 1 / 6,
47 }
48
49
50 def _grid_ticks(tpb: int, grid_name: str) -> int:
51 fraction = _GRID_FRACTIONS.get(grid_name, 0.25)
52 return max(1, round(tpb * fraction))
53
54
55 def _snap(tick: int, grid: int, strength: float) -> int:
56 """Snap *tick* toward the nearest grid point with *strength* [0, 1]."""
57 nearest = round(tick / grid) * grid
58 return round(tick + (nearest - tick) * strength)
59
60
61 def register(subparsers: "argparse._SubParsersAction[argparse.ArgumentParser]") -> None:
62 """Register the quantize subcommand."""
63 parser = subparsers.add_parser("quantize", help="Snap note onsets to a rhythmic grid.", description=__doc__)
64 parser.add_argument("track", metavar="TRACK", help="Workspace-relative path to a .mid file.")
65 parser.add_argument("--grid", "-g", metavar="GRID", default="16th", help="Quantisation grid: whole, half, quarter, 8th, 16th, 32nd, triplet-8th, triplet-16th.")
66 parser.add_argument("--strength", "-s", metavar="S", type=float, default=1.0, help="Quantisation strength 0.0 (no change) – 1.0 (full snap). Default 1.0.")
67 parser.add_argument("--dry-run", "-n", action="store_true", help="Preview without writing.")
68 parser.set_defaults(func=run)
69
70
71 def run(args: argparse.Namespace) -> None:
72 """Snap note onsets to a rhythmic grid.
73
74 ``muse quantize`` moves each note's start tick to the nearest multiple of
75 the chosen subdivision. Use ``--strength`` < 1.0 for partial quantisation
76 that preserves some human feel while tightening the groove.
77
78 After quantising, run ``muse status`` to inspect the structured delta
79 (which notes moved) and ``muse commit`` to record the operation.
80 """
81 track: str = args.track
82 grid: str = args.grid
83 strength: float = args.strength
84 dry_run: bool = args.dry_run
85
86 if grid not in _GRID_FRACTIONS:
87 print(
88 f"❌ Unknown grid '{grid}'. "
89 f"Valid: {', '.join(_GRID_FRACTIONS)}",
90 file=sys.stderr,
91 )
92 raise SystemExit(ExitCode.USER_ERROR)
93
94 if not 0.0 <= strength <= 1.0:
95 print("❌ --strength must be between 0.0 and 1.0.", file=sys.stderr)
96 raise SystemExit(ExitCode.USER_ERROR)
97
98 root = require_repo()
99 result = load_track_from_workdir(root, track)
100 if result is None:
101 print(f"❌ Track '{track}' not found or not a valid MIDI file.", file=sys.stderr)
102 raise SystemExit(ExitCode.USER_ERROR)
103
104 notes, tpb = result
105 if not notes:
106 print(f" (track '{track}' contains no notes — nothing to quantise)")
107 return
108
109 grid_t = _grid_ticks(tpb, grid)
110 quantised: list[NoteInfo] = []
111 shifts: list[int] = []
112
113 for n in notes:
114 new_tick = _snap(n.start_tick, grid_t, strength)
115 shifts.append(abs(new_tick - n.start_tick))
116 quantised.append(NoteInfo(
117 pitch=n.pitch,
118 velocity=n.velocity,
119 start_tick=new_tick,
120 duration_ticks=n.duration_ticks,
121 channel=n.channel,
122 ticks_per_beat=n.ticks_per_beat,
123 ))
124
125 avg_shift = sum(shifts) / max(len(shifts), 1)
126 max_shift = max(shifts) if shifts else 0
127 moved = sum(1 for s in shifts if s > 0)
128
129 if dry_run:
130 print(f"\n[dry-run] Would quantise {track} → {grid}-note grid (strength={strength:.2f})")
131 print(f" Notes adjusted: {moved} / {len(notes)}")
132 print(f" Avg tick shift: {avg_shift:.1f} · Max: {max_shift}")
133 print(" No changes written (--dry-run).")
134 return
135
136 midi_bytes = notes_to_midi_bytes(quantised, tpb)
137 workdir = root
138 try:
139 work_path = contain_path(workdir, track)
140 except ValueError as exc:
141 print(f"❌ Invalid track path: {exc}")
142 raise SystemExit(ExitCode.USER_ERROR)
143 work_path.parent.mkdir(parents=True, exist_ok=True)
144 work_path.write_bytes(midi_bytes)
145
146 print(f"\n✅ Quantised {track} → {grid}-note grid")
147 print(f" {moved} notes adjusted · avg shift: {avg_shift:.1f} ticks · max shift: {max_shift}")
148 print(" Run `muse status` to review, then `muse commit`")