From 61a1fcbd0aab22f0075a1438e0977af32ed1918f Mon Sep 17 00:00:00 2001 From: Nik Awtrey <46797896+nawtrey@users.noreply.github.com> Date: Fri, 16 Aug 2024 19:59:45 -0700 Subject: [PATCH] ENH: Improve directional diagram algo (#113) * Changes directional diagram algorithm to use depth-first-search (i.e. `nx.dfs_tree`) to create the directional diagrams. Also uses the `nx.DiGraph.reverse` method to build the directional diagram with reversed edges and changes directional diagram return type for `return_edges=False` to a `nx.DiGraph`. * Removes private functions `diagrams._collect_sources` and `diagrams.get_directional_path_edges` * Move array-based edge flipping code into the `return_edges=True` code path * Addresses the directional diagram algo improvement portion of #22 --- kda/diagrams.py | 104 +++++++++++++----------------------------------- 1 file changed, 28 insertions(+), 76 deletions(-) diff --git a/kda/diagrams.py b/kda/diagrams.py index 622712a..54e038a 100644 --- a/kda/diagrams.py +++ b/kda/diagrams.py @@ -132,68 +132,6 @@ def _get_flux_path_edges(target, unique_edges): return list(set(path_edges)) -def _collect_sources(G): - """ - Finds all nodes in a diagram with a single neighbor. Used - to find the leaf nodes in a spanning tree (partial diagram). - - Parameters - ---------- - G : ``NetworkX.Graph`` - A partial diagram - - Returns - ------- - sources: list of int - List of nodes with a single neighbor. - - """ - sources = [] - for n in G.nodes: - if len(list(G.neighbors(n))) == 1: - # sources should only have a single neighbor, but - # this may not be true for more advanced cases - sources.append(n) - return sources - - -def _get_directional_path_edges(G, target): - """ - Collects edges for all paths leading to a - target state for an input partial diagram. - - Parameters - ---------- - G : ``NetworkX.MultiDiGraph`` - A kinetic diagram - target : int - Target state. - - Returns - ------- - path_edges : list - List of edge tuples (e.g. ``[(0, 1, 0), (1, 2, 0), ...]``). - - """ - sources = _collect_sources(G) - # purge target from source list - sources = [n for n in sources if n != target] - # NetworkX function allows for multiple target states but not - # multiple sources. So instead of iterating over each available - # source, we can instead flip the direction of the paths to avoid - # a for loop - paths = list(nx.all_simple_edge_paths(G, source=target, target=sources)) - # flatten the path edges and remove redundant edge tuples - path_edges = np.unique([edge for path in paths for edge in path], axis=0) - # flip edge tuples to account for our targets and sources being flipped - path_edges = np.fliplr(path_edges) - # add in the zero column for now - # TODO: change downstream functions so we - # don't have to keep these unnecessary zeros - path_edges = np.column_stack((path_edges, np.zeros(path_edges.shape[0]))) - return path_edges - - def _construct_cycle_edges(cycle): """ Constucts edge tuples in a cycle using the node indices in the cycle. It @@ -431,7 +369,8 @@ def generate_partial_diagrams(G, return_edges=False): def generate_directional_diagrams(G, return_edges=False): """ - Generates all directional diagrams for a kinetic diagram. + Generates all directional diagrams for a kinetic diagram + using depth-first-search algorithm. Parameters ---------- @@ -444,11 +383,10 @@ def generate_directional_diagrams(G, return_edges=False): Returns ------- - directional_diagrams : ndarray of ``NetworkX.MultiDiGraph`` - Array of all directional diagrams for ``G``. - directional_diagram_edges : ndarray - Array of edges (made from 2-tuples) for valid directional - diagrams. + directional_diagrams : ndarray or ndarray of ``NetworkX.DiGraph`` + Array of all directional diagram edges made from 3-tuples + (``return_edges=True``) or array of all directional + diagrams (``return_edges=False``) for ``G``. """ partial_diagrams = generate_partial_diagrams(G, return_edges=False) @@ -461,22 +399,36 @@ def generate_directional_diagrams(G, return_edges=False): else: directional_diagrams = np.empty((n_dir_diags,), dtype=object) + # get the set of target nodes in ascending order + # so all directional diagrams for each state are + # generated in order targets = np.sort(list(G.nodes)) for i, target in enumerate(targets): - for j, partial_diagram in enumerate(partial_diagrams): - # get directional edges from partial diagram edges - dir_edges = _get_directional_path_edges(partial_diagram, target) + for j, G_partial in enumerate(partial_diagrams): + # apply depth-first-search to partial diagram to create + # a directed spanning tree where the edges are directed + # from the target node to the leaf nodes + G_dfs = nx.dfs_tree(G_partial, source=target) if return_edges: + # collect the edges from the directed spanning tree + # and reverse the direction of the edges to get the correct + # edges for a directional diagram + dir_edges = np.fliplr(np.asarray(G_dfs.edges(), dtype=np.int32)) + # add in the zero column for now + # TODO: change downstream functions so we + # don't have to keep these unnecessary zeros + dir_edges = np.column_stack((dir_edges, np.zeros(dir_edges.shape[0]))) directional_diagrams[j + i*n_partials] = dir_edges else: - directional_diagram = nx.MultiDiGraph() - directional_diagram.add_edges_from(dir_edges) + # make a copy of the `nx.DiGraph` with reversed + # edges to get the directional diagram + G_directional = G_dfs.reverse(copy=True) # set "is_target" to False for all nodes - nx.set_node_attributes(directional_diagram, False, "is_target") + nx.set_node_attributes(G_directional, False, "is_target") # set target node to True - directional_diagram.nodes[target]["is_target"] = True + G_directional.nodes[target]["is_target"] = True # add to array of directional diagrams - directional_diagrams[j + i*n_partials] = directional_diagram + directional_diagrams[j + i*n_partials] = G_directional return directional_diagrams