test_musehub_contracts.py
python
| 1 | """Unit tests for musehub/contracts/hash_utils.py. |
| 2 | |
| 3 | The contract hashing module enforces deterministic SHA-256 fingerprinting of |
| 4 | music generation contracts. These tests lock down: |
| 5 | |
| 6 | - canonical_contract_dict field exclusions |
| 7 | - _normalize_value handling of nested types |
| 8 | - hash stability across runs |
| 9 | - contract_hash exclusion prevents circular dependency |
| 10 | """ |
| 11 | from __future__ import annotations |
| 12 | |
| 13 | import dataclasses |
| 14 | import hashlib |
| 15 | import json |
| 16 | |
| 17 | import pytest |
| 18 | |
| 19 | from musehub.contracts.hash_utils import ( |
| 20 | _HASH_EXCLUDED_FIELDS, |
| 21 | _normalize_value, |
| 22 | canonical_contract_dict, |
| 23 | ) |
| 24 | |
| 25 | |
| 26 | # --------------------------------------------------------------------------- |
| 27 | # Minimal dataclass fixtures |
| 28 | # --------------------------------------------------------------------------- |
| 29 | |
| 30 | @dataclasses.dataclass(frozen=True) |
| 31 | class SimpleContract: |
| 32 | tempo: int = 120 |
| 33 | key: str = "C major" |
| 34 | # advisory fields that should be excluded from hashes |
| 35 | contract_hash: str = "" |
| 36 | contract_version: str = "1.0" |
| 37 | region_name: str = "" |
| 38 | |
| 39 | |
| 40 | @dataclasses.dataclass(frozen=True) |
| 41 | class NestedContract: |
| 42 | name: str = "outer" |
| 43 | inner: SimpleContract = dataclasses.field(default_factory=SimpleContract) |
| 44 | |
| 45 | |
| 46 | # --------------------------------------------------------------------------- |
| 47 | # _HASH_EXCLUDED_FIELDS |
| 48 | # --------------------------------------------------------------------------- |
| 49 | |
| 50 | class TestHashExcludedFields: |
| 51 | def test_advisory_fields_excluded(self) -> None: |
| 52 | for field in ( |
| 53 | "contract_hash", |
| 54 | "parent_contract_hash", |
| 55 | "contract_version", |
| 56 | "execution_hash", |
| 57 | "l2_generate_prompt", |
| 58 | "region_name", |
| 59 | "gm_guidance", |
| 60 | "assigned_color", |
| 61 | "existing_track_id", |
| 62 | ): |
| 63 | assert field in _HASH_EXCLUDED_FIELDS, f"{field!r} should be excluded" |
| 64 | |
| 65 | |
| 66 | # --------------------------------------------------------------------------- |
| 67 | # _normalize_value |
| 68 | # --------------------------------------------------------------------------- |
| 69 | |
| 70 | class TestNormalizeValue: |
| 71 | def test_primitives_passthrough(self) -> None: |
| 72 | assert _normalize_value(42) == 42 |
| 73 | assert _normalize_value(3.14) == 3.14 |
| 74 | assert _normalize_value("hello") == "hello" |
| 75 | assert _normalize_value(True) is True |
| 76 | assert _normalize_value(None) is None |
| 77 | |
| 78 | def test_list_normalized(self) -> None: |
| 79 | result = _normalize_value([3, 1, 2]) |
| 80 | assert result == [3, 1, 2] |
| 81 | |
| 82 | def test_tuple_becomes_list(self) -> None: |
| 83 | result = _normalize_value((1, 2, 3)) |
| 84 | assert result == [1, 2, 3] |
| 85 | |
| 86 | def test_dict_keys_sorted(self) -> None: |
| 87 | result = _normalize_value({"z": 1, "a": 2}) |
| 88 | assert list(result.keys()) == ["a", "z"] # type: ignore[union-attr] |
| 89 | |
| 90 | def test_dataclass_converted_to_dict(self) -> None: |
| 91 | obj = SimpleContract(tempo=90, key="D minor") |
| 92 | result = _normalize_value(obj) |
| 93 | assert isinstance(result, dict) |
| 94 | assert "tempo" in result |
| 95 | assert "key" in result |
| 96 | # Excluded fields should not appear |
| 97 | assert "contract_hash" not in result |
| 98 | |
| 99 | def test_unknown_type_stringified(self) -> None: |
| 100 | class Weird: |
| 101 | def __str__(self) -> str: |
| 102 | return "weird-value" |
| 103 | result = _normalize_value(Weird()) |
| 104 | assert result == "weird-value" |
| 105 | |
| 106 | |
| 107 | # --------------------------------------------------------------------------- |
| 108 | # canonical_contract_dict |
| 109 | # --------------------------------------------------------------------------- |
| 110 | |
| 111 | class TestCanonicalContractDict: |
| 112 | def test_excluded_fields_absent(self) -> None: |
| 113 | obj = SimpleContract(tempo=120, key="G major") |
| 114 | d = canonical_contract_dict(obj) |
| 115 | for excluded in _HASH_EXCLUDED_FIELDS: |
| 116 | assert excluded not in d, f"{excluded!r} should not appear in canonical dict" |
| 117 | |
| 118 | def test_included_fields_present(self) -> None: |
| 119 | obj = SimpleContract(tempo=120, key="A minor") |
| 120 | d = canonical_contract_dict(obj) |
| 121 | assert "tempo" in d |
| 122 | assert "key" in d |
| 123 | assert d["tempo"] == 120 |
| 124 | assert d["key"] == "A minor" |
| 125 | |
| 126 | def test_deterministic_across_calls(self) -> None: |
| 127 | obj = SimpleContract(tempo=100, key="F major") |
| 128 | d1 = canonical_contract_dict(obj) |
| 129 | d2 = canonical_contract_dict(obj) |
| 130 | assert d1 == d2 |
| 131 | |
| 132 | def test_same_data_same_json(self) -> None: |
| 133 | obj = SimpleContract(tempo=140, key="B♭ major") |
| 134 | d = canonical_contract_dict(obj) |
| 135 | j1 = json.dumps(d, sort_keys=True) |
| 136 | j2 = json.dumps(canonical_contract_dict(obj), sort_keys=True) |
| 137 | assert j1 == j2 |
| 138 | |
| 139 | def test_nested_dataclass_recursed(self) -> None: |
| 140 | obj = NestedContract(name="outer", inner=SimpleContract(tempo=70)) |
| 141 | d = canonical_contract_dict(obj) |
| 142 | assert "name" in d |
| 143 | assert "inner" in d |
| 144 | assert isinstance(d["inner"], dict) |
| 145 | assert d["inner"]["tempo"] == 70 # type: ignore[index] |
| 146 | assert "contract_hash" not in d.get("inner", {}) # type: ignore[operator] |
| 147 | |
| 148 | def test_different_values_different_hash(self) -> None: |
| 149 | obj1 = SimpleContract(tempo=120) |
| 150 | obj2 = SimpleContract(tempo=140) |
| 151 | d1 = canonical_contract_dict(obj1) |
| 152 | d2 = canonical_contract_dict(obj2) |
| 153 | j1 = json.dumps(d1, sort_keys=True) |
| 154 | j2 = json.dumps(d2, sort_keys=True) |
| 155 | h1 = hashlib.sha256(j1.encode()).hexdigest() |
| 156 | h2 = hashlib.sha256(j2.encode()).hexdigest() |
| 157 | assert h1 != h2 |