diff --git a/hydra_slayer/functional.py b/hydra_slayer/functional.py index 3f220bc..5e16c20 100644 --- a/hydra_slayer/functional.py +++ b/hydra_slayer/functional.py @@ -105,7 +105,7 @@ def _get_instance( # assume that name of the factory can be provided as first argument # or directly by keyword - name, args, kwarg = _extract_factory_name_arg( + name, args, kwargs = _extract_factory_name_arg( factory_key=factory_key, args=args, kwargs=kwargs ) if name is None: diff --git a/hydra_slayer/registry.py b/hydra_slayer/registry.py index f9c390d..2bfa3a1 100644 --- a/hydra_slayer/registry.py +++ b/hydra_slayer/registry.py @@ -43,7 +43,7 @@ def _get_factory_name(f, provided_name: str = None) -> str: if not provided_name: provided_name = getattr(f, "__name__", None) if not provided_name: - raise ValueError(f"Factory {f} has no '__name__' and no name was provided") + raise ValueError(f"Factory '{f}' has no '__name__' and no name was provided") if provided_name == "": raise ValueError("Name for lambda factories must be provided") return provided_name @@ -153,9 +153,9 @@ def add_from_module( prefix = [prefix] elif isinstance(prefix, list): if any((not isinstance(p, str)) for p in prefix): - raise TypeError("All prefix in list must be strings.") + raise TypeError("All prefix in list must be strings") else: - raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}.") + raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}") to_add = {f"{p}{name}": factories[name] for p in prefix for name in names_to_add} self.add(**to_add) diff --git a/tests/foobar.py b/tests/foobar.py index 455769e..8f2e4ed 100644 --- a/tests/foobar.py +++ b/tests/foobar.py @@ -1,4 +1,6 @@ # flake8: noqa +from typing import Any + __all__ = ["foo"] @@ -27,8 +29,11 @@ def quuz(**params): class grault: - def __init__(self, a=1, b=2): + def __init__(self, a: Any = 1, b: int = 2): self.a = a + + if not isinstance(b, int): + raise ValueError self.b = b @staticmethod @@ -37,3 +42,12 @@ def garply(a, b): def waldo(self): return {"a": self.a, "b": self.b} + + +class fred: + def __init__(self, a): + self.a = a + + @classmethod + def get_from_params(cls, a): + return cls(a) diff --git a/tests/test_factory.py b/tests/test_factory.py index 73325dd..1b03c9b 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -1,9 +1,8 @@ # flake8: noqa -import string - import pytest from hydra_slayer import factory +from . import foobar def test_call_meta_factory(): @@ -43,9 +42,9 @@ def test_default_meta_factory(): def test_fail_get_factory(): - with pytest.raises(ValueError) as e_ifo: + error_msg = "factory '.+' is not callable" + with pytest.raises(ValueError, match=error_msg): factory.default_meta_factory(5, tuple(), {}) - assert hasattr(e_ifo.value, "__cause__") def test_metafactory_factory_meta_factory_arg(): @@ -100,3 +99,15 @@ def test_metafactory_factory_modes(): res = factory.metafactory_factory(lambda x: x, (42,), {"_mode_": "partial"}) assert res() == 42 + + +def test_fail_metafactory_factory_modes(): + error_msg = "'.+' is not a valid call mode" + with pytest.raises(ValueError, match=error_msg): + factory.metafactory_factory(int, (42,), {"_mode_": "foo"}) + + +def test_metafactory_from_params(): + res = factory.metafactory_factory(foobar.fred, (), {"a": 42}) + + assert res.a == 42 diff --git a/tests/test_functional.py b/tests/test_functional.py index 9d4396d..93f9a00 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,6 +1,4 @@ # flake8: noqa -import string - import pytest from hydra_slayer import functional as F @@ -25,9 +23,9 @@ def test_get_factories(): def test_fail_get_factory(): - with pytest.raises(LookupError) as e_ifo: + error_msg = "No factory with name '.+' was registered" + with pytest.raises(LookupError, match=error_msg): F.get_factory("tests.foobar.corge")() - assert hasattr(e_ifo.value, "__cause__") def test_instantiations(): @@ -48,17 +46,25 @@ def test_instantiations(): def test_fail_instantiation(): - with pytest.raises(LookupError) as e_ifo: + error_msg = "No factory with name '.+' was registered" + with pytest.raises(LookupError, match=error_msg): F.get_instance("tests.foobar.corge")() - assert hasattr(e_ifo.value, "__cause__") - with pytest.raises(TypeError) as e_ifo: + error_msg = r"get_instance\(\) missing at least 1 required argument: '.+'" + with pytest.raises(TypeError, match=error_msg): F.get_instance(a=1, b=2)() - assert hasattr(e_ifo.value, "__cause__") - with pytest.raises(TypeError) as e_ifo: + error_msg = ".+ got an unexpected keyword argument '.+'" + with pytest.raises(TypeError, match=error_msg): F.get_instance("tests.foobar.foo", c=1)() - assert hasattr(e_ifo.value, "__cause__") + + error_msg = "Factory '.+' call failed: args=.+ kwargs=.+" + with pytest.raises(RuntimeError, match=error_msg): + F.get_instance("tests.foobar.grault", b=1.0)() + + warn_msg = r"No signature found for `.+`, \*args and \*\*kwargs arguments cannot be extracted" + with pytest.warns(UserWarning, match=warn_msg): + F.get_instance("int", 1) def test_from_params(): @@ -358,3 +364,19 @@ def test_get_from_params_var_method_with_params(): }, ) assert res["b"] == {"a": 1, "b": 2} + + +def test_fail_get_from_params_on_exclusive_keywords(): + error_msg = "`.+` and `.+` \(in get mode\) keywords are exclusive" + with pytest.raises(ValueError, match=error_msg): + F.get_from_params( + **{ + "_target_": "tests.foobar.foo", + "a": [ + {"_target_": "tests.foobar.foo", "a": 1, "b": 2, "_var_": "x"}, + {"_target_": "tests.foobar.foo", "a": 3, "b": 4, "_var_": "x"}, + ], + "b": 5, + }, + shared_params={"_meta_factory_": call_meta_factory}, + ) diff --git a/tests/test_registry.py b/tests/test_registry.py index b394ef1..fb15816 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -2,7 +2,7 @@ import pytest from hydra_slayer.registry import Registry -from .foobar import foo +from .foobar import bar, foo from . import foobar as module @@ -22,13 +22,24 @@ def test_add_function_name_override(): assert "bar" in r._factories -def test_add_lambda_fail(): +def test_add_fail_on_lambda(): r = Registry() - with pytest.raises(ValueError): + error_msg = "Name for lambda factories must be provided" + with pytest.raises(ValueError, match=error_msg): r.add(lambda x: x) +def test_add_fail_on_no_name(): + r = Registry() + + obj = 42 + + error_msg = "Factory '.+' has no '__name__' and no name was provided" + with pytest.raises(ValueError, match=error_msg): + r.add(obj, name=None) + + def test_add_lambda_override(): r = Registry() @@ -40,31 +51,75 @@ def test_add_lambda_override(): def test_fail_multiple_with_name(): r = Registry() - with pytest.raises(ValueError): + error_msg = "Multiple factories with single name are not allowed" + with pytest.raises(ValueError, match=error_msg): r.add(foo, foo, name="bar") def test_fail_double_add_different(): r = Registry() - r.add(foo) - with pytest.raises(LookupError): - - def bar(): - pass + r.add(foo) + error_msg = "Factory with name '.+' is already present\nAlready registered: '.+'\nNew: '.+'" + with pytest.raises(LookupError, match=error_msg): r.add(foo=bar) def test_double_add_same_nofail(): r = Registry() + r.add(foo) + # It's ok to add same twice, forced by python relative import # implementation # https://github.com/catalyst-team/catalyst/issues/135 r.add(foo) +def test_add_args_support(): + r = Registry() + + r.add(foo, bar) + + assert "foo" in r._factories and "bar" in r._factories + + +def test_add_kwargs_support(): + r = Registry() + + r.add(foo=foo) + + assert "foo" in r._factories + + +def test_add_warns_on_empty_kwargs(): + r = Registry() + + warn_msg = "No factories were provided!" + with pytest.warns(UserWarning, match=warn_msg): + r.add(**{}) + + +def test_get_empty(): + r = Registry() + + res = r.get(None) + assert res is None + + +def test_get_if_str(): + r = Registry() + + r.add(foo=foo) + + res = r.get_if_str("foo") + assert res == foo + + res = r.get_if_str(42) + assert res == 42 + + def test_instantiations(): r = Registry() @@ -86,13 +141,12 @@ def test_instantiations(): def test_fail_instantiation(): r = Registry() - r.add(foo) + assert r.add(foo) is not None - with pytest.raises((RuntimeError, TypeError)) as e_ifo: + error_msg = ".+ got an unexpected keyword argument '.+'" + with pytest.raises((RuntimeError, TypeError), match=error_msg): r.get_instance("foo", c=1)() - assert hasattr(e_ifo.value, "__cause__") - def test_decorator(): r = Registry() @@ -101,7 +155,7 @@ def test_decorator(): def bar(): pass - r.get("bar") + assert r.get("bar") is not None def test_kwargs(): @@ -109,7 +163,20 @@ def test_kwargs(): r.add(bar=foo) - r.get("bar") + assert r.get("bar") is not None + + +def test_late_add(): + def callback(registry: Registry) -> None: + registry.add(foo) + + r = Registry() + + r.late_add(callback) + + assert r._factories == {} + + assert r.all() == ("foo",) def test_add_module(): @@ -117,12 +184,45 @@ def test_add_module(): r.add_from_module(module) - r.get("foo") + assert r.get("foo") is not None - with pytest.raises(LookupError): + error_msg = "No factory with name '.+' was registered" + with pytest.raises(LookupError, match=error_msg): r.get_instance("bar") +def test_add_module_adds_all(): + r = Registry() + + r.add_from_module(module, ignore_all=True) + + assert "foo" in r._factories and "bar" in r._factories + + +def test_add_module_prefix_support(): + r = Registry() + + r.add_from_module(module, prefix="m.") + + r.get("m.foo") + + error_msg = "No factory with name '.+' was registered" + with pytest.raises(LookupError, match=error_msg): + r.get_instance("foo") + + +def test_add_from_module_fails_on_invalid_prefix(): + r = Registry() + + error_msg = "All prefix in list must be strings" + with pytest.raises(TypeError, match=error_msg): + r.add_from_module(module, prefix=["42", 42]) + + error_msg = "Prefix must be a list or a string, got .+" + with pytest.raises(TypeError, match=error_msg): + r.add_from_module(module, prefix=42) + + def test_from_config(): r = Registry() @@ -462,3 +562,102 @@ def test_get_from_params_vars_dict(): }, ) assert res["b"] == 4 + + +def test_all_magic_method(): + r = Registry() + + r.add(foo) + + res = r.all() + assert res == ("foo",) + + r.add(bar) + + res = r.all() + assert res == ("foo", "bar") + + +def test_str_magic_method(): + r = Registry() + + r.add(foo) + + res = r.__str__() + assert res == "('foo',)" + + r.add(bar) + + res = r.__str__() + assert res == "('foo', 'bar')" + + +def test_repr_magic_method(): + r = Registry() + + r.add(foo) + + res = r.__repr__() + assert res == "('foo',)" + + r.add(bar) + + res = r.__repr__() + assert res == "('foo', 'bar')" + + +def test_len_magic_method(): + r = Registry() + + r.add(foo) + + res = len(r) + assert res == 1 + + r.add(bar) + + res = len(r) + assert res == 2 + + +def test_getitem_magic_method(): + r = Registry() + + r.add(foo) + + res = r["foo"] + assert res == foo + + +def test_iter_magic_method(): + r = Registry() + + r.add(foo) + + res = next(iter(r)) + assert res == "foo" + + +def test_contains_magic_method(): + r = Registry() + + r.add(foo) + + assert "foo" in r + + +def test_setitem_magic_method(): + r = Registry() + + r["bar"] = foo + + assert "bar" in r._factories + + +def test_delitem_magic_method(): + r = Registry() + + r.add(foo) + + del r["foo"] + assert r._factories == {}