Skip to content

Commit

Permalink
🐛 fix(sqla): Mapped[Model]
Browse files Browse the repository at this point in the history
  • Loading branch information
ProgramRipper committed Oct 19, 2023
1 parent ffcb953 commit 4f1a235
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 22 deletions.
7 changes: 3 additions & 4 deletions nonebot_plugin_orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from . import migrate
from .model import Model
from .config import Config, plugin_config
from .config import Config
from .utils import LoguruHandler, StreamToLogger

if sys.version_info >= (3, 10):
Expand Down Expand Up @@ -221,9 +220,9 @@ def _init_logger():
l.setLevel(level)


_init_logger()

from .sql import *
from .model import *
from .config import *
from .migrate import *

_init_logger()
59 changes: 41 additions & 18 deletions nonebot_plugin_orm/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import sys
from inspect import Parameter, Signature
from typing import TYPE_CHECKING, Any, ClassVar
from typing_extensions import Self, Unpack, Annotated

from nonebot.params import Depends
from nonebot import get_plugin_by_module_name
from sqlalchemy import Table, MetaData, select
from pydantic.typing import get_args, get_origin
from sqlalchemy.orm import Mapped, DeclarativeBase
from sqlalchemy.orm.decl_api import DeclarativeAttributeIntercept

from .utils import DependsInner, get_annotations

Expand All @@ -30,13 +30,47 @@
}


class Model(DeclarativeBase):
class ModelMeta(DeclarativeAttributeIntercept):
if TYPE_CHECKING:
__signature__: Signature

def __new__(
mcs,
name: str,
bases: tuple[type, ...],
namespace: dict[str, Any],
**kwargs: Any,
) -> ModelMeta:
from . import async_scoped_session

cls: ModelMeta = super().__new__(mcs, name, bases, namespace, **kwargs)

if not (signature := getattr(cls, "__signature__", None)):
return cls

async def dependency(
*, __session: async_scoped_session, **kwargs: Any
) -> ModelMeta | None:
return await __session.scalar(select(cls).filter_by(**kwargs))

dependency.__signature__ = Signature(
(
Parameter(
"_ModelMeta__session",
Parameter.KEYWORD_ONLY,
annotation=async_scoped_session,
),
*signature.parameters.values(),
)
)

return Annotated[cls, Depends(dependency)]


class Model(DeclarativeBase, metaclass=ModelMeta):
metadata = MetaData(naming_convention=NAMING_CONVENTION)

if TYPE_CHECKING:
__args__: ClassVar[tuple[type[Self], Unpack[tuple[Any, ...]]]]
__origin__: type[Annotated]

__table__: ClassVar[Table]
__bind_key__: ClassVar[str]

Expand All @@ -58,11 +92,7 @@ def _setup_di(cls: type[Model]) -> None:
"""
from . import async_scoped_session

parameters: list[Parameter] = [
Parameter(
"__session__", Parameter.KEYWORD_ONLY, annotation=async_scoped_session
)
]
parameters: list[Parameter] = []

annotations: dict[str, Any] = {}
for base in reversed(cls.__mro__):
Expand Down Expand Up @@ -101,14 +131,7 @@ def _setup_di(cls: type[Model]) -> None:
if default is not Signature.empty and not isinstance(default, Mapped):
delattr(cls, name)

async def dependency(
*, __session__: async_scoped_session, **kwargs: Any
) -> Model | None:
return await __session__.scalar(select(cls).filter_by(**kwargs))

dependency.__signature__ = Signature(parameters)
cls.__args__ = (Model, Depends(dependency))
cls.__origin__ = Annotated
cls.__signature__ = Signature(parameters)


def _setup_tablename(cls: type[Model]) -> None:
Expand Down

0 comments on commit 4f1a235

Please sign in to comment.