Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: 100% coverage #45

Merged
merged 5 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion hydra_slayer/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _get_instance(

# assume that name of the factory can be provided as first argument
# or directly by keyword
name, args, kwarg = _extract_factory_name_arg(
name, args, kwargs = _extract_factory_name_arg(
factory_key=factory_key, args=args, kwargs=kwargs
)
if name is None:
Expand Down
6 changes: 3 additions & 3 deletions hydra_slayer/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_factory_name(f, provided_name: str = None) -> str:
if not provided_name:
provided_name = getattr(f, "__name__", None)
if not provided_name:
raise ValueError(f"Factory {f} has no '__name__' and no name was provided")
raise ValueError(f"Factory '{f}' has no '__name__' and no name was provided")
if provided_name == "<lambda>":
raise ValueError("Name for lambda factories must be provided")
return provided_name
Expand Down Expand Up @@ -153,9 +153,9 @@ def add_from_module(
prefix = [prefix]
elif isinstance(prefix, list):
if any((not isinstance(p, str)) for p in prefix):
raise TypeError("All prefix in list must be strings.")
raise TypeError("All prefix in list must be strings")
else:
raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}.")
raise TypeError(f"Prefix must be a list or a string, got {type(prefix)}")

to_add = {f"{p}{name}": factories[name] for p in prefix for name in names_to_add}
self.add(**to_add)
Expand Down
16 changes: 15 additions & 1 deletion tests/foobar.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# flake8: noqa
from typing import Any

__all__ = ["foo"]


Expand Down Expand Up @@ -27,8 +29,11 @@ def quuz(**params):


class grault:
def __init__(self, a=1, b=2):
def __init__(self, a: Any = 1, b: int = 2):
self.a = a

if not isinstance(b, int):
raise ValueError
self.b = b

@staticmethod
Expand All @@ -37,3 +42,12 @@ def garply(a, b):

def waldo(self):
return {"a": self.a, "b": self.b}


class fred:
def __init__(self, a):
self.a = a

@classmethod
def get_from_params(cls, a):
return cls(a)
19 changes: 15 additions & 4 deletions tests/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# flake8: noqa
import string

import pytest

from hydra_slayer import factory
from . import foobar


def test_call_meta_factory():
Expand Down Expand Up @@ -43,9 +42,9 @@ def test_default_meta_factory():


def test_fail_get_factory():
with pytest.raises(ValueError) as e_ifo:
error_msg = "factory '.+' is not callable"
with pytest.raises(ValueError, match=error_msg):
factory.default_meta_factory(5, tuple(), {})
assert hasattr(e_ifo.value, "__cause__")


def test_metafactory_factory_meta_factory_arg():
Expand Down Expand Up @@ -100,3 +99,15 @@ def test_metafactory_factory_modes():
res = factory.metafactory_factory(lambda x: x, (42,), {"_mode_": "partial"})

assert res() == 42


def test_fail_metafactory_factory_modes():
error_msg = "'.+' is not a valid call mode"
with pytest.raises(ValueError, match=error_msg):
factory.metafactory_factory(int, (42,), {"_mode_": "foo"})


def test_metafactory_from_params():
res = factory.metafactory_factory(foobar.fred, (), {"a": 42})

assert res.a == 42
42 changes: 32 additions & 10 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
# flake8: noqa
import string

import pytest

from hydra_slayer import functional as F
Expand All @@ -25,9 +23,9 @@ def test_get_factories():


def test_fail_get_factory():
with pytest.raises(LookupError) as e_ifo:
error_msg = "No factory with name '.+' was registered"
with pytest.raises(LookupError, match=error_msg):
F.get_factory("tests.foobar.corge")()
assert hasattr(e_ifo.value, "__cause__")


def test_instantiations():
Expand All @@ -48,17 +46,25 @@ def test_instantiations():


def test_fail_instantiation():
with pytest.raises(LookupError) as e_ifo:
error_msg = "No factory with name '.+' was registered"
with pytest.raises(LookupError, match=error_msg):
F.get_instance("tests.foobar.corge")()
assert hasattr(e_ifo.value, "__cause__")

with pytest.raises(TypeError) as e_ifo:
error_msg = r"get_instance\(\) missing at least 1 required argument: '.+'"
with pytest.raises(TypeError, match=error_msg):
F.get_instance(a=1, b=2)()
assert hasattr(e_ifo.value, "__cause__")

with pytest.raises(TypeError) as e_ifo:
error_msg = ".+ got an unexpected keyword argument '.+'"
with pytest.raises(TypeError, match=error_msg):
F.get_instance("tests.foobar.foo", c=1)()
assert hasattr(e_ifo.value, "__cause__")

error_msg = "Factory '.+' call failed: args=.+ kwargs=.+"
with pytest.raises(RuntimeError, match=error_msg):
F.get_instance("tests.foobar.grault", b=1.0)()

warn_msg = r"No signature found for `.+`, \*args and \*\*kwargs arguments cannot be extracted"
with pytest.warns(UserWarning, match=warn_msg):
F.get_instance("int", 1)


def test_from_params():
Expand Down Expand Up @@ -358,3 +364,19 @@ def test_get_from_params_var_method_with_params():
},
)
assert res["b"] == {"a": 1, "b": 2}


def test_fail_get_from_params_on_exclusive_keywords():
error_msg = "`.+` and `.+` \(in get mode\) keywords are exclusive"
with pytest.raises(ValueError, match=error_msg):
F.get_from_params(
**{
"_target_": "tests.foobar.foo",
"a": [
{"_target_": "tests.foobar.foo", "a": 1, "b": 2, "_var_": "x"},
{"_target_": "tests.foobar.foo", "a": 3, "b": 4, "_var_": "x"},
],
"b": 5,
},
shared_params={"_meta_factory_": call_meta_factory},
)
Loading
Loading