From 50a5dcb0b6dbe29d15713249a6b1d3afdd5d2b78 Mon Sep 17 00:00:00 2001 From: Jordan Matelsky Date: Wed, 12 Jun 2024 14:52:49 -0400 Subject: [PATCH] Support string labels with CSV entity labels --- grandlite/__init__.py | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/grandlite/__init__.py b/grandlite/__init__.py index 20d1f6c..9622a46 100644 --- a/grandlite/__init__.py +++ b/grandlite/__init__.py @@ -20,6 +20,7 @@ "jsonl": lambda x: x.to_json(sys.stdout, orient="records", lines=True), } + def _guess_delimiter(first_n_lines: list[str]) -> str: """ Guess the delimiter of a CSV file from the first few lines. @@ -38,10 +39,14 @@ def _guess_delimiter(first_n_lines: list[str]) -> str: # Appears at all? if not any(delimiter in line for line in first_n_lines): continue - if all(line.count(delimiter) == first_n_lines[0].count(delimiter) for line in first_n_lines): + if all( + line.count(delimiter) == first_n_lines[0].count(delimiter) + for line in first_n_lines + ): return delimiter raise ValueError("Could not guess delimiter.") + def read_headered_edgelist(filename: str) -> nx.Graph: """ Read a graph from a headered edgelist file. @@ -62,6 +67,7 @@ def read_headered_edgelist(filename: str) -> nx.Graph: src_col = match.group(1) tgt_col = match.group(2) filepath = match.group(3) + # The file has a header row. # Use the CSV reader to read the file: def _without_srctgt(row): @@ -108,7 +114,6 @@ def _infer_graph_filetype_from_contents(filename: str) -> str: if match is not None: return "opencypher" - raise NotImplementedError("Cannot infer graph file type from contents.") @@ -128,6 +133,21 @@ def read_opencypher(paths: str) -> nx.Graph: return opencypher_buffers_to_graph(vertex_paths, edge_paths) +def parse_labels_attribute(labels_str: str) -> set: + """ + Parse a CSV string of labels into a set. + + Arguments: + labels_str: The CSV string of labels. + + Returns: + A set of labels. + """ + if labels_str: + return set(labels_str.split(",")) + return set() + + def detect_and_load_graph(graph_uri: str) -> nx.Graph: """ Read a graph from its URI and return a NetworkX.Graph-compatible API. @@ -184,6 +204,15 @@ def detect_and_load_graph(graph_uri: str) -> nx.Graph: raise ValueError(f"Unknown graph file type for file '{graph_path}'.") host_graph = readers[graph_type](graph_path) + # Parse __labels__ attributes as CSV and convert to set + for nid, node_attrs in host_graph.nodes(data=True): + if "__labels__" in node_attrs: + node_attrs["__labels__"] = parse_labels_attribute(node_attrs["__labels__"]) + + for u, v, edge_attrs in host_graph.edges(data=True): + if "__labels__" in edge_attrs: + edge_attrs["__labels__"] = parse_labels_attribute(edge_attrs["__labels__"]) + return host_graph @@ -312,6 +341,7 @@ def cli(): elif args.output in ["json", "jsonl"]: # Write valid JSON to stdout import json + json.dump(results, sys.stdout) else: writer = csv.DictWriter(sys.stdout, fieldnames=results.keys())