cgcardona / muse public
quantize.py python
147 lines 5.1 KB
dfaf1b77 refactor: rename muse-work/ → state/ Gabriel Cardona <gabriel@tellurstori.com> 8h ago
1 """muse quantize — snap note onsets to a rhythmic grid.
2
3 Moves every note's start tick to the nearest multiple of the chosen
4 subdivision (16th, 8th, quarter, etc.). Duration is preserved. An
5 essential post-processing step after human-recorded or agent-generated
6 MIDI that needs to be grid-aligned before mixing.
7
8 Usage::
9
10 muse quantize tracks/piano.mid --grid 16th
11 muse quantize tracks/bass.mid --grid 8th --strength 0.5
12 muse quantize tracks/melody.mid --dry-run
13 muse quantize tracks/drums.mid --grid 32nd
14
15 Grid values: whole, half, quarter, 8th, 16th, 32nd, triplet-8th, triplet-16th
16
17 Output::
18
19 ✅ Quantised tracks/piano.mid → 16th-note grid
20 64 notes adjusted · avg shift: 14 ticks · max shift: 58 ticks
21 Run `muse status` to review, then `muse commit`
22 """
23
24 from __future__ import annotations
25
26 import logging
27 import pathlib
28
29 import typer
30
31 from muse.core.errors import ExitCode
32 from muse.core.validation import contain_path
33 from muse.core.repo import require_repo
34 from muse.plugins.midi._query import NoteInfo, load_track_from_workdir, notes_to_midi_bytes
35
36 logger = logging.getLogger(__name__)
37 app = typer.Typer()
38
39 _GRID_FRACTIONS: dict[str, float] = {
40 "whole": 4.0,
41 "half": 2.0,
42 "quarter": 1.0,
43 "8th": 0.5,
44 "16th": 0.25,
45 "32nd": 0.125,
46 "triplet-8th": 1 / 3,
47 "triplet-16th": 1 / 6,
48 }
49
50
51 def _grid_ticks(tpb: int, grid_name: str) -> int:
52 fraction = _GRID_FRACTIONS.get(grid_name, 0.25)
53 return max(1, round(tpb * fraction))
54
55
56 def _snap(tick: int, grid: int, strength: float) -> int:
57 """Snap *tick* toward the nearest grid point with *strength* [0, 1]."""
58 nearest = round(tick / grid) * grid
59 return round(tick + (nearest - tick) * strength)
60
61
62 @app.callback(invoke_without_command=True)
63 def quantize(
64 ctx: typer.Context,
65 track: str = typer.Argument(..., metavar="TRACK", help="Workspace-relative path to a .mid file."),
66 grid: str = typer.Option(
67 "16th", "--grid", "-g", metavar="GRID",
68 help="Quantisation grid: whole, half, quarter, 8th, 16th, 32nd, triplet-8th, triplet-16th.",
69 ),
70 strength: float = typer.Option(
71 1.0, "--strength", "-s", metavar="S",
72 help="Quantisation strength 0.0 (no change) – 1.0 (full snap). Default 1.0.",
73 ),
74 dry_run: bool = typer.Option(False, "--dry-run", "-n", help="Preview without writing."),
75 ) -> None:
76 """Snap note onsets to a rhythmic grid.
77
78 ``muse quantize`` moves each note's start tick to the nearest multiple of
79 the chosen subdivision. Use ``--strength`` < 1.0 for partial quantisation
80 that preserves some human feel while tightening the groove.
81
82 After quantising, run ``muse status`` to inspect the structured delta
83 (which notes moved) and ``muse commit`` to record the operation.
84 """
85 if grid not in _GRID_FRACTIONS:
86 typer.echo(
87 f"❌ Unknown grid '{grid}'. "
88 f"Valid: {', '.join(_GRID_FRACTIONS)}",
89 err=True,
90 )
91 raise typer.Exit(code=ExitCode.USER_ERROR)
92
93 if not 0.0 <= strength <= 1.0:
94 typer.echo("❌ --strength must be between 0.0 and 1.0.", err=True)
95 raise typer.Exit(code=ExitCode.USER_ERROR)
96
97 root = require_repo()
98 result = load_track_from_workdir(root, track)
99 if result is None:
100 typer.echo(f"❌ Track '{track}' not found or not a valid MIDI file.", err=True)
101 raise typer.Exit(code=ExitCode.USER_ERROR)
102
103 notes, tpb = result
104 if not notes:
105 typer.echo(f" (track '{track}' contains no notes — nothing to quantise)")
106 return
107
108 grid_t = _grid_ticks(tpb, grid)
109 quantised: list[NoteInfo] = []
110 shifts: list[int] = []
111
112 for n in notes:
113 new_tick = _snap(n.start_tick, grid_t, strength)
114 shifts.append(abs(new_tick - n.start_tick))
115 quantised.append(NoteInfo(
116 pitch=n.pitch,
117 velocity=n.velocity,
118 start_tick=new_tick,
119 duration_ticks=n.duration_ticks,
120 channel=n.channel,
121 ticks_per_beat=n.ticks_per_beat,
122 ))
123
124 avg_shift = sum(shifts) / max(len(shifts), 1)
125 max_shift = max(shifts) if shifts else 0
126 moved = sum(1 for s in shifts if s > 0)
127
128 if dry_run:
129 typer.echo(f"\n[dry-run] Would quantise {track} → {grid}-note grid (strength={strength:.2f})")
130 typer.echo(f" Notes adjusted: {moved} / {len(notes)}")
131 typer.echo(f" Avg tick shift: {avg_shift:.1f} · Max: {max_shift}")
132 typer.echo(" No changes written (--dry-run).")
133 return
134
135 midi_bytes = notes_to_midi_bytes(quantised, tpb)
136 workdir = root / "state"
137 try:
138 work_path = contain_path(workdir, track)
139 except ValueError as exc:
140 typer.echo(f"❌ Invalid track path: {exc}")
141 raise typer.Exit(code=ExitCode.USER_ERROR)
142 work_path.parent.mkdir(parents=True, exist_ok=True)
143 work_path.write_bytes(midi_bytes)
144
145 typer.echo(f"\n✅ Quantised {track} → {grid}-note grid")
146 typer.echo(f" {moved} notes adjusted · avg shift: {avg_shift:.1f} ticks · max shift: {max_shift}")
147 typer.echo(" Run `muse status` to review, then `muse commit`")