Discrepancies Between Test and FastAPI App Data
When testing the FastAPI application with two different async sessions to the database, it's possible to get the wrong test result.
Join the DZone community and get the full member experience.
Join For FreeWhen testing the FastAPI application with two different async sessions to the database, the following error may occur:
- In the test, an object is created in the database (the test session).
- A request is made to the application itself in which this object is changed (the application session).
- An object is loaded from the database in the test, but there are no required changes in it (the test session).
Let’s find out what’s going on.
Most often, we use two different sessions in the application and in the test.
Moreover, in the test, we usually wrap the session in a fixture that prepares the database for tests, and after the tests, everything is cleaned up.
Below is an example of the application.
A file with a database connection app/database.py:
""" Database settings file """
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import declarative_base
DATABASE_URL = "postgresql+asyncpg://user:password@host:5432/dbname"
engine = create_async_engine(DATABASE_URL, echo=True, future=True)
async_session = async_sessionmaker(bind=engine, class_=AsyncSession, expire_on_commit=False)
async def get_session() -> AsyncGenerator:
""" Returns async session """
async with async_session() as session:
yield session
Base = declarative_base()
A file with a model description app/models.py:
""" Model file """
from sqlalchemy import Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from .database import Base
class Lamp(Base):
""" Lamp model """
__tablename__ = 'lamps'
id: Mapped[int] = mapped_column(Integer, primary_key=True, index=True)
status: Mapped[str] = mapped_column(String, default="off")
A file with an endpoint description app/main.py:
""" Main file """
import logging
from fastapi import FastAPI, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from .database import get_session
from .models import Lamp
app = FastAPI()
@app.post("/lamps/{lamp_id}/on")
async def check_lamp(
lamp_id: int,
session: AsyncSession = Depends(get_session)
) -> dict:
""" Lamp on endpoint """
results = await session.execute(select(Lamp).where(Lamp.id == lamp_id))
lamp = results.scalar_one_or_none()
if lamp:
logging.error("Status before update: %s", lamp.status)
lamp.status = "on"
session.add(lamp)
await session.commit()
await session.refresh(lamp)
logging.error("Status after update: %s", lamp.status)
return {}
I have added logging and a few more requests to the example on purpose to make it clear.
Here, a session is created using Depends.
Below is the file with a test example tests/test_lamp.py:
""" Test lamp """
import logging
from typing import AsyncGenerator
import pytest
import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.database import Base, engine
from app.main import app, Lamp
@pytest_asyncio.fixture(scope="function", name="test_session")
async def test_session_fixture() -> AsyncGenerator:
""" Async session fixture """
async_session = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield session
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await engine.dispose()
@pytest.mark.asyncio
async def test_lamp_on(test_session):
""" Test lamp switch on """
lamp = Lamp()
test_session.add(lamp)
await test_session.commit()
await test_session.refresh(lamp)
logging.error("New client status: %s", lamp.status)
assert lamp.status == "off"
async with AsyncClient(app=app, base_url="http://testserver") as async_client:
response = await async_client.post(f"/lamps/{lamp.id}/on")
assert response.status_code == 200
results = await test_session.execute(select(Lamp).where(Lamp.id == lamp.id))
new_lamp = results.scalar_one_or_none()
logging.error("Updated status: %s", new_lamp.status)
assert new_lamp.status == "on"
This is a regular Pytest with getting a session to the database in a fixture. In this fixture, all tables are created before the session is returned, and after using it, they are deleted.
Please note again that in the test, we use a session from the test_session fixture and, in the main code, from the app/database.py file. Despite the fact that we use the same engine, different sessions are generated. It is important.
The expected sequence of database requests
status = on
should return from the database.
In the test, I create an object in the database first. This is a usual INSERT
through a session from a test. Let’s call it Session 1. At this moment, only this session is connected to the database. The application session is not connected yet.
After creating an object, I perform a refresh. This is SELECT
of a newly created object with an instance
update via Session 1.
As a result, I make sure that the object is created correctly and the status
field is filled with the needed value — off
.
Then, I perform a POST
request to the /lamps/1/on
endpoint. This is turning on the lamp. To make the example shorter, I don’t use a fixture. As soon as the request starts working, a new session to the database is created. Let’s call it Session 2. With this session, I load the needed object from the database. I output the status
to the log. It is off
. After that, I update this status
and save the update in the database. A request is made to the database:
BEGIN (implicit)
UPDATE lamps SET status=$1::VARCHAR WHERE lamps.id = $2::INTEGER
parameters: ('on', 1)
COMMIT
Note that the COMMIT
command is also present. Despite the fact that the transaction is implicit, its result is instantly available after COMMIT
in other sessions.
Next, I make a request to get an updated object from the database using refresh. I output status
. And its value is now on
.
It would seem that everything should work. The endpoint stops working, closes Session 2, and transfers control to the test.
In the test, I make a usual request from Session 1 to get a modified object. But in the status field, I see the off
value.
Below is the scheme of the sequence of actions in the code.
At the same time, according to all logs, the last SELECT
request to the database was executed and returned status = on
. Its value is definitely equal to on
in the database at this moment. This is the value that engine asyncpg
receives in response to the SELECT
request.
So, what happened?
Here is what happened.
It turned out that the request made to get a new object did not update the current one but found and used an existing one. In the beginning, I added a lamp object using ORM. I changed it in another session. When the change was made, the current session knew nothing about this change. And commit made in Session 2 did not request the expire_all
method in Session 1.
To fix this, you can do one of the following:
- Use a shared session for the test and application.
- Refresh the instance rather than trying to get it from the database
- Forcibly expire instance.
- Close the session.
Dependency Overrides
To use the same session, you can simply override the session in the application with the one I created in the test. It’s easy.
To do this, we need to add the following code to the test:
async def _override_get_db():
yield test_session
app.dependency_overrides[get_session] = _override_get_db
If you want, you can wrap this part into a fixture to use it in all tests.
The resulting algorithm will be as follows:
Below is the test code with session substitution:
@pytest.mark.asyncio
async def test_lamp_on(test_session):
""" Test lamp switch on """
async def _override_get_db():
yield test_session
app.dependency_overrides[get_session] = _override_get_db
lamp = Lamp()
test_session.add(lamp)
await test_session.commit()
await test_session.refresh(lamp)
logging.error("New client status: %s", lamp.status)
assert lamp.status == "off"
async with AsyncClient(app=app, base_url="http://testserver") as async_client:
response = await async_client.post(f"/lamps/{lamp.id}/on")
assert response.status_code == 200
results = await test_session.execute(select(Lamp).where(Lamp.id == 1))
new_lamp = results.scalar_one_or_none()
logging.error("Updated status: %s", new_lamp.status)
assert new_lamp.status == "on"
However, if the application uses multiple sessions (which is possible), that may not be the best way. Also, if commit
or rollback
is not called in the tested function, this will not help.
Refresh
The second solution is the simplest and most logical. We should not create a new request to get an object. To update, it is enough to call refresh immediately after processing the request to the endpoint. Internally, it calls expires, which leads to the fact that the saved instance
is not used for a new request, and the data is filled in anew. This solution is the most logical and easiest to understand.
await test_session.refresh(lamp)
After it, you do not need to try and load the new_lamp
object again, it is enough to check the same lamp
.
Below is the code scheme using refresh
.
Below is the test code with the update.
@pytest.mark.asyncio
async def test_lamp_on(test_session):
""" Test lamp switch on """
lamp = Lamp()
test_session.add(lamp)
await test_session.commit()
await test_session.refresh(lamp)
logging.error("New client status: %s", lamp.status)
assert lamp.status == "off"
async with AsyncClient(app=app, base_url="http://testserver") as async_client:
response = await async_client.post(f"/lamps/{lamp.id}/on")
assert response.status_code == 200
await test_session.refresh(lamp)
logging.error("Updated status: %s", lamp.status)
assert lamp.status == "on"
Expire
But if we change a lot of objects, it might be better to call expire_all
. Then, all instances will be read from the database, and the consistency will not be broken.
test_session.expire_all()
You can also call expire on a particular instance and even on instance attribute.
test_session.expire(lamp)
After these calls, you will have to read the objects from the database manually.
Below is the sequence of steps in the code when using expire.
Below is the test code with expires.
@pytest.mark.asyncio
async def test_lamp_on(test_session):
""" Test lamp switch on """
lamp = Lamp()
test_session.add(lamp)
await test_session.commit()
await test_session.refresh(lamp)
logging.error("New client status: %s", lamp.status)
assert lamp.status == "off"
async with AsyncClient(app=app, base_url="http://testserver") as async_client:
response = await async_client.post(f"/lamps/{lamp.id}/on")
assert response.status_code == 200
test_session.expire_all()
# OR:
# test_session.expire(lamp)
results = await test_session.execute(select(Lamp).where(Lamp.id == 1))
new_lamp = results.scalar_one_or_none()
logging.error("Updated status: %s", new_lamp.status)
assert new_lamp.status == "on"
Close
In fact, the last approach with session termination also calls expire_all
, but the session can be used further. And when reading the new data, we will get the up-to-date objects.
await test_session.close()
This should be called immediately after the request for the application is completed and before the checks begin.
Below are the steps in the code when using close
.
Below is the test code with session closure.
@pytest.mark.asyncio
async def test_lamp_on(test_session):
""" Test lamp switch on """
lamp = Lamp()
test_session.add(lamp)
await test_session.commit()
await test_session.refresh(lamp)
logging.error("New client status: %s", lamp.status)
assert lamp.status == "off"
async with AsyncClient(app=app, base_url="http://testserver") as async_client:
response = await async_client.post(f"/lamps/{lamp.id}/on")
assert response.status_code == 200
await test_session.close()
results = await test_session.execute(select(Lamp).where(Lamp.id == 1))
new_lamp = results.scalar_one_or_none()
logging.error("Updated status: %s", new_lamp.status)
assert new_lamp.status == "on"
Calling rollback()
will help as well. It also calls expire_all
, but it explicitly rolls back the transaction. If the transaction needs to be executed, commit()
also executes expire_all
. But in this example, neither rollback
nor commit
will be relevant since the transaction in the test has already been completed, and the transaction in the application does not affect the session from the test.
In fact, this feature only works in SQLAlchemy ORM in async mode in transactions. However, the behavior in which I do make a request to the database in the code to get a new object seems illogical if it still returns a cached object but not the forcibly received one from the database. This is a bit confusing when debugging the code. But when used correctly, this is how it should be.
Conclusion
Working in async mode with SQLAlchemy ORM, you have to track transactions and threads in parallel sessions. If all this seems too difficult, then use SQLAlchemy ORM synchronous mode. Everything is much simpler in it.
Opinions expressed by DZone contributors are their own.
Comments