Skip to content

Commit

Permalink
Merge pull request #197 from mmore500/downsample-tips
Browse files Browse the repository at this point in the history
Optimize downsample-tips, add CLI
  • Loading branch information
mmore500 authored Jan 5, 2025
2 parents bb7e657 + bf8c592 commit 8e6a4b5
Show file tree
Hide file tree
Showing 38 changed files with 392 additions and 66 deletions.
4 changes: 4 additions & 0 deletions hstrat/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
print(f"hstrat v{get_hstrat_version()}")
print()
print("Available commands (stabilized API):")
print("$ python3 -m hstrat.dataframe.surface_build_tree")
print("$ python3 -m hstrat.dataframe.surface_unpack_reconstruct")
print("$ python3 -m hstrat.dataframe.surface_postprocess_trie")
print()
print("Available commands (experimental API):")
print("$ python3 -m hstrat._auxiliary_lib._alifestd_as_newick_asexual")
print(
"$ python3 -m hstrat._auxiliary_lib._alifestd_downsample_tips_asexual"
)
print(
"$ python3 -m hstrat._auxiliary_lib._alifestd_try_add_ancestor_list_col"
)
Expand Down
93 changes: 88 additions & 5 deletions hstrat/_auxiliary_lib/_alifestd_downsample_tips_asexual.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
import random
import argparse
import functools
import logging
import sys
import typing

from joinem._dataframe_cli import _add_parser_base, _run_dataframe_cli
import numpy as np
import pandas as pd

from ._alifestd_find_leaf_ids import alifestd_find_leaf_ids
from ._alifestd_has_contiguous_ids import alifestd_has_contiguous_ids
from ._alifestd_prune_extinct_lineages_asexual import (
alifestd_prune_extinct_lineages_asexual,
)
from ._alifestd_try_add_ancestor_id_col import alifestd_try_add_ancestor_id_col
from ._configure_prod_logging import configure_prod_logging
from ._delegate_polars_implementation import delegate_polars_implementation
from ._format_cli_description import format_cli_description
from ._get_hstrat_version import get_hstrat_version
from ._log_context_duration import log_context_duration
from ._with_rng_state_context import with_rng_state_context


Expand All @@ -17,8 +28,13 @@ def _alifestd_downsample_tips_asexual_impl(
) -> pd.DataFrame:
"""Implementation detail for alifestd_downsample_tips_asexual."""
tips = alifestd_find_leaf_ids(phylogeny_df)
kept = random.sample(tips, min(n_downsample, len(tips)))
phylogeny_df["extant"] = phylogeny_df["id"].isin(kept)
kept = np.random.choice(tips, min(n_downsample, len(tips)), replace=False)
if alifestd_has_contiguous_ids(phylogeny_df):
extant = np.zeros(len(phylogeny_df), dtype=bool)
extant[kept] = True
phylogeny_df["extant"] = extant
else:
phylogeny_df["extant"] = phylogeny_df["id"].isin(kept)

return alifestd_prune_extinct_lineages_asexual(
phylogeny_df, mutate=True
Expand All @@ -31,8 +47,9 @@ def alifestd_downsample_tips_asexual(
mutate: bool = False,
seed: typing.Optional[int] = None,
) -> pd.DataFrame:
"""Subsample phylogeny containing `num_tips` tips. If `num_tips` is greater
than the number of tips in the phylogeny, the whole phylogeny is returned.
"""Create a subsample phylogeny containing `num_tips` tips. If `num_tips`
is greater than the number of tips in the phylogeny, the whole phylogeny is
returned.
Only supports asexual phylogenies.
"""
Expand All @@ -56,3 +73,69 @@ def alifestd_downsample_tips_asexual(
)

return impl(phylogeny_df, n_downsample)


_raw_description = """Create a subsample phylogeny containing `num_tips` tips.
If `num_tips` is greater than the number of tips in the phylogeny, the whole phylogeny is returned.
Data is assumed to be in alife standard format.
Only supports asexual phylogenies.
Additional Notes
================
- Requires 'ancestor_id' column to be present in input DataFrame.
Otherwise, no action is taken.
- Use `--eager-read` if modifying data file inplace.
- This CLI entrypoint is experimental and may be subject to change.
"""


def _create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
add_help=False,
description=format_cli_description(_raw_description),
formatter_class=argparse.RawTextHelpFormatter,
)
parser = _add_parser_base(
parser=parser,
dfcli_module="hstrat._auxiliary_lib._alifestd_downsample_tips_asexual",
dfcli_version=get_hstrat_version(),
)
parser.add_argument(
"-n",
default=sys.maxsize,
type=int,
help="Number of tips to subsample.",
)
parser.add_argument(
"--seed",
default=None,
dest="seed",
help="Integer seed for deterministic behavior.",
type=int,
)
return parser


if __name__ == "__main__":
configure_prod_logging()

parser = _create_parser()
args, __ = parser.parse_known_args()
with log_context_duration(
"hstrat._auxiliary_lib._alifestd_downsample_tips_asexual", logging.info
):
_run_dataframe_cli(
base_parser=parser,
output_dataframe_op=delegate_polars_implementation()(
functools.partial(
alifestd_downsample_tips_asexual,
n_downsample=args.n,
seed=args.seed,
),
),
overridden_arguments="ignore", # seed is overridden
)
15 changes: 6 additions & 9 deletions hstrat/_auxiliary_lib/_alifestd_find_leaf_ids.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import typing

import numpy as np
import ordered_set as ods
import pandas as pd
Expand All @@ -9,7 +7,7 @@
from ._alifestd_try_add_ancestor_id_col import alifestd_try_add_ancestor_id_col


def alifestd_find_leaf_ids(phylogeny_df: pd.DataFrame) -> typing.List[int]:
def alifestd_find_leaf_ids(phylogeny_df: pd.DataFrame) -> np.ndarray:
"""What ids are not listed in any `ancestor_list`?
Input dataframe is not mutated by this operation.
Expand All @@ -20,15 +18,14 @@ def alifestd_find_leaf_ids(phylogeny_df: pd.DataFrame) -> typing.List[int]:
if "ancestor_id" in phylogeny_df:

# root is self ref, but must exclude to handle only-root phylo
internal_node_idxs = phylogeny_df.loc[
phylogeny_df["ancestor_id"] != phylogeny_df["id"],
"ancestor_id",
].to_numpy()
internal_node_idxs = phylogeny_df["ancestor_id"].to_numpy()[
phylogeny_df["ancestor_id"] != phylogeny_df["id"]
]

leaf_pos_filter = np.ones(len(phylogeny_df), dtype=np.bool_)
leaf_pos_filter[internal_node_idxs] = False

return phylogeny_df.loc[leaf_pos_filter, "id"].to_list()
return np.flatnonzero(leaf_pos_filter)

all_ids = ods.OrderedSet(phylogeny_df["id"])
internal_ids = (
Expand All @@ -50,4 +47,4 @@ def alifestd_find_leaf_ids(phylogeny_df: pd.DataFrame) -> typing.List[int]:
]
)
)
return list(all_ids - internal_ids)
return np.fromiter(all_ids - internal_ids, dtype=int)
101 changes: 85 additions & 16 deletions hstrat/_auxiliary_lib/_alifestd_prune_extinct_lineages_asexual.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,70 @@
import numpy as np
import pandas as pd

from ._alifestd_has_contiguous_ids import alifestd_has_contiguous_ids
from ._alifestd_is_topologically_sorted import alifestd_is_topologically_sorted
from ._alifestd_try_add_ancestor_id_col import alifestd_try_add_ancestor_id_col
from ._alifestd_unfurl_lineage_asexual import alifestd_unfurl_lineage_asexual
from ._jit import jit
from ._unfurl_lineage_with_contiguous_ids import (
unfurl_lineage_with_contiguous_ids,
)


def _create_has_extant_descendant_noncontiguous(
phylogeny_df: pd.DataFrame,
extant_mask: np.ndarray,
) -> np.ndarray:
"""Implementation detail for alifestd_prune_extinct_lineages_asexual."""

phylogeny_df["has_extant_descendant"] = False
for extant_id in phylogeny_df.loc[extant_mask, "id"]:
for lineage_id in alifestd_unfurl_lineage_asexual(
phylogeny_df,
int(extant_id),
mutate=True,
):
if phylogeny_df.loc[lineage_id, "has_extant_descendant"]:
break

phylogeny_df.loc[lineage_id, "has_extant_descendant"] = True

return phylogeny_df["has_extant_descendant"]


@jit(nopython=True)
def _create_has_extant_descendant_contiguous(
ancestor_ids: np.ndarray,
extant_mask: np.ndarray,
) -> np.ndarray:
"""Implementation detail for alifestd_prune_extinct_lineages_asexual."""

has_extant_descendant = np.zeros_like(extant_mask)
for extant_id in np.flatnonzero(extant_mask):
for lineage_id in unfurl_lineage_with_contiguous_ids(
ancestor_ids,
int(extant_id),
):
if has_extant_descendant[lineage_id]:
break

has_extant_descendant[lineage_id] = True

return has_extant_descendant


@jit(nopython=True)
def _create_has_extant_descendant_contiguous_sorted(
ancestor_ids: np.ndarray,
extant_mask: np.ndarray,
) -> np.ndarray:
"""Implementation detail for alifestd_prune_extinct_lineages_asexual."""

has_extant_descendant = extant_mask.copy()
for id_ in range(len(ancestor_ids) - 1, -1, -1):
has_extant_descendant[ancestor_ids[id_]] |= has_extant_descendant[id_]

return has_extant_descendant


def alifestd_prune_extinct_lineages_asexual(
Expand Down Expand Up @@ -45,7 +107,10 @@ def alifestd_prune_extinct_lineages_asexual(
phylogeny_df = phylogeny_df.copy()

phylogeny_df = alifestd_try_add_ancestor_id_col(phylogeny_df, mutate=True)
phylogeny_df.set_index("id", drop=False, inplace=True)
if alifestd_has_contiguous_ids(phylogeny_df):
phylogeny_df.reset_index(drop=True, inplace=True)
else:
phylogeny_df.index = phylogeny_df["id"]

extant_mask = None
if "extant" in phylogeny_df:
Expand All @@ -58,22 +123,26 @@ def alifestd_prune_extinct_lineages_asexual(
else:
raise ValueError('Need "extant" or "destruction_time" column.')

phylogeny_df["has_extant_descendant"] = False

for extant_id in phylogeny_df.loc[extant_mask, "id"]:
for lineage_id in alifestd_unfurl_lineage_asexual(
if not alifestd_has_contiguous_ids(phylogeny_df):
has_extant_descendant = _create_has_extant_descendant_noncontiguous(
phylogeny_df,
int(extant_id),
mutate=True,
):
if phylogeny_df.loc[lineage_id, "has_extant_descendant"]:
break

phylogeny_df.loc[lineage_id, "has_extant_descendant"] = True
extant_mask,
)
elif not alifestd_is_topologically_sorted(phylogeny_df):
has_extant_descendant = _create_has_extant_descendant_contiguous(
phylogeny_df["ancestor_id"].to_numpy(dtype=np.uint64),
extant_mask.to_numpy(dtype=bool),
)
else:
has_extant_descendant = (
_create_has_extant_descendant_contiguous_sorted(
phylogeny_df["ancestor_id"].to_numpy(dtype=np.uint64),
extant_mask.to_numpy(dtype=bool),
)
)

drop_filter = ~phylogeny_df["has_extant_descendant"]
phylogeny_df = phylogeny_df[has_extant_descendant].reset_index(drop=True)
phylogeny_df.drop(
phylogeny_df.index[drop_filter], inplace=True, axis="rows"
columns="has_extant_descendant", errors="ignore", inplace=True
)
phylogeny_df.drop("has_extant_descendant", inplace=True, axis="columns")
return phylogeny_df.reset_index(drop=True)
return phylogeny_df
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def alifestd_try_add_ancestor_list_col(
def _create_parser() -> argparse.ArgumentParser:
"""Create parser for CLI entrypoint."""
parser = argparse.ArgumentParser(
add_help=False,
description=format_cli_description(_raw_description),
formatter_class=argparse.RawTextHelpFormatter,
)
Expand Down
4 changes: 4 additions & 0 deletions hstrat/_auxiliary_lib/_coerce_to_pandas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import typing

import pandas as pd
import polars as pl

_supported_iterables = tuple, set, list, frozenset
_supported_mappings = dict
Expand All @@ -10,6 +11,9 @@ def coerce_to_pandas(obj: typing.Any, *, recurse: bool = False) -> typing.Any:
"""
If a Polars type is detected, coerce it to corresponding Pandas type.
"""
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()

if hasattr(obj, "__dataframe__"):
return pd.api.interchange.from_dataframe(obj, allow_copy=True)
elif hasattr(obj, "to_pandas"):
Expand Down
10 changes: 6 additions & 4 deletions hstrat/_auxiliary_lib/_delegate_polars_implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from ._coerce_to_polars import coerce_to_polars
from ._warn_once import warn_once

DataFrame_T = typing.TypeVar("DataFrame_T", pd.DataFrame, pl.DataFrame)
DataFrame_T = typing.TypeVar(
"DataFrame_T", pd.DataFrame, pl.DataFrame, pl.LazyFrame
)
Series_T = typing.TypeVar("Series_T", pd.Series, pl.Series)


Expand All @@ -25,7 +27,7 @@ def _detect_pandas(arg: typing.Any, recurse: bool) -> bool:
"""
if isinstance(arg, (pd.DataFrame, pd.Series)):
return True
elif isinstance(arg, (pl.DataFrame, pl.Series, str)):
elif isinstance(arg, (pl.DataFrame, pl.LazyFrame, pl.Series, str)):
return False
elif recurse and isinstance(arg, _supported_mappings):
return any(_detect_pandas(v, recurse) for v in arg.values())
Expand All @@ -46,7 +48,7 @@ def _detect_polars(arg: typing.Any, recurse: bool) -> bool:
If `recurse` is True, then this function will recursively check for Polars
members in mappings and iterables.
"""
if isinstance(arg, (pl.DataFrame, pl.Series)):
if isinstance(arg, (pl.DataFrame, pl.LazyFrame, pl.Series)):
return True
elif isinstance(arg, (pd.DataFrame, pd.Series, str)):
return False
Expand Down Expand Up @@ -95,7 +97,7 @@ def delegating_function(*args, **kwargs) -> typing.Any:
any_pandas = any(map(detect_pandas_, (*args, *kwargs.values())))
any_polars = any(map(detect_polars_, (*args, *kwargs.values())))
logging.info("begin delgate_polars_implementation")
logging.info("- detected {any_pandas=} {any_polars=}")
logging.info(f"- detected {any_pandas=} {any_polars=}")

if any_pandas and any_polars:
raise TypeError("mixing pandas and polars types is disallowed")
Expand Down
3 changes: 2 additions & 1 deletion hstrat/dataframe/surface_build_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@

def _create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
add_help=False,
description=format_cli_description(raw_message),
formatter_class=argparse.RawTextHelpFormatter,
)
Expand Down Expand Up @@ -132,7 +133,7 @@ def _create_parser() -> argparse.ArgumentParser:
args, __ = parser.parse_known_args()

logging.info(
f"instantiating trie postprocess functor: "
"instantiating trie postprocess functor: "
f"`{args.trie_postprocessor}`",
)
trie_postprocessor = eval(args.trie_postprocessor, {"hstrat": hstrat})
Expand Down
Loading

0 comments on commit 8e6a4b5

Please sign in to comment.