Skip to content

Commit

Permalink
fix: Make litestar example work again, and implement tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sherbang committed Mar 27, 2024
1 parent c3dba02 commit 2572d66
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 82 deletions.
159 changes: 77 additions & 82 deletions examples/litestar.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,66 @@
from __future__ import annotations

from datetime import date, datetime
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List, Union
from uuid import UUID

from litestar import Litestar
from litestar.controller import Controller
from litestar.di import Provide
from litestar.exceptions import NotFoundException as LiteStarNotFoundException
from litestar.handlers.http_handlers.decorators import delete, get, patch, post
from litestar.pagination import OffsetPagination
from litestar.params import Parameter
from pydantic import BaseModel as _BaseModel
from pydantic import TypeAdapter
from sqlalchemy import ForeignKey, select
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload
from typing_extensions import Annotated

from advanced_alchemy.base import UUIDAuditBase, UUIDBase
from advanced_alchemy.config import AsyncSessionConfig
from advanced_alchemy.exceptions import NotFoundError as AdvancedAlchemyNotFoundError
from advanced_alchemy.extensions.litestar.dto import SQLAlchemyDTO, SQLAlchemyDTOConfig
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyAsyncConfig, SQLAlchemyPlugin
from advanced_alchemy.filters import LimitOffset
from advanced_alchemy.repository import SQLAlchemyAsyncRepository

if TYPE_CHECKING:
from litestar.dto import DTOData
from sqlalchemy.ext.asyncio import AsyncSession


class BaseModel(_BaseModel):
"""Extend Pydantic's BaseModel to enable ORM mode"""

model_config = {"from_attributes": True}


# the SQLAlchemy base includes a declarative model for you to use in your models.
# The `Base` class includes a `UUID` based primary key (`id`)
class AuthorModel(UUIDBase):
class Author(UUIDBase):
# we can optionally provide the table name instead of auto-generating it
__tablename__ = "author" # type: ignore[assignment]
name: Mapped[str]
dob: Mapped[date | None]
books: Mapped[list[BookModel]] = relationship(back_populates="author", lazy="noload")
dob: Mapped[Union[date, None]] # noqa: UP007 - needed for SQLAlchemy on older python versions
books: Mapped[List[Book]] = relationship(back_populates="author", lazy="noload") # noqa: UP006


# The `AuditBase` class includes the same UUID` based primary key (`id`) and 2
# additional columns: `created` and `updated`. `created` is a timestamp of when the
# record created, and `updated` is the last time the record was modified.
class BookModel(UUIDAuditBase):
class Book(UUIDAuditBase):
__tablename__ = "book" # type: ignore[assignment]
title: Mapped[str]
author_id: Mapped[UUID] = mapped_column(ForeignKey("author.id"))
author: Mapped[AuthorModel] = relationship(lazy="joined", innerjoin=True, viewonly=True)


# we will explicitly define the schema instead of using DTO objects for clarity.

author: Mapped[Author] = relationship(lazy="joined", innerjoin=True, viewonly=True)

class Author(BaseModel):
id: UUID | None
name: str
dob: date | None = None

# DTO objects let us filter certain fields out of our request/response data
# without defining separate models
class AuthorDTO(SQLAlchemyDTO[Author]):
config = SQLAlchemyDTOConfig(exclude={"books"})

class AuthorCreate(BaseModel):
name: str
dob: date | None = None

class AuthorCreateUpdateDTO(SQLAlchemyDTO[Author]):
config = SQLAlchemyDTOConfig(exclude={"id", "books"})

class AuthorUpdate(BaseModel):
name: str | None = None
dob: date | None = None


class AuthorRepository(SQLAlchemyAsyncRepository[AuthorModel]):
class AuthorRepository(SQLAlchemyAsyncRepository[Author]):
"""Author repository."""

model_type = AuthorModel
model_type = Author


async def provide_authors_repo(db_session: AsyncSession) -> AuthorRepository:
Expand All @@ -86,7 +73,7 @@ async def provide_authors_repo(db_session: AsyncSession) -> AuthorRepository:
async def provide_author_details_repo(db_session: AsyncSession) -> AuthorRepository:
"""This provides a simple example demonstrating how to override the join options for the repository."""
return AuthorRepository(
statement=select(AuthorModel).options(selectinload(AuthorModel.books)),
statement=select(Author).options(selectinload(Author.books)),
session=db_session,
)

Expand Down Expand Up @@ -119,68 +106,74 @@ class AuthorController(Controller):

dependencies = {"authors_repo": Provide(provide_authors_repo)}

@get(path="/authors")
@get(path="/authors", return_dto=AuthorDTO)
async def list_authors(
self,
authors_repo: AuthorRepository,
limit_offset: LimitOffset,
) -> OffsetPagination[Author]:
"""List authors."""
results, total = await authors_repo.list_and_count(limit_offset)
type_adapter = TypeAdapter(list[Author])
return OffsetPagination[Author](
items=type_adapter.validate_python(results),
items=results,
total=total,
limit=limit_offset.limit,
offset=limit_offset.offset,
)

@post(path="/authors")
async def create_author(
self,
authors_repo: AuthorRepository,
data: AuthorCreate,
) -> Author:
@post(path="/authors", dto=AuthorCreateUpdateDTO)
async def create_author(self, authors_repo: AuthorRepository, data: DTOData[Author]) -> Author:
"""Create a new author."""
obj = await authors_repo.add(
AuthorModel(**data.model_dump(exclude_unset=True, exclude_none=True)),
)

# Turn the DTO object into an Author instance.
author = data.create_instance()

obj = await authors_repo.add(author)
await authors_repo.session.commit()
return Author.model_validate(obj)
return obj

# we override the authors_repo to use the version that joins the Books in
@get(path="/authors/{author_id:uuid}", dependencies={"authors_repo": Provide(provide_author_details_repo)})
async def get_author(
self,
authors_repo: AuthorRepository,
author_id: UUID = Parameter( # noqa: B008
title="Author ID",
description="The author to retrieve.",
),
author_id: Annotated[
UUID,
Parameter(
title="Author ID",
description="The author to retrieve.",
),
],
) -> Author:
"""Get an existing author."""
obj = await authors_repo.get(author_id)
return Author.model_validate(obj)
try:
return await authors_repo.get(author_id)
except AdvancedAlchemyNotFoundError as e:
msg = f"Author with id {author_id} not found."
raise LiteStarNotFoundException(msg) from e

@patch(
path="/authors/{author_id:uuid}",
dependencies={"authors_repo": Provide(provide_author_details_repo)},
dto=AuthorCreateUpdateDTO,
)
async def update_author(
self,
authors_repo: AuthorRepository,
data: AuthorUpdate,
author_id: UUID = Parameter( # noqa: B008
title="Author ID",
description="The author to update.",
),
data: DTOData[Author],
author_id: Annotated[
UUID,
Parameter(
title="Author ID",
description="The author to update.",
),
],
) -> Author:
"""Update an author."""
raw_obj = data.model_dump(exclude_unset=True, exclude_none=True)
raw_obj.update({"id": author_id})
obj = await authors_repo.update(AuthorModel(**raw_obj))
author = data.create_instance(id=author_id)
obj = await authors_repo.update(author)
await authors_repo.session.commit()
return Author.model_validate(obj)
return obj

@delete(path="/authors/{author_id:uuid}")
async def delete_author(
Expand All @@ -196,24 +189,26 @@ async def delete_author(
await authors_repo.session.commit()


session_config = AsyncSessionConfig(expire_on_commit=False)
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=session_config,
) # Create 'db_session' dependency.
sqlalchemy_plugin = SQLAlchemyPlugin(config=sqlalchemy_config)


async def on_startup() -> None:
"""Initializes the database."""
async with sqlalchemy_config.get_engine().begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)


app = Litestar(
route_handlers=[AuthorController],
on_startup=[on_startup],
plugins=[sqlalchemy_plugin],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination, sync_to_thread=False)},
signature_namespace={"date": date, "datetime": datetime, "UUID": UUID},
)
def init_app(*, sqlalchemy_config: SQLAlchemyAsyncConfig | None = None) -> Litestar:
if not sqlalchemy_config:
# expire_on_commit=False prevents the sqlalchemy models from being invalidated on commit.
session_config = AsyncSessionConfig(expire_on_commit=False)
sqlalchemy_config = SQLAlchemyAsyncConfig(
connection_string="sqlite+aiosqlite:///test.sqlite",
session_config=session_config,
) # Create 'db_session' dependency.

sqlalchemy_plugin = SQLAlchemyPlugin(config=sqlalchemy_config)

async def on_startup() -> None:
"""Initializes the database."""
async with sqlalchemy_config.get_engine().begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)

return Litestar(
route_handlers=[AuthorController],
on_startup=[on_startup],
plugins=[sqlalchemy_plugin],
dependencies={"limit_offset": Provide(provide_limit_offset_pagination, sync_to_thread=False)},
signature_namespace={"date": date, "datetime": datetime, "UUID": UUID},
)
Empty file added tests/examples/__init__.py
Empty file.
81 changes: 81 additions & 0 deletions tests/examples/test_litestar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from typing import TYPE_CHECKING

import pytest
from litestar.testing import AsyncTestClient

from advanced_alchemy.base import UUIDBase
from advanced_alchemy.config import AsyncSessionConfig
from advanced_alchemy.extensions.litestar.plugins import SQLAlchemyAsyncConfig
from examples.litestar import Author, init_app

if TYPE_CHECKING:
from litestar import Litestar


@pytest.fixture()
async def test_client() -> AsyncIterator[AsyncTestClient[Litestar]]:
# Use an in-memory database for testing and create the tables.
engine = SQLAlchemyAsyncConfig.create_engine_callable("sqlite+aiosqlite:///:memory:")
async with engine.begin() as conn:
await conn.run_sync(UUIDBase.metadata.create_all)

sqlalchemy_config = SQLAlchemyAsyncConfig(
# Use the same session instance for all requests so the database doesn't disappear
engine_instance=engine,
session_config=AsyncSessionConfig(expire_on_commit=False),
)

app = init_app(sqlalchemy_config=sqlalchemy_config)
app.debug = True

async with AsyncTestClient(app=app) as client:
yield client


async def test_create_list(test_client: AsyncTestClient[Litestar]) -> None:
author = Author(name="foo")

response = await test_client.post(
"/authors",
json=author.to_dict(),
)
assert response.status_code == 201, response.text
assert response.json()["name"] == author.name

response = await test_client.get("/authors")
assert response.status_code == 200, response.text
assert response.json()["items"][0]["name"] == author.name


async def test_create_get_update_delete(test_client: AsyncTestClient[Litestar]) -> None:
author = Author(name="foo")

response = await test_client.post(
"/authors",
json=author.to_dict(),
)
assert response.status_code == 201, response.text
assert response.json()["name"] == author.name
author_id = response.json()["id"]

response = await test_client.get(f"/authors/{author_id}")
assert response.status_code == 200, response.text
assert response.json()["name"] == author.name
assert response.json()["id"] == author_id

response = await test_client.patch(
f"/authors/{author_id}",
json={"name": "bar"},
)
assert response.status_code == 200, response.text
assert response.json()["name"] == "bar"
assert response.json()["id"] == author_id

response = await test_client.delete(f"/authors/{author_id}")
assert response.status_code == 204, response.text

response = await test_client.get(f"/authors/{author_id}")
assert response.status_code == 404, response.text

0 comments on commit 2572d66

Please sign in to comment.