cgcardona / muse public
test_core_store.py python
248 lines 9.0 KB
8d5137ed fix(security): full surface hardening — validation, path containment, p… Gabriel Cardona <cgcardona@gmail.com> 10h ago
1 """Tests for muse.core.store — file-based commit and snapshot storage."""
2
3 import datetime
4 import json
5 import pathlib
6
7 import pytest
8
9 from muse.core.store import (
10 CommitDict,
11 CommitRecord,
12 SnapshotRecord,
13 TagRecord,
14 find_commits_by_prefix,
15 get_all_commits,
16 get_all_tags,
17 get_commits_for_branch,
18 get_head_commit_id,
19 get_head_snapshot_id,
20 get_head_snapshot_manifest,
21 get_tags_for_commit,
22 read_commit,
23 read_snapshot,
24 update_commit_metadata,
25 write_commit,
26 write_snapshot,
27 write_tag,
28 )
29
30
31 @pytest.fixture
32 def repo(tmp_path: pathlib.Path) -> pathlib.Path:
33 """Create a minimal .muse/ directory structure."""
34 muse_dir = tmp_path / ".muse"
35 (muse_dir / "commits").mkdir(parents=True)
36 (muse_dir / "snapshots").mkdir(parents=True)
37 (muse_dir / "refs" / "heads").mkdir(parents=True)
38 (muse_dir / "repo.json").write_text(json.dumps({"repo_id": "test-repo"}))
39 (muse_dir / "HEAD").write_text("refs/heads/main\n")
40 (muse_dir / "refs" / "heads" / "main").write_text("")
41 return tmp_path
42
43
44 def _make_commit(root: pathlib.Path, commit_id: str, snapshot_id: str, message: str, parent: str | None = None) -> CommitRecord:
45 c = CommitRecord(
46 commit_id=commit_id,
47 repo_id="test-repo",
48 branch="main",
49 snapshot_id=snapshot_id,
50 message=message,
51 committed_at=datetime.datetime.now(datetime.timezone.utc),
52 parent_commit_id=parent,
53 )
54 write_commit(root, c)
55 return c
56
57
58 def _make_snapshot(root: pathlib.Path, snapshot_id: str, manifest: dict[str, str]) -> SnapshotRecord:
59 s = SnapshotRecord(snapshot_id=snapshot_id, manifest=manifest)
60 write_snapshot(root, s)
61 return s
62
63
64 class TestFormatVersion:
65 """CommitRecord.format_version tracks schema evolution."""
66
67 def test_new_commit_has_format_version_5(self, repo: pathlib.Path) -> None:
68 c = _make_commit(repo, "abc123", "s" * 64, "msg")
69 assert c.format_version == 5
70
71 def test_format_version_round_trips_through_json_v5(self, repo: pathlib.Path) -> None:
72 _make_commit(repo, "abc123", "s" * 64, "msg")
73 loaded = read_commit(repo, "abc123")
74 assert loaded is not None
75 assert loaded.format_version == 5
76
77 def test_format_version_in_serialised_dict(self) -> None:
78 c = CommitRecord(
79 commit_id="x",
80 repo_id="r",
81 branch="main",
82 snapshot_id="s",
83 message="m",
84 committed_at=datetime.datetime.now(datetime.timezone.utc),
85 )
86 d = c.to_dict()
87 assert "format_version" in d
88 assert d["format_version"] == 5
89
90 def test_missing_format_version_defaults_to_1(self) -> None:
91 """Existing JSON without format_version field deserialises as version 1."""
92 raw = CommitDict(
93 commit_id="abc",
94 repo_id="r",
95 branch="main",
96 snapshot_id="s",
97 message="old record",
98 committed_at="2025-01-01T00:00:00+00:00",
99 )
100 c = CommitRecord.from_dict(raw)
101 assert c.format_version == 1
102
103 def test_explicit_format_version_preserved(self) -> None:
104 raw = CommitDict(
105 commit_id="abc",
106 repo_id="r",
107 branch="main",
108 snapshot_id="s",
109 message="versioned record",
110 committed_at="2025-01-01T00:00:00+00:00",
111 format_version=2,
112 )
113 c = CommitRecord.from_dict(raw)
114 assert c.format_version == 2
115
116 def test_format_version_field_is_integer(self, repo: pathlib.Path) -> None:
117 _make_commit(repo, "abc123", "s" * 64, "msg")
118 loaded = read_commit(repo, "abc123")
119 assert loaded is not None
120 assert isinstance(loaded.format_version, int)
121
122
123 class TestWriteReadCommit:
124 def test_roundtrip(self, repo: pathlib.Path) -> None:
125 c = _make_commit(repo, "abc123", "s" * 64, "Initial commit")
126 loaded = read_commit(repo, "abc123")
127 assert loaded is not None
128 assert loaded.commit_id == "abc123"
129 assert loaded.message == "Initial commit"
130 assert loaded.repo_id == "test-repo"
131
132 def test_read_missing_returns_none(self, repo: pathlib.Path) -> None:
133 assert read_commit(repo, "nonexistent") is None
134
135 def test_idempotent_write(self, repo: pathlib.Path) -> None:
136 _make_commit(repo, "abc123", "s" * 64, "First")
137 _make_commit(repo, "abc123", "s" * 64, "Second") # Should not overwrite
138 loaded = read_commit(repo, "abc123")
139 assert loaded is not None
140 assert loaded.message == "First"
141
142 def test_metadata_preserved(self, repo: pathlib.Path) -> None:
143 c = CommitRecord(
144 commit_id="abc123",
145 repo_id="test-repo",
146 branch="main",
147 snapshot_id="s" * 64,
148 message="With metadata",
149 committed_at=datetime.datetime.now(datetime.timezone.utc),
150 metadata={"section": "chorus", "emotion": "joyful"},
151 )
152 write_commit(repo, c)
153 loaded = read_commit(repo, "abc123")
154 assert loaded is not None
155 assert loaded.metadata["section"] == "chorus"
156 assert loaded.metadata["emotion"] == "joyful"
157
158
159 class TestUpdateCommitMetadata:
160 def test_set_key(self, repo: pathlib.Path) -> None:
161 _make_commit(repo, "abc123", "s" * 64, "msg")
162 result = update_commit_metadata(repo, "abc123", "tempo_bpm", 120.0)
163 assert result is True
164 loaded = read_commit(repo, "abc123")
165 assert loaded is not None
166 assert loaded.metadata["tempo_bpm"] == 120.0
167
168 def test_missing_commit_returns_false(self, repo: pathlib.Path) -> None:
169 assert update_commit_metadata(repo, "missing", "k", "v") is False
170
171
172 class TestWriteReadSnapshot:
173 def test_roundtrip(self, repo: pathlib.Path) -> None:
174 s = _make_snapshot(repo, "s" * 64, {"tracks/drums.mid": "deadbeef"})
175 loaded = read_snapshot(repo, "s" * 64)
176 assert loaded is not None
177 assert loaded.manifest == {"tracks/drums.mid": "deadbeef"}
178
179 def test_read_missing_returns_none(self, repo: pathlib.Path) -> None:
180 assert read_snapshot(repo, "nonexistent") is None
181
182
183 class TestHeadQueries:
184 def test_get_head_commit_id_empty_branch(self, repo: pathlib.Path) -> None:
185 assert get_head_commit_id(repo, "main") is None
186
187 def test_get_head_commit_id(self, repo: pathlib.Path) -> None:
188 (repo / ".muse" / "refs" / "heads" / "main").write_text("abc123")
189 assert get_head_commit_id(repo, "main") == "abc123"
190
191 def test_get_head_snapshot_id(self, repo: pathlib.Path) -> None:
192 _make_commit(repo, "abc123", "s" * 64, "msg")
193 _make_snapshot(repo, "s" * 64, {"f.mid": "hash1"})
194 (repo / ".muse" / "refs" / "heads" / "main").write_text("abc123")
195 assert get_head_snapshot_id(repo, "test-repo", "main") == "s" * 64
196
197 def test_get_head_snapshot_manifest(self, repo: pathlib.Path) -> None:
198 _make_commit(repo, "abc123", "s" * 64, "msg")
199 _make_snapshot(repo, "s" * 64, {"f.mid": "hash1"})
200 (repo / ".muse" / "refs" / "heads" / "main").write_text("abc123")
201 manifest = get_head_snapshot_manifest(repo, "test-repo", "main")
202 assert manifest == {"f.mid": "hash1"}
203
204
205 class TestGetCommitsForBranch:
206 def test_chain(self, repo: pathlib.Path) -> None:
207 _make_commit(repo, "root", "snap0", "Root")
208 _make_commit(repo, "child", "s" * 64, "Child", parent="root")
209 _make_commit(repo, "grandchild", "snap2", "Grandchild", parent="child")
210 (repo / ".muse" / "refs" / "heads" / "main").write_text("grandchild")
211
212 commits = get_commits_for_branch(repo, "test-repo", "main")
213 assert [c.commit_id for c in commits] == ["grandchild", "child", "root"]
214
215 def test_empty_branch(self, repo: pathlib.Path) -> None:
216 assert get_commits_for_branch(repo, "test-repo", "main") == []
217
218
219 class TestFindByPrefix:
220 def test_finds_match(self, repo: pathlib.Path) -> None:
221 _make_commit(repo, "abcdef1234", "s" * 64, "msg")
222 results = find_commits_by_prefix(repo, "abcdef")
223 assert len(results) == 1
224 assert results[0].commit_id == "abcdef1234"
225
226 def test_no_match(self, repo: pathlib.Path) -> None:
227 assert find_commits_by_prefix(repo, "zzz") == []
228
229
230 class TestTags:
231 def test_write_and_read(self, repo: pathlib.Path) -> None:
232 _make_commit(repo, "abc123", "s" * 64, "msg")
233 write_tag(repo, TagRecord(
234 tag_id="tag1",
235 repo_id="test-repo",
236 commit_id="abc123",
237 tag="emotion:joyful",
238 ))
239 tags = get_tags_for_commit(repo, "test-repo", "abc123")
240 assert len(tags) == 1
241 assert tags[0].tag == "emotion:joyful"
242
243 def test_get_all_tags(self, repo: pathlib.Path) -> None:
244 _make_commit(repo, "abc123", "s" * 64, "msg")
245 write_tag(repo, TagRecord(tag_id="t1", repo_id="test-repo", commit_id="abc123", tag="stage:rough-mix"))
246 write_tag(repo, TagRecord(tag_id="t2", repo_id="test-repo", commit_id="abc123", tag="key:Am"))
247 all_tags = get_all_tags(repo, "test-repo")
248 assert len(all_tags) == 2