Skip to content

Commit

Permalink
feat: _mode_ keyword added (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
bagxi authored Oct 20, 2021
1 parent 79db2e6 commit af87664
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 22 deletions.
21 changes: 8 additions & 13 deletions docs/pages/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,15 @@ Creating ``pd.DataFrame`` from config
# By default, hydra-slayer use partial fit for functions
# (what is useful with activation functions in neural networks).
# But if we want to call ``pandas.read_csv`` function instead,
# then we should pass ``call_meta_factory`` manually.
meta_factory: &call_function
_target_: hydra_slayer.call_meta_factory
# then we should set ``call`` mode manually.
_mode_: call
right:
_target_: pandas.read_csv
filepath_or_buffer: dataset/dataset_part2.csv
meta_factory: *call_function
_mode_: call
how: inner
'on': user
meta_factory: *call_function
_mode_: call
.. code-block:: python
Expand Down Expand Up @@ -319,11 +318,10 @@ Extending configs
# config.yaml
dataset:
_target_: hydra_slayer.get_from_params
# ``yaml.safe_load`` will return dictionary with parameters,
# but to get ``DataLoader`` additional ``hydra_slayer.get_from_params``
# should be used.
_target_: hydra_slayer.get_from_params
kwargs:
# Read dataset from "dataset.yaml", roughly equivalent to
# with open("dataset.yaml") as stream:
Expand All @@ -332,16 +330,13 @@ Extending configs
stream:
_target_: open
file: dataset.yaml
meta_factory: &call_function
_target_: hydra_slayer.call_meta_factory
meta_factory: *call_function
_mode_: call
_mode_: call
model:
_target_: torchvision.models.resnet18
pretrained: true
meta_factory:
_target_: hydra_slayer.call_meta_factory
_mode_: call
criterion:
_target_: torch.nn.CrossEntropyLoss
Expand Down
53 changes: 48 additions & 5 deletions hydra_slayer/factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Callable, Mapping, Tuple, Type, Union
import copy
import functools
import inspect

Expand All @@ -7,6 +8,8 @@
Factory = Union[Type, Callable[..., Any]]
MetaFactory = Callable[[Factory, Tuple, Mapping], Any]

DEFAULT_CALL_MODE_KEY = "_mode_"


def call_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping):
"""Creates a new instance from ``factory``.
Expand Down Expand Up @@ -41,10 +44,24 @@ def partial_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping):


def default_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping):
"""
Creates a new instance from ``factory`` if ``factory`` is class
(like :py:func:`call_meta_factory`), else returns a new partial object
(like :py:func:`partial_meta_factory`).
"""Returns a new instance or a new partial object.
* _mode_='auto'
Creates a new instance from ``factory`` if ``factory`` is class
(like :py:func:`call_meta_factory`), else returns a new partial object
(like :py:func:`partial_meta_factory`).
* _mode_='call'
Returns a result of the factory called with the positional arguments
``args`` and keyword arguments ``kwargs``.
* _mode_='partial'
Returns a new partial object which when called will behave like factory
called with the positional arguments ``args`` and keyword arguments
``kwargs``.
Args:
factory: factory to create instance from
Expand All @@ -54,7 +71,33 @@ def default_meta_factory(factory: Factory, args: Tuple, kwargs: Mapping):
Returns:
Instance.
Raises:
ValueError: if mode not in list: ``'auto'``, ``'call'``, ``'partial'``.
Examples:
>>> default_meta_factory(int, (42,))
42
>>> # please note that additional () are used
>>> default_meta_factory(lambda x: x, (42,))()
42
>>> default_meta_factory(int, ('42',), {"base": 16})
66
>>> # please note that additional () are not needed
>>> default_meta_factory(lambda x: x, (42,), {"_mode_": "call"})
42
>>> default_meta_factory(lambda x: x, ('42',), {"_mode_": "partial", "base": 16})()
66
"""
if inspect.isfunction(factory):
# make a copy of kwargs since we don't want to modify them directly
kwargs = copy.copy(kwargs)
mode = kwargs.pop(DEFAULT_CALL_MODE_KEY, "auto")
if mode not in {"auto", "call", "partial"}:
raise ValueError(f"`{mode}` is not a valid call mode")

if mode == "auto" and inspect.isfunction(factory) or mode == "partial":
return partial_meta_factory(factory, args, kwargs)
return call_meta_factory(factory, args, kwargs)
6 changes: 4 additions & 2 deletions hydra_slayer/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,12 @@ def get_from_params(*, shared_params: Optional[Dict[str, Any]] = None, **kwargs)
Creates instance based in configuration dict with ``instantiation_fn``.
Note:
The name of the factory to use should be provided by ``'_target_'`` keyword.
The name of the factory to use should be provided
by ``'_target_'`` keyword.
Args:
shared_params: params to pass on all levels in case of recursive creation
shared_params: params to pass on all levels in case of
recursive creation
**kwargs: named parameters for factory
Returns:
Expand Down
6 changes: 4 additions & 2 deletions hydra_slayer/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def add_from_module(
module: module to scan
prefix: prefix string for all the module's factories.
If prefix is a list, all values will be treated as aliases
ignore_all: if ``True``, ignores ``__all__`` attribute of the module
ignore_all: if ``True``, ignores ``__all__`` attribute
of the module
Raises:
TypeError: if prefix is not a list or a string
Expand Down Expand Up @@ -203,7 +204,8 @@ def get_from_params(
If ``config[name_key]`` is None, ``None`` is returned.
Args:
shared_params: params to pass on all levels in case of recursive creation
shared_params: params to pass on all levels in case of
recursive creation
**kwargs: \*\*kwargs to pass to the factory
Returns:
Expand Down
32 changes: 32 additions & 0 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,35 @@ def test_default_meta_factory():
res = default_meta_factory(lambda x: x, (42,), {})

assert res() == 42


def test_default_meta_factory_mode():
# `int` is class, so `call_meta_factory` is expected
res = default_meta_factory(int, (42,), {"_mode_": "auto"})

assert res == 42

# `lambda` is function, so `partial_meta_factory` is expected
res = default_meta_factory(lambda x: x, (42,), {"_mode_": "auto"})

assert res() == 42

# _mode_='call', so `call_meta_factory` is expected
res = default_meta_factory(int, (42,), {"_mode_": "call"})

assert res == 42

# _mode_='call', so `call_meta_factory` is expected
res = default_meta_factory(lambda x: x, (42,), {"_mode_": "call"})

assert res == 42

# _mode_='partial', so `partial_meta_factory` is expected
res = default_meta_factory(int, (42,), {"_mode_": "partial"})

assert res() == 42

# _mode_='partial', so `partial_meta_factory` is expected
res = default_meta_factory(lambda x: x, (42,), {"_mode_": "partial"})

assert res() == 42

0 comments on commit af87664

Please sign in to comment.