gabriel / musehub public
conftest.py python
134 lines 4.3 KB
d4eb1c39 Theme overhaul: domains, new-repo, MCP docs, copy icons; legacy CSS rem… Gabriel Cardona <cgcardona@gmail.com> 3d ago
1 """Pytest configuration and fixtures."""
2 from __future__ import annotations
3
4 import logging
5 import os
6 from collections.abc import AsyncGenerator, Generator
7
8 # Set before any musehub imports so the Settings lru_cache picks up the value.
9 # Use `if not` rather than setdefault so an empty-string value from the .env file
10 # (ACCESS_TOKEN_SECRET=) doesn't silently win over the test fallback.
11 if not os.environ.get("ACCESS_TOKEN_SECRET"):
12 os.environ["ACCESS_TOKEN_SECRET"] = "test-secret-for-unit-tests-do-not-use-in-prod"
13 if not os.environ.get("MUSE_ENV"):
14 os.environ["MUSE_ENV"] = "test"
15
16 import pytest
17 import pytest_asyncio
18 from httpx import AsyncClient, ASGITransport
19 from sqlalchemy.ext.asyncio import (
20 AsyncSession,
21 async_sessionmaker,
22 create_async_engine,
23 )
24 from sqlalchemy.pool import StaticPool
25
26 from musehub.db import database
27 from musehub.db.database import Base, get_db
28 from musehub.db.models import User
29 from musehub.main import app
30
31
32 def pytest_configure(config: pytest.Config) -> None:
33 """Ensure asyncio_mode is auto so async fixtures work (e.g. in Docker when pyproject not in cwd)."""
34 if hasattr(config.option, "asyncio_mode") and config.option.asyncio_mode is None:
35 config.option.asyncio_mode = "auto"
36 logging.getLogger("httpcore").setLevel(logging.CRITICAL)
37
38
39 @pytest.fixture
40 def anyio_backend() -> str:
41 return "asyncio"
42
43
44 @pytest.fixture(autouse=True)
45 def _reset_variation_store() -> Generator[None, None, None]:
46 """Reset the singleton VariationStore between tests to prevent cross-test pollution.
47
48 Gracefully no-ops if the variation module has been removed (MuseHub extraction).
49 """
50 yield
51 try:
52 from musehub.variation.storage.variation_store import reset_variation_store
53 reset_variation_store()
54 except ModuleNotFoundError:
55 pass
56
57
58 @pytest_asyncio.fixture
59 async def db_session() -> AsyncGenerator[AsyncSession, None]:
60 """Create an in-memory test database session."""
61 engine = create_async_engine(
62 "sqlite+aiosqlite:///:memory:",
63 connect_args={"check_same_thread": False},
64 poolclass=StaticPool,
65 )
66 async with engine.begin() as conn:
67 await conn.run_sync(Base.metadata.create_all)
68
69 async_session_factory = async_sessionmaker(
70 bind=engine,
71 class_=AsyncSession,
72 expire_on_commit=False,
73 )
74 old_engine = database._engine
75 old_factory = database._async_session_factory
76 database._engine = engine
77 database._async_session_factory = async_session_factory
78 try:
79 async with async_session_factory() as session:
80 async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
81 yield session
82 app.dependency_overrides[get_db] = override_get_db
83 yield session
84 app.dependency_overrides.clear()
85 finally:
86 database._engine = old_engine
87 database._async_session_factory = old_factory
88 async with engine.begin() as conn:
89 await conn.run_sync(Base.metadata.drop_all)
90 await engine.dispose()
91
92
93 @pytest_asyncio.fixture
94 async def client(db_session: AsyncSession) -> AsyncGenerator[AsyncClient, None]:
95
96 """Create an async test client. Depends on db_session so auth revocation check uses test DB."""
97 transport = ASGITransport(app=app)
98 async with AsyncClient(transport=transport, base_url="http://test") as ac:
99 yield ac
100
101
102 # -----------------------------------------------------------------------------
103 # Auth fixtures for API contract and integration tests
104 # -----------------------------------------------------------------------------
105
106 @pytest_asyncio.fixture
107 async def test_user(db_session: AsyncSession) -> User:
108
109 """Create a test user (for authenticated route tests)."""
110 user = User(
111 id="550e8400-e29b-41d4-a716-446655440000",
112 )
113 db_session.add(user)
114 await db_session.commit()
115 await db_session.refresh(user)
116 return user
117
118
119 @pytest.fixture
120 def auth_token(test_user: User) -> str:
121
122 """JWT for test_user (1 hour)."""
123 from musehub.auth.tokens import create_access_token
124 return create_access_token(user_id=test_user.id, expires_hours=1)
125
126
127 @pytest.fixture
128 def auth_headers(auth_token: str) -> dict[str, str]:
129
130 """Headers with Bearer token and JSON content type."""
131 return {
132 "Authorization": f"Bearer {auth_token}",
133 "Content-Type": "application/json",
134 }