Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide ParameterCollection.where for efficient conditional iteration of parameters #1899

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
33 changes: 30 additions & 3 deletions armi/reactor/parameters/parameterCollections.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import copy
import pickle
from typing import Any, Optional, List, Set
from typing import Any, Optional, List, Set, Iterator, Callable
import sys

import numpy as np
Expand Down Expand Up @@ -359,7 +359,7 @@ def __contains__(self, name):
else:
return name in self._hist

def __eq__(self, other):
def __eq__(self, other: "ParameterCollection"):
if not isinstance(other, self.__class__):
return False

Expand All @@ -374,7 +374,8 @@ def __eq__(self, other):

return True

def __iter__(self):
def __iter__(self) -> Iterator[str]:
"""Iterate over names of assigned parameters define on this collection."""
return (
pd.name
for pd in self.paramDefs
Expand Down Expand Up @@ -493,6 +494,32 @@ def restoreBackup(self, paramsToApply):
pd.assigned = SINCE_ANYTHING
self.assigned = SINCE_ANYTHING

def where(
self, f: Callable[[parameterDefinitions.Parameter], bool]
) -> Iterator[parameterDefinitions.Parameter]:
"""Produce an iterator over parameters that meet some criteria.

Parameters
----------
f : callable function f(parameter) -> bool
Function to check if a parameter should be fetched during the iteration.

Returns
-------
iterator of :class:`armi.reactor.parameters.Parameter`
Iterator, **not** list or tuple, that produces each parameter that
meets ``f(parameter) == True``.

Examples
--------
>>> block = r.core[0][0]
>>> pdef = block.p.paramDefs
>>> for param in pdef.where(lambda pd: pd.atLocation(ParamLocation.EDGES)):
... print(param.name, block.p[param.name])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UNIMPORTANT

I see you put these empty lines at the end of a about half your docstrings. I just checked, the RST renders fine without this.

"""
return filter(f, self.paramDefs)


def collectPluginParameters(pm):
"""Apply parameters from plugins to their respective object classes."""
Expand Down
4 changes: 4 additions & 0 deletions armi/reactor/parameters/parameterDefinitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,10 @@ def atLocation(self, loc):
"""True if parameter is defined at location."""
return self.location and self.location & loc

def hasCategory(self, category: str) -> bool:
"""True if a parameter has a specific category."""
return category in self.categories


class ParameterDefinitionCollection:
"""
Expand Down
92 changes: 92 additions & 0 deletions armi/reactor/tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Tests of the Parameters class."""
import copy
import typing
import unittest

from armi.reactor import parameters
Expand Down Expand Up @@ -456,10 +457,18 @@ class MockPC(parameters.ParameterCollection):
self.assertEqual(p2.categories, set(["awesome", "stuff", "bacon"]))
self.assertEqual(p3.categories, set(["bacon"]))

for p in [p1, p2, p3]:
self._testCategoryConsistency(p)

self.assertEqual(set(pc.paramDefs.inCategory("awesome")), set([p1, p2]))
self.assertEqual(set(pc.paramDefs.inCategory("stuff")), set([p1, p2]))
self.assertEqual(set(pc.paramDefs.inCategory("bacon")), set([p2, p3]))

def _testCategoryConsistency(self, p: parameters.Parameter):
for category in p.categories:
self.assertTrue(p.hasCategory(category))
self.assertFalse(p.hasCategory("this_shouldnot_exist"))

def test_parameterCollectionsHave__slots__(self):
"""Tests we prevent accidental creation of attributes."""
self.assertEqual(
Expand Down Expand Up @@ -502,3 +511,86 @@ class MockPCChild(MockPC):
pcc = MockPCChild()
with self.assertRaises(AssertionError):
pcc.whatever = 33


class ParamCollectionWhere(unittest.TestCase):
"""Tests for ParameterCollection.where."""

class ScopeParamCollection(parameters.ParameterCollection):
pDefs = parameters.ParameterDefinitionCollection()
with pDefs.createBuilder() as pb:
pb.defParam(
name="empty",
description="Bare",
location=None,
categories=None,
units="",
)
pb.defParam(
name="keff",
description="keff",
location=parameters.ParamLocation.VOLUME_INTEGRATED,
categories=[parameters.Category.neutronics],
units="",
)
pb.defParam(
name="cornerFlux",
description="corner flux",
location=parameters.ParamLocation.CORNERS,
categories=[
parameters.Category.neutronics,
],
units="",
)
pb.defParam(
name="edgeTemperature",
description="edge temperature",
location=parameters.ParamLocation.EDGES,
categories=[parameters.Category.thermalHydraulics],
units="",
)

pc: typing.ClassVar[parameters.ParameterCollection]

Comment on lines +553 to +554
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
pc: typing.ClassVar[parameters.ParameterCollection]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, but I believe this line is unneeded. The setupClass() method handles this for you.

@classmethod
def setUpClass(cls) -> None:
"""Define a couple useful parameters with categories, locations, etc."""
cls.pc = cls.ScopeParamCollection()

def test_onCategory(self):
"""Test the use of Parameter.hasCategory on filtering."""
names = {"keff", "cornerFlux"}
for p in self.pc.where(
lambda pd: pd.hasCategory(parameters.Category.neutronics)
):
self.assertTrue(p.hasCategory(parameters.Category.neutronics), msg=p)
names.remove(p.name)
self.assertFalse(names, msg=f"{names=} should be empty!")

def test_onLocation(self):
"""Test the use of Parameter.atLocation in filtering."""
names = {
"edgeTemperature",
}
for p in self.pc.where(
lambda pd: pd.atLocation(parameters.ParamLocation.EDGES)
):
self.assertTrue(p.atLocation(parameters.ParamLocation.EDGES), msg=p)
names.remove(p.name)
self.assertFalse(names, msg=f"{names=} should be empty!")

def test_complicated(self):
"""Test a multi-condition filter."""
names = {
"cornerFlux",
}

def check(p: parameters.Parameter) -> bool:
return p.atLocation(parameters.ParamLocation.CORNERS) and p.hasCategory(
parameters.Category.neutronics
)

for p in self.pc.where(check):
self.assertTrue(check(p), msg=p)
names.remove(p.name)
self.assertFalse(names, msg=f"{names=} should be empty")
4 changes: 4 additions & 0 deletions doc/release/0.4.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ New Features
#. Adding ``--skip-inspection`` flag to ``CompareCases`` CLI. (`PR#1842 <https://github.com/terrapower/armi/pull/1842>`_)
#. Provide utilities for determining location of a rotated object in a hexagonal lattice (``getIndexOfRotatedCell``). (`PR#1846 <https://github.com/terrapower/armi/1846`)
#. Allow merging a component with zero area into another component. (`PR#1858 <https://github.com/terrapower/armi/pull/1858>`_)
#. Provide ``Parameter.hasCategory`` for quickly checking if a parameter is defined with a given category.
(`PR#1899 <https://github.com/terrapower/armi/pull/1899>`_)
#. Provide ``ParameterCollection.where`` for efficient iteration over parameters who's definition
matches a given condition. (`PR#1899 <https://github.com/terrapower/armi/pull/1899>`_)
#. TBD

API Changes
Expand Down
Loading