gabriel / musehub public
test_musehub_contracts.py python
157 lines 5.3 KB
7923a405 test(supercharge): comprehensive test suite overhaul — all 11 points Gabriel Cardona <gabriel@tellurstori.com> 6d ago
1 """Unit tests for musehub/contracts/hash_utils.py.
2
3 The contract hashing module enforces deterministic SHA-256 fingerprinting of
4 music generation contracts. These tests lock down:
5
6 - canonical_contract_dict field exclusions
7 - _normalize_value handling of nested types
8 - hash stability across runs
9 - contract_hash exclusion prevents circular dependency
10 """
11 from __future__ import annotations
12
13 import dataclasses
14 import hashlib
15 import json
16
17 import pytest
18
19 from musehub.contracts.hash_utils import (
20 _HASH_EXCLUDED_FIELDS,
21 _normalize_value,
22 canonical_contract_dict,
23 )
24
25
26 # ---------------------------------------------------------------------------
27 # Minimal dataclass fixtures
28 # ---------------------------------------------------------------------------
29
30 @dataclasses.dataclass(frozen=True)
31 class SimpleContract:
32 tempo: int = 120
33 key: str = "C major"
34 # advisory fields that should be excluded from hashes
35 contract_hash: str = ""
36 contract_version: str = "1.0"
37 region_name: str = ""
38
39
40 @dataclasses.dataclass(frozen=True)
41 class NestedContract:
42 name: str = "outer"
43 inner: SimpleContract = dataclasses.field(default_factory=SimpleContract)
44
45
46 # ---------------------------------------------------------------------------
47 # _HASH_EXCLUDED_FIELDS
48 # ---------------------------------------------------------------------------
49
50 class TestHashExcludedFields:
51 def test_advisory_fields_excluded(self) -> None:
52 for field in (
53 "contract_hash",
54 "parent_contract_hash",
55 "contract_version",
56 "execution_hash",
57 "l2_generate_prompt",
58 "region_name",
59 "gm_guidance",
60 "assigned_color",
61 "existing_track_id",
62 ):
63 assert field in _HASH_EXCLUDED_FIELDS, f"{field!r} should be excluded"
64
65
66 # ---------------------------------------------------------------------------
67 # _normalize_value
68 # ---------------------------------------------------------------------------
69
70 class TestNormalizeValue:
71 def test_primitives_passthrough(self) -> None:
72 assert _normalize_value(42) == 42
73 assert _normalize_value(3.14) == 3.14
74 assert _normalize_value("hello") == "hello"
75 assert _normalize_value(True) is True
76 assert _normalize_value(None) is None
77
78 def test_list_normalized(self) -> None:
79 result = _normalize_value([3, 1, 2])
80 assert result == [3, 1, 2]
81
82 def test_tuple_becomes_list(self) -> None:
83 result = _normalize_value((1, 2, 3))
84 assert result == [1, 2, 3]
85
86 def test_dict_keys_sorted(self) -> None:
87 result = _normalize_value({"z": 1, "a": 2})
88 assert list(result.keys()) == ["a", "z"] # type: ignore[union-attr]
89
90 def test_dataclass_converted_to_dict(self) -> None:
91 obj = SimpleContract(tempo=90, key="D minor")
92 result = _normalize_value(obj)
93 assert isinstance(result, dict)
94 assert "tempo" in result
95 assert "key" in result
96 # Excluded fields should not appear
97 assert "contract_hash" not in result
98
99 def test_unknown_type_stringified(self) -> None:
100 class Weird:
101 def __str__(self) -> str:
102 return "weird-value"
103 result = _normalize_value(Weird())
104 assert result == "weird-value"
105
106
107 # ---------------------------------------------------------------------------
108 # canonical_contract_dict
109 # ---------------------------------------------------------------------------
110
111 class TestCanonicalContractDict:
112 def test_excluded_fields_absent(self) -> None:
113 obj = SimpleContract(tempo=120, key="G major")
114 d = canonical_contract_dict(obj)
115 for excluded in _HASH_EXCLUDED_FIELDS:
116 assert excluded not in d, f"{excluded!r} should not appear in canonical dict"
117
118 def test_included_fields_present(self) -> None:
119 obj = SimpleContract(tempo=120, key="A minor")
120 d = canonical_contract_dict(obj)
121 assert "tempo" in d
122 assert "key" in d
123 assert d["tempo"] == 120
124 assert d["key"] == "A minor"
125
126 def test_deterministic_across_calls(self) -> None:
127 obj = SimpleContract(tempo=100, key="F major")
128 d1 = canonical_contract_dict(obj)
129 d2 = canonical_contract_dict(obj)
130 assert d1 == d2
131
132 def test_same_data_same_json(self) -> None:
133 obj = SimpleContract(tempo=140, key="B♭ major")
134 d = canonical_contract_dict(obj)
135 j1 = json.dumps(d, sort_keys=True)
136 j2 = json.dumps(canonical_contract_dict(obj), sort_keys=True)
137 assert j1 == j2
138
139 def test_nested_dataclass_recursed(self) -> None:
140 obj = NestedContract(name="outer", inner=SimpleContract(tempo=70))
141 d = canonical_contract_dict(obj)
142 assert "name" in d
143 assert "inner" in d
144 assert isinstance(d["inner"], dict)
145 assert d["inner"]["tempo"] == 70 # type: ignore[index]
146 assert "contract_hash" not in d.get("inner", {}) # type: ignore[operator]
147
148 def test_different_values_different_hash(self) -> None:
149 obj1 = SimpleContract(tempo=120)
150 obj2 = SimpleContract(tempo=140)
151 d1 = canonical_contract_dict(obj1)
152 d2 = canonical_contract_dict(obj2)
153 j1 = json.dumps(d1, sort_keys=True)
154 j2 = json.dumps(d2, sort_keys=True)
155 h1 = hashlib.sha256(j1.encode()).hexdigest()
156 h2 = hashlib.sha256(j2.encode()).hexdigest()
157 assert h1 != h2