Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support “Reuse[Dep1] | Dep2” #30

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 47 additions & 14 deletions andi/andi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from collections import OrderedDict, defaultdict
from dataclasses import dataclass
from typing import (
Dict, List, Optional, Type, Callable, Union, Container,
Tuple, MutableMapping, Any, Mapping)
Annotated, Dict, List, Optional, Type, TypeVar, Callable, Union, Container,
Tuple, MutableMapping, Any, Mapping, get_args)

from andi.typeutils import (
get_union_args,
Expand All @@ -11,6 +11,7 @@
get_unannotated_params,
get_callable_func_obj,
get_type_hints_with_extras,
is_typing_annotated,
strip_annotated,
)
from andi.errors import (
Expand All @@ -21,6 +22,11 @@
)


_T = TypeVar("T")
_REUSE_ANNOTATION = object()
Reuse = Annotated[_T, _REUSE_ANNOTATION]


def inspect(class_or_func: Callable) -> Dict[str, List[Optional[Type]]]:
"""
For each argument of the ``class_or_func`` return a list of possible types.
Expand Down Expand Up @@ -308,15 +314,26 @@ class we want to override. If ``recursive_overrides`` is True, then
overrides = overrides or _empty_overrides
class_or_func, overrides = _may_override(class_or_func, overrides, recursive_overrides)

plan, _ = _plan(class_or_func,
is_injectable=is_injectable,
externally_provided=externally_provided,
full_final_kwargs=full_final_kwargs,
dependency_stack=None,
overrides=overrides,
recursive_overrides=recursive_overrides,
custom_builder_fn=custom_builder_fn,
)
plan_deps = set()
while not plan_deps or plan_deps != last_plan_deps:
last_plan_deps = plan_deps
plan, _ = _plan(class_or_func,
is_injectable=is_injectable,
externally_provided=externally_provided,
full_final_kwargs=full_final_kwargs,
dependency_stack=None,
overrides=overrides,
recursive_overrides=recursive_overrides,
custom_builder_fn=custom_builder_fn,
last_plan_deps=last_plan_deps,
)
plan_deps = {item[0] for item in plan or []}

# TODO: Remove logging here.
from logging import getLogger
logger = getLogger(__name__)
logger.error(plan_deps)

return plan


Expand All @@ -334,7 +351,8 @@ def _plan(class_or_func: Callable, *,
overrides: Callable[[Callable], Optional[Callable]],
recursive_overrides: bool = False,
custom_builder_fn: Callable[[Callable], Optional[Callable]] = lambda _: None,
custom_builder_result: Optional[Callable] = None
custom_builder_result: Optional[Callable] = None,
last_plan_deps: Optional[Plan] = None,
) -> Tuple[Plan, List[Tuple]]:
dependency_stack = dependency_stack or []
is_root_call = not dependency_stack # For better code reading
Expand Down Expand Up @@ -363,7 +381,7 @@ def _plan(class_or_func: Callable, *,
for argname, types in arguments.items():
sel_cls, arg_overrides = _select_type(
types, is_injectable, externally_provided, overrides, recursive_overrides,
custom_builder_fn
custom_builder_fn, last_plan_deps
)
if sel_cls is not None:
errors = [] # type: List[Tuple]
Expand Down Expand Up @@ -432,6 +450,7 @@ def _select_type(types,
overrides: Callable[[Callable], Optional[Callable]],
recursive_overrides: bool,
custom_builder_fn: Callable[[Callable], Optional[Callable]] = lambda _: None,
last_plan_deps: Plan = None,
) -> Tuple[Optional[Callable], OverrideFn]:
"""
Choose the first type that can be provided. None otherwise. Also return
Expand All @@ -440,12 +459,26 @@ def _select_type(types,
for candidate in types:
candidate, new_overrides = _may_override(
candidate, overrides, recursive_overrides)
candidate_stripped = strip_annotated(candidate)

if is_typing_annotated(candidate):
candidate_stripped, annotation1, *_ = get_args(candidate)
if annotation1 == _REUSE_ANNOTATION:
if not last_plan_deps or candidate_stripped not in last_plan_deps:
continue
candidate = candidate_stripped
else:
candidate_stripped = candidate
if (
is_injectable(candidate_stripped)
or externally_provided(candidate_stripped)
or custom_builder_fn(candidate_stripped) is not None
):

# TODO: Remove logging here.
from logging import getLogger
logger = getLogger(__name__)
logger.error(f"{types} → {candidate}")

return candidate, new_overrides
return None, overrides

Expand Down
Loading
Loading