Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support string labels with CSV entity labels #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions grandlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"jsonl": lambda x: x.to_json(sys.stdout, orient="records", lines=True),
}


def _guess_delimiter(first_n_lines: list[str]) -> str:
"""
Guess the delimiter of a CSV file from the first few lines.
Expand All @@ -38,10 +39,14 @@ def _guess_delimiter(first_n_lines: list[str]) -> str:
# Appears at all?
if not any(delimiter in line for line in first_n_lines):
continue
if all(line.count(delimiter) == first_n_lines[0].count(delimiter) for line in first_n_lines):
if all(
line.count(delimiter) == first_n_lines[0].count(delimiter)
for line in first_n_lines
):
return delimiter
raise ValueError("Could not guess delimiter.")


def read_headered_edgelist(filename: str) -> nx.Graph:
"""
Read a graph from a headered edgelist file.
Expand All @@ -62,6 +67,7 @@ def read_headered_edgelist(filename: str) -> nx.Graph:
src_col = match.group(1)
tgt_col = match.group(2)
filepath = match.group(3)

# The file has a header row.
# Use the CSV reader to read the file:
def _without_srctgt(row):
Expand Down Expand Up @@ -108,7 +114,6 @@ def _infer_graph_filetype_from_contents(filename: str) -> str:
if match is not None:
return "opencypher"


raise NotImplementedError("Cannot infer graph file type from contents.")


Expand All @@ -128,6 +133,21 @@ def read_opencypher(paths: str) -> nx.Graph:
return opencypher_buffers_to_graph(vertex_paths, edge_paths)


def parse_labels_attribute(labels_str: str) -> set:
"""
Parse a CSV string of labels into a set.

Arguments:
labels_str: The CSV string of labels.

Returns:
A set of labels.
"""
if labels_str:
return set(labels_str.split(","))
return set()


def detect_and_load_graph(graph_uri: str) -> nx.Graph:
"""
Read a graph from its URI and return a NetworkX.Graph-compatible API.
Expand Down Expand Up @@ -184,6 +204,15 @@ def detect_and_load_graph(graph_uri: str) -> nx.Graph:
raise ValueError(f"Unknown graph file type for file '{graph_path}'.")

host_graph = readers[graph_type](graph_path)
# Parse __labels__ attributes as CSV and convert to set
for nid, node_attrs in host_graph.nodes(data=True):
if "__labels__" in node_attrs:
node_attrs["__labels__"] = parse_labels_attribute(node_attrs["__labels__"])

for u, v, edge_attrs in host_graph.edges(data=True):
if "__labels__" in edge_attrs:
edge_attrs["__labels__"] = parse_labels_attribute(edge_attrs["__labels__"])

return host_graph


Expand Down Expand Up @@ -312,6 +341,7 @@ def cli():
elif args.output in ["json", "jsonl"]:
# Write valid JSON to stdout
import json

json.dump(results, sys.stdout)
else:
writer = csv.DictWriter(sys.stdout, fieldnames=results.keys())
Expand Down