Skip to content

Commit

Permalink
Enable annotation documentation locations (#71)
Browse files Browse the repository at this point in the history
This PR is motivated by the need to reduce cyclic imports / directly
pass a resolver into the `ResolverKey`.

This PR introduces the `location` argument so you can explicitly say
what the import / documentation path is for a given resolver. This
allows a pre-instantiated resolver to be passed as the second argument
to `ResolverKey` and in some places, avoid cyclic imports in
documentation build due to its dynamic import functionality.
  • Loading branch information
cthoyt authored Nov 2, 2024
1 parent c1c5e61 commit ed6e428
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 7 deletions.
3 changes: 3 additions & 0 deletions src/class_resolver/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
synonyms: Mapping[str, type[X]] | None = None,
synonym_attribute: str | None = "synonyms",
base_as_suffix: bool = True,
location: str | None = None,
) -> None:
"""Initialize the resolver.
Expand All @@ -107,6 +108,7 @@ def __init__(
:param synonym_attribute: The attribute to look in each class for synonyms. Explicitly set to None
to turn off synonym lookup.
:param base_as_suffix: Should the base class's name be used as the suffix if none is given? Defaults to true.
:param location: The location used to document the resolver in sphinx
"""
self.base = base
self.synonyms_attribute = synonym_attribute
Expand All @@ -120,6 +122,7 @@ def __init__(
synonyms=synonyms,
default=default,
suffix=suffix,
location=location,
)

def extract_name(self, element: type[X]) -> str:
Expand Down
12 changes: 12 additions & 0 deletions src/class_resolver/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,30 @@ class BaseResolver(ABC, Generic[X, Y]):
#: The shared suffix fo all classes derived from the base class
suffix: str | None

#: The string used to document the resolver in a sphinx item.
#: For example, a resolver for aggregations is available in class-resovler
#: that can be imported from ``class_resolver.contrib.numpy.aggregation_resolver``.
#: It can be documented with sphinx using ``:data:`class_resolver.contrib.numpy.aggregation_resolve```,
#: which creates this kind of link :data:`class_resolver.contrib.numpy.aggregation_resolve`
#: (assuming you have intersphinx set up properly).
location: str | None

def __init__(
self,
elements: Iterable[X] | None = None,
*,
default: X | None = None,
synonyms: Mapping[str, X] | None = None,
suffix: str | None = None,
location: str | None = None,
):
"""Initialize the resolver.
:param elements: The elements to register
:param default: The optional default element
:param synonyms: The optional synonym dictionary
:param suffix: The optional shared suffix of all instances
:param location: The location used to document the resolver in sphinx
"""
self.default = default
self.synonyms = dict(synonyms or {})
Expand All @@ -114,6 +124,8 @@ def __init__(
for element in elements:
self.register(element)

self.location = location

def __iter__(self) -> Iterator[X]:
"""Iterate over the registered elements."""
return iter(self.lookup_dict.values())
Expand Down
13 changes: 9 additions & 4 deletions src/class_resolver/contrib/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
"aggregation_resolver",
]

aggregation_resolver = FunctionResolver([np.sum, np.max, np.min, np.mean, np.median], default=np.mean)
# compat with older numpy versions, where np.min points to np.amin
aggregation_resolver.register(np.min, synonyms={"min"}, raise_on_conflict=False)
aggregation_resolver.register(np.max, synonyms={"max"}, raise_on_conflict=False)
aggregation_resolver = FunctionResolver(
[np.sum, np.max, np.min, np.mean, np.median],
default=np.mean,
location="class_resolver.contrib.numpy.aggregation_resolver",
)
"""A resolver for common aggregation functions in NumPy including the following functions:
- :func:`numpy.sum`
Expand Down Expand Up @@ -45,3 +46,7 @@ def first(x):
arr = [1, 2, 3, 10]
assert 1 == func(arr)
"""

# compat with older numpy versions, where np.min points to np.amin
aggregation_resolver.register(np.min, synonyms={"min"}, raise_on_conflict=False)
aggregation_resolver.register(np.max, synonyms={"max"}, raise_on_conflict=False)
2 changes: 2 additions & 0 deletions src/class_resolver/contrib/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
default=TPESampler,
suffix="Sampler",
exclude_private=False,
location="class_resolver.contrib.optuna.sampler_resolver",
)
"""A resolver for :class:`optuna.samplers.BaseSampler` subclasses.
Expand Down Expand Up @@ -49,6 +50,7 @@ def optimize_study(sampler: Hint[BaseSampler] = None):
default=MedianPruner,
suffix="Pruner",
exclude_private=False,
location="class_resolver.contrib.optuna.pruner_resolver",
)
"""A resolver for :class:`optuna.pruners.BasePruner` subclasses.
Expand Down
1 change: 1 addition & 0 deletions src/class_resolver/contrib/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
base=BaseEstimator,
base_as_suffix=False,
default=LogisticRegression,
location="class_resolver.contrib.sklearn.classifier_resolver",
)
"""A resolver for classifiers.
Expand Down
6 changes: 6 additions & 0 deletions src/class_resolver/contrib/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
Optimizer,
default=Adam,
base_as_suffix=False,
location="class_resolver.contrib.torch.optimizer_resolver",
)
"""A resolver for :class:`torch.optim.Optimizer` classes.
Expand Down Expand Up @@ -80,6 +81,7 @@ def train(
base=nn.Module,
default=activation.ReLU,
base_as_suffix=False,
location="class_resolver.contrib.torch.activation_resolver",
)
"""A resolver for :mod:`torch.nn.modules.activation` classes.
Expand Down Expand Up @@ -120,6 +122,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
hard=nn.ReLU,
soft=nn.Softplus,
),
location="class_resolver.contrib.torch.margin_activation_resolver",
)
"""A resolver for a subset of :mod:`torch.nn.modules.activation` classes.
Expand All @@ -131,6 +134,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
initializer_resolver = FunctionResolver(
[func for name, func in vars(init).items() if not name.startswith("_") and name.endswith("_")],
default=init.normal_,
location="class_resolver.contrib.torch.initializer_resolver",
)
"""A resolver for :mod:`torch.nn.init` functions.
Expand Down Expand Up @@ -168,6 +172,7 @@ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
LRScheduler,
default=ExponentialLR,
suffix="LR",
location="class_resolver.contrib.torch.lr_scheduler_resolver",
)
"""A resolver for learning rate schedulers.
Expand Down Expand Up @@ -211,6 +216,7 @@ def train(
aggregation_resolver = FunctionResolver(
[torch.sum, torch.max, torch.min, torch.mean, torch.logsumexp, torch.median],
default=torch.mean,
location="class_resolver.contrib.torch.aggregation_resolver",
)
"""A resolver for common aggregation functions in PyTorch including the following functions:
Expand Down
3 changes: 2 additions & 1 deletion src/class_resolver/contrib/torch_geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
base=MessagePassing,
suffix="Conv",
default=SimpleConv,
location="class_resolver.contrib.torch_geometric.message_passing_resolver",
)
"""A resolver for message passing layers.
Expand All @@ -28,8 +29,8 @@
aggregation_resolver = ClassResolver.from_subclasses(
base=Aggregation,
default=MeanAggregation,
location="class_resolver.contrib.torch_geometric.aggregation_resolver",
)

"""A resolver for aggregation layers.
This includes the following:
Expand Down
7 changes: 6 additions & 1 deletion src/class_resolver/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@


def _get_qualpath_from_object(resolver: BaseResolver) -> str:
raise NotImplementedError
if resolver.location:
return resolver.location
raise NotImplementedError(
"Can not get a qualified name for auto-generation of sphinx documentation "
"for a resolver that doesn't have the `location` variable set"
)


class ResolverKey:
Expand Down
30 changes: 29 additions & 1 deletion tests/test_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch import Tensor, nn

from class_resolver import ResolverKey, update_docstring_with_resolver_keys
from class_resolver import FunctionResolver, ResolverKey, update_docstring_with_resolver_keys
from class_resolver.contrib.torch import activation_resolver, aggregation_resolver
from class_resolver.docs import _clean_docstring

Expand Down Expand Up @@ -54,6 +54,10 @@
ResolverKey("activation", "class_resolver.contrib.torch.activation_resolver"),
)

TEST_RESOLVER_2 = update_docstring_with_resolver_keys(
ResolverKey("activation", activation_resolver),
)

EXPECTED_FUNCTION_1_DOC = """\
Apply an activation then aggregation.
Expand Down Expand Up @@ -193,6 +197,15 @@ def f4(
""".rstrip()


@TEST_RESOLVER_2
def f5(activation, activation_kwargs):
"""Apply an activation then aggregation.
:param activation: An activation function (stateful)
:param activation_kwargs: Keyword arguments for activation function
"""


class DecoratorTests(unittest.TestCase):
"""Decorator tests."""

Expand Down Expand Up @@ -239,6 +252,11 @@ def test_clean_docstring(self) -> None:
with self.subTest(docstring=ds):
self.assertEqual(TARGET, _clean_docstring(ds))

def test_bad_type(self):
"""Raise the appropriate error."""
with self.assertRaises(TypeError):
ResolverKey("", None)

def test_no_params(self):
"""Test when no keys are passed."""
with self.assertRaises(ValueError):
Expand All @@ -258,6 +276,12 @@ def test_missing_params(self):
def f(x):
"""Do the thing."""

def test_no_location(self):
"""Test when there's no explicit location given."""
r = FunctionResolver([])
with self.assertRaises(NotImplementedError):
ResolverKey("xx", r)

def test_f1(self):
"""Test the correct docstring is produced."""
self.assertEqual(EXPECTED_FUNCTION_1_DOC, f1.__doc__)
Expand All @@ -274,6 +298,10 @@ def test_f4(self):
"""Test the correct docstring is produced."""
self.assertEqual(EXPECTED_FUNCTION_4_DOC, f4.__doc__)

def test_f5(self):
"""Test the correct docstring is produced."""
self.assertEqual(EXPECTED_FUNCTION_1_DOC, f5.__doc__)


class TestTable(unittest.TestCase):
"""Test building tables for resolvers."""
Expand Down

0 comments on commit ed6e428

Please sign in to comment.