Skip to content
forked from pydata/xarray

Commit

Permalink
GroupBy(chunked-array)
Browse files Browse the repository at this point in the history
Closes pydata#757
Closes pydata#2852
  • Loading branch information
dcherian committed Sep 19, 2024
1 parent 3c74509 commit 95f4802
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 20 deletions.
102 changes: 83 additions & 19 deletions xarray/groupers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@

from xarray.coding.cftime_offsets import BaseCFTimeOffset, _new_to_legacy_freq
from xarray.core import duck_array_ops
from xarray.core.computation import apply_ufunc
from xarray.core.coordinates import Coordinates
from xarray.core.dataarray import DataArray
from xarray.core.duck_array_ops import isnull
from xarray.core.groupby import T_Group, _DummyGroup
from xarray.core.indexes import safe_cast_to_index
from xarray.core.resample_cftime import CFTimeGrouper
Expand All @@ -29,6 +31,7 @@
SideOptions,
)
from xarray.core.variable import Variable
from xarray.namedarray.pycompat import is_chunked_array

__all__ = [
"EncodedGroups",
Expand Down Expand Up @@ -96,7 +99,7 @@ def __init__(
assert isinstance(full_index, pd.Index)
self.full_index = full_index

if group_indices is None:
if group_indices is None and not is_chunked_array(codes.data):
self.group_indices = tuple(
g
for g in _codes_to_group_indices(codes.data.ravel(), len(full_index))
Expand Down Expand Up @@ -155,10 +158,17 @@ class UniqueGrouper(Grouper):
"""Grouper object for grouping by a categorical variable."""

_group_as_index: pd.Index | None = field(default=None, repr=False)
labels: np.ndarray | None = field(default=None)

def __post_init__(self) -> None:
if self.labels is not None:
self.labels = np.sort(self.labels)

@property
def group_as_index(self) -> pd.Index:
"""Caches the group DataArray as a pandas Index."""
if is_chunked_array(self.group):
raise ValueError("Please call compute manually.")
if self._group_as_index is None:
if self.group.ndim == 1:
self._group_as_index = self.group.to_index()
Expand All @@ -169,6 +179,11 @@ def group_as_index(self) -> pd.Index:
def factorize(self, group: T_Group) -> EncodedGroups:
self.group = group

if is_chunked_array(group.data) and self.labels is None:
raise ValueError("When grouping by a dask array, `labels` must be passed.")
if self.labels is not None:
return self._factorize_given_labels(group)

index = self.group_as_index
is_unique_and_monotonic = isinstance(self.group, _DummyGroup) or (
index.is_unique
Expand All @@ -182,6 +197,24 @@ def factorize(self, group: T_Group) -> EncodedGroups:
else:
return self._factorize_unique()

def _factorize_given_labels(self, group: T_Group) -> EncodedGroups:
codes = apply_ufunc(
_factorize_given_labels,
group,
kwargs={"labels": self.labels},
dask="parallelized",
output_dtypes=[np.int64],
)
return EncodedGroups(
codes=codes,
full_index=pd.Index(self.labels),
unique_coord=Variable(
dims=codes.name,
data=self.labels,
attrs=self.group.attrs,
),
)

def _factorize_unique(self) -> EncodedGroups:
# look through group to find the unique values
sort = not isinstance(self.group_as_index, pd.MultiIndex)
Expand Down Expand Up @@ -291,13 +324,9 @@ def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
raise ValueError("All bin edges are NaN.")

def factorize(self, group: T_Group) -> EncodedGroups:
from xarray.core.dataarray import DataArray

data = np.asarray(group.data) # Cast _DummyGroup data to array

binned, self.bins = pd.cut( # type: ignore [call-overload]
data.ravel(),
def _cut(self, data):
return pd.cut( # type: ignore [call-overload]
np.asarray(data).ravel(),
bins=self.bins,
right=self.right,
labels=self.labels,
Expand All @@ -307,23 +336,43 @@ def factorize(self, group: T_Group) -> EncodedGroups:
retbins=True,
)

binned_codes = binned.codes
if (binned_codes == -1).all():
def _factorize_lazy(self, group: T_Group) -> DataArray:
def _wrapper(data, **kwargs):
binned, bins = self._cut(data)
if isinstance(self.bins, int):
# we are running eagerly, update self.bins with actual edges instead
self.bins = bins
return binned.codes.reshape(data.shape)

return apply_ufunc(_wrapper, group, dask="parallelized")

def factorize(self, group: T_Group) -> EncodedGroups:
if isinstance(group, _DummyGroup):
group = DataArray(group.data, dims=group.dims, name=group.name)
by_is_chunked = is_chunked_array(group.data)
if isinstance(self.bins, int) and by_is_chunked:
raise ValueError(
f"Bin edges must be provided when grouping by chunked arrays. Received {self.bins=!r} instead"
)
codes = self._factorize_lazy(group)
if not by_is_chunked and (codes == -1).all():
raise ValueError(
f"None of the data falls within bins with edges {self.bins!r}"
)

new_dim_name = f"{group.name}_bins"
codes.name = new_dim_name

# This seems silly, but it lets us have Pandas handle the complexity
# of labels, precision, and include_lowest, even when group is a chunked array
dummy, _ = self._cut(np.array([1, 2, 3]).astype(group.dtype))
full_index = dummy.categories
if not by_is_chunked:
uniques = np.sort(pd.unique(codes.data.ravel()))
unique_values = full_index[uniques[uniques != -1]]
else:
unique_values = full_index

full_index = binned.categories
uniques = np.sort(pd.unique(binned_codes))
unique_values = full_index[uniques[uniques != -1]]

codes = DataArray(
binned_codes.reshape(group.shape),
getattr(group, "coords", None),
name=new_dim_name,
)
unique_coord = Variable(
dims=new_dim_name, data=unique_values, attrs=group.attrs
)
Expand Down Expand Up @@ -461,6 +510,21 @@ def factorize(self, group: T_Group) -> EncodedGroups:
)


def _factorize_given_labels(data: np.ndarray, labels: np.ndarray) -> np.ndarray:
# Copied from flox
sort = False # use labels as provided
sorter = np.argsort(labels)
codes = np.searchsorted(labels, data, sorter=sorter)
mask = ~np.isin(data, labels) | isnull(data) | (codes == len(labels))
if not sort:
# codes is the index in to the sorted array.
# if we didn't want sorting, unsort it back
codes[(codes == len(labels),)] = -1
codes = sorter[(codes,)]
codes[mask] = -1
return codes


def unique_value_groups(
ar, sort: bool = True
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
Expand Down
4 changes: 3 additions & 1 deletion xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,7 +2583,9 @@ def test_groupby_math_auto_chunk() -> None:
sub = xr.DataArray(
InaccessibleArray(np.array([1, 2])), dims="label", coords={"label": [1, 2]}
)
actual = da.chunk(x=1, y=2).groupby("label") - sub
chunked = da.chunk(x=1, y=2)
chunked.label.load()
actual = chunked.groupby("label") - sub
assert actual.chunksizes == {"x": (1, 1, 1), "y": (2, 1)}


Expand Down

0 comments on commit 95f4802

Please sign in to comment.