Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve get_node #229

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions thicket/tests/test_get_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
def test_get_node(literal_thickets):
tk, _, _ = literal_thickets

# Check error raised
with pytest.raises(KeyError):
tk.get_node("Foo")

baz = tk.get_node("Baz")
# Check case which="first"
qux1 = tk.get_node("Qux", which="first")
assert qux1.frame["name"] == "Qux"
assert qux1._hatchet_nid == 1

# Check node properties
assert baz.frame["name"] == "Baz"
assert baz.frame["type"] == "function"
assert baz._hatchet_nid == 0
# Check case which="last"
qux2 = tk.get_node("Qux", which="last")
assert qux2.frame["name"] == "Qux"
assert qux2._hatchet_nid == 2

# Check case which="all"
qux_all = tk.get_node("Qux", which="all")
assert len(qux_all) == 2
21 changes: 16 additions & 5 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1727,23 +1727,34 @@ def add_root_node(self, attrs):
# Check Thicket state
validate_nodes(self)

def get_node(self, name):
def get_node(self, name, which="first"):
"""Get a node object in the Thicket by its Node.frame['name']. If more than one
node has the same name, a list of nodes is returned.
node has the same name, use the 'which' argument to specify which node to return.

Arguments:
name (str): name of the node (Node.frame['name']).
which (str, optional): which node to return if multiple nodes have the same
name. Options are "first", "last", or "all". Defaults to "first".

Returns:
(Node or list(Node)): Node object with the given name or list of Node objects
with the given name.
"""
node = [n for n in self.graph.traverse() if n.frame["name"] == name]
nodes = [n for n in self.graph.traverse() if n.frame["name"] == name]

if len(node) == 0:
if len(nodes) == 0:
raise KeyError(f'Node with name "{name}" not found.')

return node[0] if len(node) == 1 else node
if len(nodes) == 1 or which == "first":
return nodes[0]
elif which == "last":
return nodes[-1]
elif which == "all":
return nodes
else:
raise ValueError(
'Invalid value for "which". Options are "first", "last", or "all".'
)

def _sync_profile_components(self, component):
"""Synchronize the Performance DataFrame, Metadata Dataframe, profile and
Expand Down
Loading