Skip to content

Commit

Permalink
test: 100% coverage (#45)
Browse files Browse the repository at this point in the history
* test: additional tests for functional.py added

* test: additional tests for the factory.py added

* test: additional tests for test_registry.py added

* fix: extraction of kwargs from factory name is now working
  • Loading branch information
bagxi committed Jan 4, 2024
1 parent e599f11 commit 41c2153
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 36 deletions.
2 changes: 1 addition & 1 deletion hydra_slayer/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_instance(

# assume that name of the factory can be provided as first argument
# or directly by keyword
name, args, kwarg = _extract_factory_name_arg(
name, args, kwargs = _extract_factory_name_arg(
factory_key=factory_key, args=args, kwargs=kwargs
)
if name is None:
Expand Down
6 changes: 3 additions & 3 deletions hydra_slayer/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_factory_name(f, provided_name: str = None) -> str:
if not provided_name:
provided_name = getattr(f, "__name__", None)
if not provided_name:
raise ValueError(f"Factory {f} has no '__name__' and no name was provided")
raise ValueError(f"Factory '{f}' has no '__name__' and no name was provided")
if provided_name == "<lambda>":
raise ValueError("Name for lambda factories must be provided")
return provided_name
Expand Down Expand Up @@ -153,9 +153,9 @@ def add_from_module(
prefix = [prefix]
elif isinstance(prefix, list):
if any((not isinstance(p, str)) for p in prefix):
raise TypeError("All prefix in list must be strings.")
raise TypeError("All prefix in list must be strings")
else:
raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}.")
raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}")

to_add = {f"{p}{name}": factories[name] for p in prefix for name in names_to_add}
self.add(**to_add)
Expand Down
16 changes: 15 additions & 1 deletion tests/foobar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# flake8: noqa
from typing import Any

__all__ = ["foo"]


Expand Down Expand Up @@ -27,8 +29,11 @@ def quuz(**params):


class grault:
def __init__(self, a=1, b=2):
def __init__(self, a: Any = 1, b: int = 2):
self.a = a

if not isinstance(b, int):
raise ValueError
self.b = b

@staticmethod
Expand All @@ -37,3 +42,12 @@ def garply(a, b):

def waldo(self):
return {"a": self.a, "b": self.b}


class fred:
def __init__(self, a):
self.a = a

@classmethod
def get_from_params(cls, a):
return cls(a)
19 changes: 15 additions & 4 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# flake8: noqa
import string

import pytest

from hydra_slayer import factory
from . import foobar


def test_call_meta_factory():
Expand Down Expand Up @@ -43,9 +42,9 @@ def test_default_meta_factory():


def test_fail_get_factory():
with pytest.raises(ValueError) as e_ifo:
error_msg = "factory '.+' is not callable"
with pytest.raises(ValueError, match=error_msg):
factory.default_meta_factory(5, tuple(), {})
assert hasattr(e_ifo.value, "__cause__")


def test_metafactory_factory_meta_factory_arg():
Expand Down Expand Up @@ -100,3 +99,15 @@ def test_metafactory_factory_modes():
res = factory.metafactory_factory(lambda x: x, (42,), {"_mode_": "partial"})

assert res() == 42


def test_fail_metafactory_factory_modes():
error_msg = "'.+' is not a valid call mode"
with pytest.raises(ValueError, match=error_msg):
factory.metafactory_factory(int, (42,), {"_mode_": "foo"})


def test_metafactory_from_params():
res = factory.metafactory_factory(foobar.fred, (), {"a": 42})

assert res.a == 42
42 changes: 32 additions & 10 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# flake8: noqa
import string

import pytest

from hydra_slayer import functional as F
Expand All @@ -25,9 +23,9 @@ def test_get_factories():


def test_fail_get_factory():
with pytest.raises(LookupError) as e_ifo:
error_msg = "No factory with name '.+' was registered"
with pytest.raises(LookupError, match=error_msg):
F.get_factory("tests.foobar.corge")()
assert hasattr(e_ifo.value, "__cause__")


def test_instantiations():
Expand All @@ -48,17 +46,25 @@ def test_instantiations():


def test_fail_instantiation():
with pytest.raises(LookupError) as e_ifo:
error_msg = "No factory with name '.+' was registered"
with pytest.raises(LookupError, match=error_msg):
F.get_instance("tests.foobar.corge")()
assert hasattr(e_ifo.value, "__cause__")

with pytest.raises(TypeError) as e_ifo:
error_msg = r"get_instance\(\) missing at least 1 required argument: '.+'"
with pytest.raises(TypeError, match=error_msg):
F.get_instance(a=1, b=2)()
assert hasattr(e_ifo.value, "__cause__")

with pytest.raises(TypeError) as e_ifo:
error_msg = ".+ got an unexpected keyword argument '.+'"
with pytest.raises(TypeError, match=error_msg):
F.get_instance("tests.foobar.foo", c=1)()
assert hasattr(e_ifo.value, "__cause__")

error_msg = "Factory '.+' call failed: args=.+ kwargs=.+"
with pytest.raises(RuntimeError, match=error_msg):
F.get_instance("tests.foobar.grault", b=1.0)()

warn_msg = r"No signature found for `.+`, \*args and \*\*kwargs arguments cannot be extracted"
with pytest.warns(UserWarning, match=warn_msg):
F.get_instance("int", 1)


def test_from_params():
Expand Down Expand Up @@ -358,3 +364,19 @@ def test_get_from_params_var_method_with_params():
},
)
assert res["b"] == {"a": 1, "b": 2}


def test_fail_get_from_params_on_exclusive_keywords():
error_msg = "`.+` and `.+` \(in get mode\) keywords are exclusive"
with pytest.raises(ValueError, match=error_msg):
F.get_from_params(
**{
"_target_": "tests.foobar.foo",
"a": [
{"_target_": "tests.foobar.foo", "a": 1, "b": 2, "_var_": "x"},
{"_target_": "tests.foobar.foo", "a": 3, "b": 4, "_var_": "x"},
],
"b": 5,
},
shared_params={"_meta_factory_": call_meta_factory},
)
Loading

0 comments on commit 41c2153

Please sign in to comment.