Skip to content

Commit

Permalink
Add numpy as np
Browse files Browse the repository at this point in the history
  • Loading branch information
tanghaibao committed Jun 18, 2024
1 parent 912077c commit 2c0d3db
Showing 1 changed file with 27 additions and 9 deletions.
36 changes: 27 additions & 9 deletions goatools/nt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import datetime
import collections as cx


def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
"""Return a new dict of namedtuples by combining "dicts" of namedtuples or objects."""
assert len(ids) == len(set(ids)), "NOT ALL IDs ARE UNIQUE: {IDs}".format(IDs=ids)
assert len(flds) == len(set(flds)), "DUPLICATE FIELDS: {IDs}".format(
IDs=cx.Counter(flds).most_common())
IDs=cx.Counter(flds).most_common()
)
usr_id_nt = []
# 1. Instantiate namedtuple object
ntobj = cx.namedtuple("Nt", " ".join(flds))
Expand All @@ -23,6 +25,7 @@ def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
usr_id_nt.append((item_id, ntobj._make(vals)))
return cx.OrderedDict(usr_id_nt)


def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
"""Return a new list of namedtuples by combining "dicts" of namedtuples or objects."""
combined_nt_list = []
Expand All @@ -36,48 +39,61 @@ def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
combined_nt_list.append(ntobj._make(vals))
return combined_nt_list


def combine_nt_lists(lists, flds, dflt_null=""):
"""Return a new list of namedtuples by zipping "lists" of namedtuples or objects."""
combined_nt_list = []
# Check that all lists are the same length
lens = [len(lst) for lst in lists]
assert len(set(lens)) == 1, \
"LIST LENGTHS MUST BE EQUAL: {Ls}".format(Ls=" ".join(str(l) for l in lens))
assert len(set(lens)) == 1, "LIST LENGTHS MUST BE EQUAL: {Ls}".format(
Ls=" ".join(str(l) for l in lens)
)
# 1. Instantiate namedtuple object
ntobj = cx.namedtuple("Nt", " ".join(flds))
# 2. Loop through zipped list
for lst0_lstn in zip(*lists):
# 2a. Combine various namedtuples into a single namedtuple
combined_nt_list.append(ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null)))
combined_nt_list.append(
ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null))
)
return combined_nt_list


def wr_py_nts(fout_py, nts, docstring=None, varname="nts"):
"""Save namedtuples into a Python module."""
if nts:
with open(fout_py, 'w') as prt:
with open(fout_py, "w") as prt:
prt.write('"""{DOCSTRING}"""\n\n'.format(DOCSTRING=docstring))
prt.write("# Created: {DATE}\n".format(DATE=str(datetime.date.today())))
prt_nts(prt, nts, varname)
sys.stdout.write(" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py))
sys.stdout.write(
" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py)
)

def prt_nts(prt, nts, varname, spc=' '):

def prt_nts(prt, nts, varname, spc=" "):
"""Print namedtuples into a Python module."""
first_nt = nts[0]
nt_name = type(first_nt).__name__
prt.write("import collections as cx\n\n")
prt.write("import numpy as np\n\n")
prt.write("NT_FIELDS = [\n")
for fld in first_nt._fields:
prt.write('{SPC}"{F}",\n'.format(SPC=spc, F=fld))
prt.write("]\n\n")
prt.write('{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
NtName=nt_name))
prt.write(
'{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
NtName=nt_name
)
)
prt.write("# {N:,} items\n".format(N=len(nts)))
prt.write("# pylint: disable=line-too-long\n")
prt.write("{VARNAME} = [\n".format(VARNAME=varname))
for ntup in nts:
prt.write("{SPC}{NT},\n".format(SPC=spc, NT=ntup))
prt.write("]\n")


def get_unique_fields(fld_lists):
"""Get unique namedtuple fields, despite potential duplicates in lists of fields."""
flds = []
Expand All @@ -93,6 +109,7 @@ def get_unique_fields(fld_lists):
assert len(flds) == len(fld_set)
return flds


# -- Internal methods ----------------------------------------------------------------
def _combine_nt_vals(lst0_lstn, flds, dflt_null):
"""Given a list of lists of nts, return a single namedtuple."""
Expand All @@ -110,4 +127,5 @@ def _combine_nt_vals(lst0_lstn, flds, dflt_null):
vals.append(dflt_null)
return vals


# Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved.

0 comments on commit 2c0d3db

Please sign in to comment.