Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Improve directional diagram algo #113

Merged
merged 3 commits into from
Aug 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 28 additions & 76 deletions kda/diagrams.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,68 +132,6 @@ def _get_flux_path_edges(target, unique_edges):
return list(set(path_edges))


def _collect_sources(G):
"""
Finds all nodes in a diagram with a single neighbor. Used
to find the leaf nodes in a spanning tree (partial diagram).
Parameters
----------
G : ``NetworkX.Graph``
A partial diagram
Returns
-------
sources: list of int
List of nodes with a single neighbor.
"""
sources = []
for n in G.nodes:
if len(list(G.neighbors(n))) == 1:
# sources should only have a single neighbor, but
# this may not be true for more advanced cases
sources.append(n)
return sources


def _get_directional_path_edges(G, target):
"""
Collects edges for all paths leading to a
target state for an input partial diagram.
Parameters
----------
G : ``NetworkX.MultiDiGraph``
A kinetic diagram
target : int
Target state.
Returns
-------
path_edges : list
List of edge tuples (e.g. ``[(0, 1, 0), (1, 2, 0), ...]``).
"""
sources = _collect_sources(G)
# purge target from source list
sources = [n for n in sources if n != target]
# NetworkX function allows for multiple target states but not
# multiple sources. So instead of iterating over each available
# source, we can instead flip the direction of the paths to avoid
# a for loop
paths = list(nx.all_simple_edge_paths(G, source=target, target=sources))
# flatten the path edges and remove redundant edge tuples
path_edges = np.unique([edge for path in paths for edge in path], axis=0)
# flip edge tuples to account for our targets and sources being flipped
path_edges = np.fliplr(path_edges)
# add in the zero column for now
# TODO: change downstream functions so we
# don't have to keep these unnecessary zeros
path_edges = np.column_stack((path_edges, np.zeros(path_edges.shape[0])))
return path_edges


def _construct_cycle_edges(cycle):
"""
Constucts edge tuples in a cycle using the node indices in the cycle. It
Expand Down Expand Up @@ -431,7 +369,8 @@ def generate_partial_diagrams(G, return_edges=False):

def generate_directional_diagrams(G, return_edges=False):
"""
Generates all directional diagrams for a kinetic diagram.
Generates all directional diagrams for a kinetic diagram
using depth-first-search algorithm.
Parameters
----------
Expand All @@ -444,11 +383,10 @@ def generate_directional_diagrams(G, return_edges=False):
Returns
-------
directional_diagrams : ndarray of ``NetworkX.MultiDiGraph``
Array of all directional diagrams for ``G``.
directional_diagram_edges : ndarray
Array of edges (made from 2-tuples) for valid directional
diagrams.
directional_diagrams : ndarray or ndarray of ``NetworkX.DiGraph``
Array of all directional diagram edges made from 3-tuples
(``return_edges=True``) or array of all directional
diagrams (``return_edges=False``) for ``G``.
"""
partial_diagrams = generate_partial_diagrams(G, return_edges=False)

Expand All @@ -461,22 +399,36 @@ def generate_directional_diagrams(G, return_edges=False):
else:
directional_diagrams = np.empty((n_dir_diags,), dtype=object)

# get the set of target nodes in ascending order
# so all directional diagrams for each state are
# generated in order
targets = np.sort(list(G.nodes))
for i, target in enumerate(targets):
for j, partial_diagram in enumerate(partial_diagrams):
# get directional edges from partial diagram edges
dir_edges = _get_directional_path_edges(partial_diagram, target)
for j, G_partial in enumerate(partial_diagrams):
# apply depth-first-search to partial diagram to create
# a directed spanning tree where the edges are directed
# from the target node to the leaf nodes
G_dfs = nx.dfs_tree(G_partial, source=target)
if return_edges:
# collect the edges from the directed spanning tree
# and reverse the direction of the edges to get the correct
# edges for a directional diagram
dir_edges = np.fliplr(np.asarray(G_dfs.edges(), dtype=np.int32))
# add in the zero column for now
# TODO: change downstream functions so we
# don't have to keep these unnecessary zeros
dir_edges = np.column_stack((dir_edges, np.zeros(dir_edges.shape[0])))
directional_diagrams[j + i*n_partials] = dir_edges
else:
directional_diagram = nx.MultiDiGraph()
directional_diagram.add_edges_from(dir_edges)
# make a copy of the `nx.DiGraph` with reversed
# edges to get the directional diagram
G_directional = G_dfs.reverse(copy=True)
# set "is_target" to False for all nodes
nx.set_node_attributes(directional_diagram, False, "is_target")
nx.set_node_attributes(G_directional, False, "is_target")
# set target node to True
directional_diagram.nodes[target]["is_target"] = True
G_directional.nodes[target]["is_target"] = True
# add to array of directional diagrams
directional_diagrams[j + i*n_partials] = directional_diagram
directional_diagrams[j + i*n_partials] = G_directional

return directional_diagrams

Expand Down
Loading