|
8 | 8 | from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession |
9 | 9 | from docx import Document as DocxDocument |
10 | 10 |
|
| 11 | +import app.database as _db_module |
| 12 | +import app.routers.upload as _upload_module |
| 13 | +import app.routers.qa as _qa_module |
| 14 | +import app.routers.flagged as _flagged_module |
11 | 15 | from app.database import Base, get_db |
12 | 16 | import app.models # noqa: F401 — ensure all models are registered with Base.metadata before create_all |
13 | 17 | from app.main import app |
@@ -64,17 +68,45 @@ async def db_session(db_engine): |
64 | 68 |
|
65 | 69 |
|
66 | 70 | @pytest_asyncio.fixture |
67 | | -async def client(db_session): |
68 | | - """HTTP test client with overridden DB dependency.""" |
| 71 | +async def client(db_engine, db_session): |
| 72 | + """HTTP test client with overridden DB dependency. |
| 73 | +
|
| 74 | + Patches both the FastAPI dependency *and* the module-level engine/session |
| 75 | + so that background tasks (which use ``async_session()`` directly) also |
| 76 | + hit the test database. |
| 77 | + """ |
69 | 78 | async def override_get_db(): |
70 | 79 | yield db_session |
71 | 80 |
|
| 81 | + # Patch module-level engine + session factory so background tasks use the test DB |
| 82 | + test_session_factory = async_sessionmaker(db_engine, class_=AsyncSession, expire_on_commit=False) |
| 83 | + |
| 84 | + originals = { |
| 85 | + "db_engine": _db_module.engine, |
| 86 | + "db_session": _db_module.async_session, |
| 87 | + "upload_session": _upload_module.async_session, |
| 88 | + "qa_session": _qa_module.async_session, |
| 89 | + "flagged_session": _flagged_module.async_session, |
| 90 | + } |
| 91 | + _db_module.engine = db_engine |
| 92 | + _db_module.async_session = test_session_factory |
| 93 | + _upload_module.async_session = test_session_factory |
| 94 | + _qa_module.async_session = test_session_factory |
| 95 | + _flagged_module.async_session = test_session_factory |
| 96 | + |
72 | 97 | app.dependency_overrides[get_db] = override_get_db |
73 | 98 | transport = ASGITransport(app=app) |
74 | 99 | async with AsyncClient(transport=transport, base_url="http://test") as c: |
75 | 100 | yield c |
76 | 101 | app.dependency_overrides.clear() |
77 | 102 |
|
| 103 | + # Restore originals |
| 104 | + _db_module.engine = originals["db_engine"] |
| 105 | + _db_module.async_session = originals["db_session"] |
| 106 | + _upload_module.async_session = originals["upload_session"] |
| 107 | + _qa_module.async_session = originals["qa_session"] |
| 108 | + _flagged_module.async_session = originals["flagged_session"] |
| 109 | + |
78 | 110 |
|
79 | 111 | @pytest.fixture |
80 | 112 | def make_docx(tmp_path): |
|
0 commit comments