Skip to content

Commit

Permalink
Migrate to DaskIndexingAdapter
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 21, 2024
1 parent e2547f1 commit a4ba2bc
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 39 deletions.
38 changes: 6 additions & 32 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2126,18 +2126,6 @@ def to_floatable(x: DataArray) -> DataArray:
return to_floatable(data)


def _apply_vectorized_indexer(indices, coord):
from xarray.core.indexing import (
VectorizedIndexer,
apply_indexer,
as_indexable,
)

return apply_indexer(
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
)


def _calc_idxminmax(
*,
array,
Expand Down Expand Up @@ -2182,28 +2170,14 @@ def _calc_idxminmax(
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# Handle chunked arrays (e.g. dask).
coord = array[dim]._variable.to_base_variable()
if is_chunked_array(array.data):
chunkmanager = get_chunked_array_type(array.data)
chunked_coord = chunkmanager.from_array(array[dim].data, chunks=((-1,),))

if indx.ndim == 0:
out = chunked_coord[indx.data]
else:
out = chunkmanager.map_blocks(
_apply_vectorized_indexer,
indx.data[..., np.newaxis],
chunked_coord,
chunks=indx.data.chunks,
drop_axis=-1,
dtype=chunked_coord.dtype,
)
res = indx.copy(data=out)
# we need to attach back the dim name
res.name = dim
else:
res = array[dim][(indx,)]
# The dim is gone but we need to remove the corresponding coordinate.
del res.coords[dim]
coord_array = chunkmanager.from_array(
array[dim].data, chunks=((array.sizes[dim],),)
)
coord = coord.copy(data=coord_array)
res = indx._replace(coord[(indx.variable,)]).rename(dim)

if skipna or (skipna is None and array.dtype.kind in na_dtypes):
# Put the NaN values back in after removing them
Expand Down
49 changes: 42 additions & 7 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import functools
import math
import operator
from collections import Counter, defaultdict
from collections.abc import Callable, Hashable, Iterable, Mapping
Expand Down Expand Up @@ -472,12 +473,12 @@ def __init__(self, key: tuple[slice | np.ndarray[Any, np.dtype[np.generic]], ...
for k in key:
if isinstance(k, slice):
k = as_integer_slice(k)
elif is_duck_dask_array(k):
raise ValueError(
"Vectorized indexing with Dask arrays is not supported. "
"Please pass a numpy array by calling ``.compute``. "
"See https://github.com/dask/dask/issues/8958."
)
# elif is_duck_dask_array(k):
# raise ValueError(
# "Vectorized indexing with Dask arrays is not supported. "
# "Please pass a numpy array by calling ``.compute``. "
# "See https://github.com/dask/dask/issues/8958."
# )
elif is_duck_array(k):
if not np.issubdtype(k.dtype, np.integer):
raise TypeError(
Expand Down Expand Up @@ -1607,6 +1608,18 @@ def transpose(self, order):
return xp.permute_dims(self.array, order)


def _apply_vectorized_indexer_dask_wrapper(indices, coord):
from xarray.core.indexing import (
VectorizedIndexer,
apply_indexer,
as_indexable,
)

return apply_indexer(
as_indexable(coord), VectorizedIndexer((indices.squeeze(axis=-1),))
)


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand All @@ -1630,7 +1643,29 @@ def _oindex_get(self, indexer: OuterIndexer):
return value

def _vindex_get(self, indexer: VectorizedIndexer):
return self.array.vindex[indexer.tuple]
try:
return self.array.vindex[indexer.tuple]
except IndexError as e:
# TODO: upstream to dask
has_dask = any(is_duck_dask_array(i) for i in indexer.tuple)
if not has_dask or (has_dask and len(indexer.tuple) > 1):
raise e
if math.prod(self.array.numblocks) > 1 or self.array.ndim > 1:
raise e
(idxr,) = indexer.tuple
if idxr.ndim == 0:
return self.array[idxr.data]
else:
import dask.array

return dask.array.map_blocks(
_apply_vectorized_indexer_dask_wrapper,
idxr[..., np.newaxis],
self.array,
chunks=idxr.chunks,
drop_axis=-1,
dtype=self.array.dtype,
)

def __getitem__(self, indexer: ExplicitIndexer):
self._check_and_raise_if_non_basic_indexer(indexer)
Expand Down

0 comments on commit a4ba2bc

Please sign in to comment.