From fd6b339c3f51e83d5f9deb12e837f822cd51a2f7 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 11:03:25 -0500 Subject: [PATCH 01/24] lots more duck array compat, plus tests --- xarray/core/array_api_compat.py | 28 ++ xarray/core/common.py | 5 +- xarray/core/computation.py | 8 +- xarray/core/dataset.py | 12 +- xarray/core/duck_array_ops.py | 122 +++++--- xarray/core/nanops.py | 2 +- xarray/core/nputils.py | 6 + xarray/core/rolling.py | 3 + xarray/core/variable.py | 39 ++- xarray/tests/test_duck_array_wrapping.py | 371 +++++++++++++++++++++++ xarray/tests/test_strategies.py | 19 +- xarray/tests/test_variable.py | 39 +-- 12 files changed, 563 insertions(+), 91 deletions(-) create mode 100644 xarray/tests/test_duck_array_wrapping.py diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index da072de5b69..28d671cc349 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -1,5 +1,7 @@ import numpy as np +from xarray.namedarray.pycompat import array_type + def is_weak_scalar_type(t): return isinstance(t, bool | int | float | complex | str | bytes) @@ -42,3 +44,29 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return xp.result_type(*arrays_and_dtypes) else: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + + +def get_array_namespace(*values): + def _get_single_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + elif isinstance(x, array_type("cupy")): + # special case cupy for now + import cupy as cp + + return cp + else: + return np + + namespaces = {_get_single_namespace(t) for t in values} + non_numpy = namespaces - {np} + + if len(non_numpy) > 1: + names = [module.__name__ for module in non_numpy] + raise TypeError(f"Mixed array types {names} are not supported.") + elif non_numpy: + [xp] = non_numpy + else: + xp = np + + return xp diff --git a/xarray/core/common.py b/xarray/core/common.py index 6f788f408d0..8aaa153c1a8 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -496,7 +496,7 @@ def clip( keep_attrs = _get_keep_attrs(default=True) return apply_ufunc( - np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" ) def get_index(self, key: Hashable) -> pd.Index: @@ -1760,7 +1760,8 @@ def _full_like_variable( **from_array_kwargs, ) else: - data = np.full_like(other.data, fill_value, dtype=dtype) + xp = duck_array_ops.get_array_namespace(other.data) + data = xp.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b15ed7f3f34..0bfe21642f7 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -34,7 +34,7 @@ from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type -from xarray.namedarray.pycompat import is_chunked_array +from xarray.namedarray.pycompat import is_chunked_array, to_numpy from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: @@ -1702,7 +1702,7 @@ def cross( ) c = apply_ufunc( - np.cross, + duck_array_ops.cross, a, b, input_core_dims=[[dim], [dim]], @@ -2174,9 +2174,13 @@ def _calc_idxminmax( # we need to attach back the dim name res.name = dim else: + indx.data = to_numpy(indx.data) res = array[dim][(indx,)] # The dim is gone but we need to remove the corresponding coordinate. del res.coords[dim] + # Cast to array namespace + xp = duck_array_ops.get_array_namespace(array.data) + res.data = xp.asarray(res.data) if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index cc34a8cc04b..6e5a8e163b8 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -127,7 +127,7 @@ calculate_dimensions, ) from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import array_type, is_chunked_array +from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -6564,7 +6564,7 @@ def dropna( array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += to_numpy(array.count(dims).data) size += math.prod([self.sizes[d] for d in dims]) if thresh is not None: @@ -8678,16 +8678,20 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: + # cast coord data to duck array if needed + coord_data = duck_array_ops.get_array_namespace(v.data).asarray( + coord_var.data + ) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: integ = duck_array_ops.cumulative_trapezoid( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = v.dims else: integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = list(v.dims) v_dims.remove(dim) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 77e62e4c71e..d67f8d17207 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,23 +18,17 @@ import pandas as pd from numpy import all as array_all # noqa: F401 from numpy import any as array_any # noqa: F401 -from numpy import concatenate as _concatenate from numpy import ( # noqa: F401 - full_like, - gradient, isclose, - isin, isnat, take, - tensordot, - transpose, unravel_index, ) -from numpy.lib.stride_tricks import sliding_window_view # noqa: F401 from packaging.version import Version from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_ops, dtypes, nputils +from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray import pycompat @@ -55,28 +49,6 @@ dask_available = module_available("dask") -def get_array_namespace(*values): - def _get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - namespaces = {_get_array_namespace(t) for t in values} - non_numpy = namespaces - {np} - - if len(non_numpy) > 1: - raise TypeError( - "cannot deal with more than one type supporting the array API at the same time" - ) - elif non_numpy: - [xp] = non_numpy - else: - xp = np - - return xp - - def einsum(*args, **kwargs): from xarray.core.options import OPTIONS @@ -85,7 +57,23 @@ def einsum(*args, **kwargs): return opt_einsum.contract(*args, **kwargs) else: - return np.einsum(*args, **kwargs) + xp = get_array_namespace(*args) + return xp.einsum(*args, **kwargs) + + +def tensordot(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.tensordot(*args, **kwargs) + + +def cross(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.cross(*args, **kwargs) + + +def gradient(f, *varargs, axis=None, edge_order=1): + xp = get_array_namespace(f) + return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) def _dask_or_eager_func( @@ -153,7 +141,7 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=bool, fill_value=False) + return full_like(data, dtype=xp.bool, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): @@ -200,11 +188,23 @@ def cumulative_trapezoid(y, x, axis): # Pad so that 'axis' has same length in result as it did in y pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] - integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) + + xp = get_array_namespace(y, x) + integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) return cumsum(integrand, axis=axis, skipna=False) +def full_like(a, fill_value, **kwargs): + xp = get_array_namespace(a) + return xp.full_like(a, fill_value, **kwargs) + + +def empty_like(a, **kwargs): + xp = get_array_namespace(a) + return xp.empty_like(a, **kwargs) + + def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) @@ -335,7 +335,8 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes""" - return np.sum(np.logical_not(isnull(data)), axis=axis) + xp = get_array_namespace(data) + return xp.sum(xp.logical_not(isnull(data)), axis=axis) def sum_where(data, axis=None, dtype=None, where=None): @@ -350,7 +351,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - xp = get_array_namespace(condition) + xp = get_array_namespace(condition, x, y) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) @@ -367,15 +368,25 @@ def fillna(data, other): return where(notnull(data), data, other) +def logical_not(data): + xp = get_array_namespace(data) + return xp.logical_not(data) + + +def clip(data, min=None, max=None): + xp = get_array_namespace(data) + return xp.clip(data, min, max) + + def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" - # TODO: remove the additional check once `numpy` adds `concat` to its array namespace - if hasattr(arrays[0], "__array_namespace__") and not isinstance( - arrays[0], np.ndarray - ): - xp = get_array_namespace(arrays[0]) + # TODO: `concat` is the xp compliant name, but fallback to concatenate for + # older numpy and for cupy + xp = get_array_namespace(*arrays) + if hasattr(xp, "concat"): return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) - return _concatenate(as_shared_dtype(arrays), axis=axis) + else: + return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) def stack(arrays, axis=0): @@ -393,6 +404,32 @@ def ravel(array): return reshape(array, (-1,)) +def transpose(array, axes=None): + xp = get_array_namespace(array) + return xp.transpose(array, axes) + + +def moveaxis(array, source, destination): + xp = get_array_namespace(array) + return xp.moveaxis(array, source, destination) + + +def pad(array, pad_width, **kwargs): + xp = get_array_namespace(array) + return xp.pad(array, pad_width, **kwargs) + + +def sliding_window_view(array, window_shape, axis=None): + # TODO: some array libraries don't support this, implement an alternative? + xp = get_array_namespace(array) + return xp.lib.stride_tricks.sliding_window_view(array, window_shape, axis=axis) + + +def quantile(array, q, axis=None, **kwargs): + xp = get_array_namespace(array) + return xp.quantile(array, q, axis=axis, **kwargs) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -734,6 +771,11 @@ def last(values, axis, skipna=None): return take(values, -1, axis=axis) +def isin(element, test_elements, **kwargs): + xp = get_array_namespace(element, test_elements) + return xp.isin(element, test_elements, **kwargs) + + def least_squares(lhs, rhs, rcond=None, skipna=False): """Return the coefficients and residuals of a least-squares fit.""" if is_duck_dask_array(rhs): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 7fbb63068c0..4894cf02be2 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -128,7 +128,7 @@ def nanmean(a, axis=None, dtype=None, out=None): "ignore", r"Mean of empty slice", category=RuntimeWarning ) - return np.nanmean(a, axis=axis, dtype=dtype) + return nputils.nanmean(a, axis=axis, dtype=dtype) def nanmedian(a, axis=None, out=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index cd20dbccd87..24d6b1dda72 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ import pandas as pd from packaging.version import Version +from xarray.core.array_api_compat import get_array_namespace from xarray.core.utils import is_duck_array, module_available from xarray.namedarray import pycompat @@ -179,6 +180,11 @@ def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype") bn_func = getattr(bn, name, None) + xp = get_array_namespace(values) + if xp is not np: + func = getattr(xp, name, None) + if func is not None: + return func(values, axis=axis, **kwargs) if ( module_available("numbagg") and pycompat.mod_version("numbagg") >= Version("0.5.0") diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 781550207ff..dfeab5e409e 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -639,6 +639,7 @@ def _array_reduce( ) del kwargs["dim"] + xp = duck_array_ops.get_array_namespace(self.obj.data) if ( OPTIONS["use_numbagg"] and module_available("numbagg") @@ -654,6 +655,7 @@ def _array_reduce( # TODO: we could also allow this, probably as part of a refactoring of this # module, so we can use the machinery in `self.reduce`. and self.ndim == 1 + and xp is np ): import numbagg @@ -676,6 +678,7 @@ def _array_reduce( or module_available("dask", "2024.11.0") ) and self.ndim == 1 + and xp is np ): return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs diff --git a/xarray/core/variable.py b/xarray/core/variable.py index a6ea44b1ee5..f4db3fa6b1d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -822,7 +822,7 @@ def __getitem__(self, key) -> Self: data = indexing.apply_indexer(indexable, indexer) if new_order: - data = np.moveaxis(data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self, dims, data) -> Self: @@ -860,12 +860,17 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - data = duck_array_ops.where(np.logical_not(mask), data, fill_value) + # cast mask to any duck array type + if not is_duck_dask_array(mask): + mask = duck_array_ops.get_array_namespace(data).asarray(mask) + data = duck_array_ops.where( + duck_array_ops.logical_not(mask), data, fill_value + ) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. mask = indexing.create_mask(indexer, self.shape) - data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) + data = duck_array_ops.broadcast_to(fill_value, getattr(mask, "shape", ())) if new_order: data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) @@ -896,7 +901,7 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] - value = np.moveaxis(value, new_order, range(len(new_order))) + value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexing.set_with_indexer(indexable, index_tuple, value) @@ -1098,7 +1103,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): dim_pad = (width, 0) if count >= 0 else (0, width) pads = [(0, 0) if d != dim else dim_pad for d in self.dims] - data = np.pad( + data = duck_array_ops.pad( duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", @@ -1244,7 +1249,7 @@ def pad( if reflect_type is not None: pad_option_kwargs["reflect_type"] = reflect_type - array = np.pad( + array = duck_array_ops.pad( duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, @@ -1533,14 +1538,16 @@ def _unstack_once( if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial( + duck_array_ops.full_like, fill_value=fill_value + ) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) - create_template = np.empty_like + create_template = duck_array_ops.empty_like else: dtype = self.dtype - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial(duck_array_ops.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1630,7 +1637,8 @@ def clip(self, min=None, max=None): """ from xarray.core.computation import apply_ufunc - return apply_ufunc(np.clip, self, min, max, dask="allowed") + xp = duck_array_ops.get_array_namespace(self.data) + return apply_ufunc(xp.clip, self, min, max, dask="parallelized") def reduce( # type: ignore[override] self, @@ -1923,13 +1931,15 @@ def quantile( if skipna or (skipna is None and self.dtype.kind in "cfO"): _quantile_func = nputils.nanquantile else: - _quantile_func = np.quantile + _quantile_func = duck_array_ops.quantile if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) + xp = duck_array_ops.get_array_namespace(self.data) + scalar = utils.is_scalar(q) - q = np.atleast_1d(np.asarray(q, dtype=np.float64)) + q = xp.atleast_1d(xp.asarray(q, dtype=float)) if dim is None: dim = self.dims @@ -1939,9 +1949,10 @@ def quantile( def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1) + return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) - axis = np.arange(-1, -1 * len(dim) - 1, -1) + # jax requires hashable + axis = tuple(range(-1, -1 * len(dim) - 1, -1)) kwargs = {"q": q, "axis": axis, "method": method} diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py new file mode 100644 index 00000000000..1fc8ece98dd --- /dev/null +++ b/xarray/tests/test_duck_array_wrapping.py @@ -0,0 +1,371 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +# TODO: how to test these in CI? +jnp = pytest.importorskip("jax.numpy") +cp = pytest.importorskip("cupy") + + +def get_test_dataarray(xp): + return xr.DataArray( + xp.asarray([[1, 2, 3, np.nan, 5]]), + dims=["y", "x"], + coords={"y": [1], "x": np.arange(5)}, + name="foo", + ) + + +@pytest.mark.parametrize("xp", [cp, jnp]) +class TestTopLevelMethods: + @pytest.fixture(autouse=True) + def setUp(self, xp): + self.xp = xp + self.Array = xp.ndarray + self.x1 = get_test_dataarray(xp) + self.x2 = get_test_dataarray(xp).assign_coords(x=np.arange(2, 7)) + + def test_apply_ufunc(self): + func = lambda x: x + 1 + result = xr.apply_ufunc(func, self.x1) + assert isinstance(result.data, self.Array) + + def test_align(self): + result = xr.align(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_broadcast(self): + result = xr.broadcast(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_concat(self): + result = xr.concat([self.x1, self.x2], dim="x") + assert isinstance(result.data, self.Array) + + def test_merge(self): + result = xr.merge([self.x1, self.x2], compat="override") + assert isinstance(result.foo.data, self.Array) + + def test_where(self): + x1, x2 = xr.align(self.x1, self.x2, join="inner") + result = xr.where(x1 > 2, x1, x2) + assert isinstance(result.data, self.Array) + + def test_full_like(self): + result = xr.full_like(self.x1, 0) + assert isinstance(result.data, self.Array) + + def test_cov(self): + result = xr.cov(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_corr(self): + result = xr.corr(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_cross(self): + x1, x2 = xr.align(self.x1.squeeze(), self.x2.squeeze(), join="inner") + result = xr.cross(x1, x2, dim="x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = xr.dot(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_map_blocks(self): + result = xr.map_blocks(lambda x: x + 1, self.x1) + assert isinstance(result.data, self.Array) + + +@pytest.mark.parametrize("xp", [cp, jnp]) +class TestDataArrayMethods: + @pytest.fixture(autouse=True) + def setUp(self, xp): + self.xp = xp + self.Array = xp.ndarray + self.x = get_test_dataarray(xp) + + def test_loc(self): + result = self.x.loc[{"x": slice(1, 3)}] + assert isinstance(result.data, self.Array) + + def test_isel(self): + result = self.x.isel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_sel(self): + result = self.x.sel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_squeeze(self): + result = self.x.squeeze("y") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interp is not namespace aware") + def test_interp(self): + result = self.x.interp(x=2.5) + assert isinstance(result.data, self.Array) + + def test_isnull(self): + result = self.x.isnull() + assert isinstance(result.data, self.Array) + + def test_notnull(self): + result = self.x.notnull() + assert isinstance(result.data, self.Array) + + def test_count(self): + result = self.x.count() + assert isinstance(result.data, self.Array) + + def test_dropna(self): + result = self.x.dropna(dim="x") + assert isinstance(result.data, self.Array) + + def test_fillna(self): + result = self.x.fillna(0) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="ffill is not namespace aware") + def test_ffill(self): + result = self.x.ffill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="bfill is not namespace aware") + def test_bfill(self): + result = self.x.bfill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interpolate_na is not namespace aware") + def test_interpolate_na(self): + result = self.x.interpolate_na() + assert isinstance(result.data, self.Array) + + def test_where(self): + result = self.x.where(self.x > 2) + assert isinstance(result.data, self.Array) + + def test_isin(self): + result = self.x.isin(self.xp.asarray([1])) + assert isinstance(result.data, self.Array) + + def test_groupby(self): + result = self.x.groupby("x").mean() + assert isinstance(result.data, self.Array) + + def test_rolling(self): + if self.xp is jnp: + pytest.xfail("no sliding_window_view in jax") + result = self.x.rolling(x=3).mean() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rolling_exp is not namespace aware") + def test_rolling_exp(self): + result = self.x.rolling_exp(x=3).mean() + assert isinstance(result.data, self.Array) + + def test_weighted(self): + result = self.x.weighted(self.x.fillna(0)).mean() + assert isinstance(result.data, self.Array) + + def test_coarsen(self): + result = self.x.coarsen(x=2, boundary="pad").mean() + assert isinstance(result.data, self.Array) + + def test_resample(self): + time_coord = pd.date_range("2000-01-01", periods=5) + result = self.x.assign_coords(x=time_coord).resample(x="D").mean() + assert isinstance(result.data, self.Array) + + def test_diff(self): + result = self.x.diff("x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = self.x.dot(self.x) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_quantile(self, skipna): + if self.xp is cp and skipna: + pytest.xfail("no nanquantile in cupy") + result = self.x.quantile(0.5, skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_differentiate(self): + if self.xp is jnp: + pytest.xfail("edge_order kwarg") + result = self.x.differentiate("x") + assert isinstance(result.data, self.Array) + + def test_integrate(self): + result = self.x.integrate("x") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="polyfit is not namespace aware") + def test_polyfit(self): + result = self.x.polyfit("x", 1) + assert isinstance(result.polyfit_coefficients.data, self.Array) + + def test_map_blocks(self): + result = self.x.map_blocks(lambda x: x + 1) + assert isinstance(result.data, self.Array) + + def test_all(self): + result = self.x.all(dim="x") + assert isinstance(result.data, self.Array) + + def test_any(self): + result = self.x.any(dim="x") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmax(self, skipna): + result = self.x.argmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmin(self, skipna): + result = self.x.argmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmax(self, skipna): + result = self.x.idxmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmin(self, skipna): + result = self.x.idxmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_max(self, skipna): + result = self.x.max(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_min(self, skipna): + result = self.x.min(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_mean(self, skipna): + result = self.x.mean(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_median(self, skipna): + result = self.x.median(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_prod(self, skipna): + result = self.x.prod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_sum(self, skipna): + result = self.x.sum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_std(self, skipna): + if self.xp is cp and not skipna: + pytest.xfail("ddof/correction kwarg mismatch") + result = self.x.std(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_var(self, skipna): + if self.xp is cp and not skipna: + pytest.xfail("ddof/correction kwarg mismatch") + result = self.x.var(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumsum(self, skipna): + result = self.x.cumsum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumprod(self, skipna): + result = self.x.cumprod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_argsort(self): + result = self.x.argsort() + assert isinstance(result.data, self.Array) + + def test_clip(self): + result = self.x.clip(min=2.0, max=4.0) + assert isinstance(result.data, self.Array) + + def test_conj(self): + result = self.x.conj() + assert isinstance(result.data, self.Array) + + def test_conjugate(self): + result = self.x.conjugate() + assert isinstance(result.data, self.Array) + + def test_imag(self): + result = self.x.imag + assert isinstance(result.data, self.Array) + + def test_searchsorted(self): + result = self.x.squeeze().searchsorted(self.xp.asarray(3)) + assert isinstance(result, self.Array) + + def test_round(self): + result = self.x.round() + assert isinstance(result.data, self.Array) + + def test_real(self): + result = self.x.real + assert isinstance(result.data, self.Array) + + def test_T(self): + result = self.x.T + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rank is not namespace aware") + def test_rank(self): + result = self.x.rank() + assert isinstance(result.data, self.Array) + + def test_transpose(self): + result = self.x.transpose() + assert isinstance(result.data, self.Array) + + def test_stack(self): + result = self.x.stack(z=("x", "y")) + assert isinstance(result.data, self.Array) + + def test_unstack(self): + result = self.x.stack(z=("x", "y")).unstack("z") + assert isinstance(result.data, self.Array) + + def test_shift(self): + result = self.x.shift(x=1) + assert isinstance(result.data, self.Array) + + def test_roll(self): + result = self.x.roll(x=1) + assert isinstance(result.data, self.Array) + + def test_pad(self): + result = self.x.pad(x=1) + assert isinstance(result.data, self.Array) + + def test_sortby(self): + result = self.x.sortby("x") + assert isinstance(result.data, self.Array) + + def test_broadcast_like(self): + result = self.x.broadcast_like(self.x) + assert isinstance(result.data, self.Array) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 798f5f732d1..48819333ca2 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -13,6 +13,7 @@ from hypothesis import given from hypothesis.extra.array_api import make_strategies_namespace +from xarray.core.options import set_options from xarray.core.variable import Variable from xarray.testing.strategies import ( attrs, @@ -267,14 +268,14 @@ def test_mean(self, data, var): Test that given a Variable of at least one dimension, the mean of the Variable is always equal to the mean of the underlying array. """ + with set_options(use_numbagg=False): + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) - # specify arbitrary reduction along at least one dimension - reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) - # create expected result (using nanmean because arrays with Nans will be generated) - reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) - expected = np.nanmean(var.data, axis=reduction_axes) - - # assert property is always satisfied - result = var.mean(dim=reduction_dims).data - npt.assert_equal(expected, result) + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9c6f50037d3..1461489e731 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1978,26 +1978,27 @@ def test_reduce_funcs(self): def test_reduce_keepdims(self): v = Variable(["x", "y"], self.d) - assert_identical( - v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) - ) - assert_identical( - v.mean(dim="x", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), - ) - assert_identical( - v.mean(dim="y", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), - ) - assert_identical( - v.mean(dim=["y", "x"], keepdims=True), - Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), - ) + with set_options(use_numbagg=False): + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) - v = Variable([], 1.0) - assert_identical( - v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) - ) + v = Variable([], 1.0) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) @requires_dask def test_reduce_keepdims_dask(self): From f7866ce78bd71e604a8e05d14cc8080dd7dd9ec4 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 15:15:27 -0500 Subject: [PATCH 02/24] merge sliding_window_view --- xarray/core/duck_array_ops.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index cca07a44f52..f994fec7ae8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -122,21 +122,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" ) -# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), -# so we need to hand-code this. -sliding_window_view = _dask_or_eager_func( - "sliding_window_view", - eager_module=np.lib.stride_tricks, - dask_module=dask_array_compat, - dask_only_kwargs=("automatic_rechunk",), - numpy_only_kwargs=("subok", "writeable"), -) - -# def sliding_window_view(array, window_shape, axis=None): -# # TODO: some array libraries don't support this, implement an alternative? -# xp = get_array_namespace(array) -# return xp.lib.stride_tricks.sliding_window_view(array, window_shape, axis=axis) +def sliding_window_view(array, window_shape, axis=None, **kwargs): + # TODO: some libraries (e.g. jax) don't have this, implement an alternative? + xp = get_array_namespace(array) + # sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), + # so we need to hand-code this. + func = _dask_or_eager_func( + "sliding_window_view", + eager_module=xp.lib.stride_tricks, + dask_module=dask_array_compat, + dask_only_kwargs=("automatic_rechunk",), + numpy_only_kwargs=("subok", "writeable"), + ) + return func(array, window_shape, axis=axis, **kwargs) def round(array): From 90037fe8e883cef727f8a4893541f94a46c583cc Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 15:25:57 -0500 Subject: [PATCH 03/24] namespaces constant --- xarray/tests/test_duck_array_wrapping.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 1fc8ece98dd..c58c62bf84b 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -8,6 +8,8 @@ jnp = pytest.importorskip("jax.numpy") cp = pytest.importorskip("cupy") +NAMESPACES = [cp, jnp] + def get_test_dataarray(xp): return xr.DataArray( @@ -18,7 +20,7 @@ def get_test_dataarray(xp): ) -@pytest.mark.parametrize("xp", [cp, jnp]) +@pytest.mark.parametrize("xp", NAMESPACES) class TestTopLevelMethods: @pytest.fixture(autouse=True) def setUp(self, xp): @@ -81,7 +83,7 @@ def test_map_blocks(self): assert isinstance(result.data, self.Array) -@pytest.mark.parametrize("xp", [cp, jnp]) +@pytest.mark.parametrize("xp", NAMESPACES) class TestDataArrayMethods: @pytest.fixture(autouse=True) def setUp(self, xp): From 5ba1a2f81280007b1e6ac1089aa9934e76c5c83e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 18 Nov 2024 15:32:03 -0500 Subject: [PATCH 04/24] revert dask allowed --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 1597e4bbe66..dd67b290cf2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1638,7 +1638,7 @@ def clip(self, min=None, max=None): from xarray.core.computation import apply_ufunc xp = duck_array_ops.get_array_namespace(self.data) - return apply_ufunc(xp.clip, self, min, max, dask="parallelized") + return apply_ufunc(xp.clip, self, min, max, dask="allowed") def reduce( # type: ignore[override] self, From 6225ae3a70d785047259decda130513d1e5df03a Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Tue, 19 Nov 2024 15:20:12 -0500 Subject: [PATCH 05/24] fix up some tests --- xarray/core/dataset.py | 9 ++++++--- xarray/core/duck_array_ops.py | 3 ++- xarray/core/nputils.py | 3 +++ xarray/core/variable.py | 6 +++--- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 6e5a8e163b8..b7ecacf98b3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -8679,9 +8679,12 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): else: if k in self.data_vars and dim in v.dims: # cast coord data to duck array if needed - coord_data = duck_array_ops.get_array_namespace(v.data).asarray( - coord_var.data - ) + if isinstance(v.data, array_type("cupy")): + coord_data = duck_array_ops.get_array_namespace(v.data).asarray( + coord_var.data + ) + else: + coord_data = coord_var.data if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index f994fec7ae8..746004c630d 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -168,7 +168,8 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=xp.bool, fill_value=False) + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + return full_like(data, dtype=dtype, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index 24d6b1dda72..b5f399debab 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -236,6 +236,9 @@ def f(values, axis=None, **kwargs): # bottleneck does not take care dtype, min_count kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) + # bottleneck returns python scalars for reduction over all axes + if isinstance(result, float): + result = np.float64(result) else: result = getattr(npmodule, name)(values, axis=axis, **kwargs) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index dd67b290cf2..dd6feb4f07d 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1936,10 +1936,8 @@ def quantile( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) - xp = duck_array_ops.get_array_namespace(self.data) - scalar = utils.is_scalar(q) - q = xp.atleast_1d(xp.asarray(q, dtype=float)) + q = np.atleast_1d(np.asarray(q, dtype=np.float64)) if dim is None: dim = self.dims @@ -1947,6 +1945,8 @@ def quantile( if utils.is_scalar(dim): dim = [dim] + xp = duck_array_ops.get_array_namespace(self.data) + def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) From e2911c2810a628a5b1a7becec103c2c8528d1c37 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Tue, 19 Nov 2024 16:16:19 -0500 Subject: [PATCH 06/24] backwards compat sparse mask --- xarray/core/variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index dd6feb4f07d..833b15d7993 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -861,7 +861,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed # cast mask to any duck array type - if not is_duck_dask_array(mask): + if type(mask) is not type(data): mask = duck_array_ops.get_array_namespace(data).asarray(mask) data = duck_array_ops.where( duck_array_ops.logical_not(mask), data, fill_value From 2ac37f9769236225d6d6692e2e2ce60414d5e9d0 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 22:13:22 -0500 Subject: [PATCH 07/24] add as_array methods --- xarray/core/dataarray.py | 22 ++++++++++++++++++++++ xarray/core/dataset.py | 26 ++++++++++++++++++++++++++ xarray/namedarray/core.py | 4 ++++ xarray/tests/test_dataarray.py | 13 +++++++++++++ xarray/tests/test_dataset.py | 15 +++++++++++++++ 5 files changed, 80 insertions(+) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 52ce2463d51..2b19863e35e 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -842,6 +842,28 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) + def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + """ + Coerces wrapped data into a specific array type. + + `asarray` should output an object that supports the Array API Standard. + This method does not convert index coordinates, which can't generally be + represented as arbitrary array types. + + Parameters + ---------- + asarray : Callable + Function that converts an array-like object to the desired array type. + For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + **kwargs : dict + Additional keyword arguments passed to the `asarray` function. + + Returns + ------- + DataArray + """ + return self._replace(self.variable.as_array(asarray, **kwargs)) + @property def _in_memory(self) -> bool: return self.variable._in_memory diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index b7ecacf98b3..700d733a543 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1434,6 +1434,32 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) + def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + """ + Converts wrapped data into a specific array type. + + `asarray` should output an object that supports the Array API Standard. + This method does not convert index coordinates, which can't generally be + represented as arbitrary array types. + + Parameters + ---------- + asarray : Callable + Function that converts an array-like object to the desired array type. + For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + **kwargs : dict + Additional keyword arguments passed to the `asarray` function. + + Returns + ------- + Dataset + """ + array_variables = { + k: v.as_array(asarray, **kwargs) if k not in self._indexes else v + for k, v in self.variables.items() + } + return self._replace(variables=array_variables) + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 98d96c73e91..8ae17ebce13 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -860,6 +860,10 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) + def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + """Coerces wrapped data into a specific array type, returning a Variable.""" + return self._replace(data=asarray(self._data, **kwargs)) + def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b5ecc9517d9..aa6cb5e1721 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -39,6 +39,7 @@ from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants from xarray.tests import ( + DuckArrayWrapper, InaccessibleArray, ReturnItem, assert_allclose, @@ -7165,6 +7166,18 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) +def test_as_array() -> None: + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) + + def as_duck_array(arr): + return DuckArrayWrapper(arr) + + result = da.as_array(as_duck_array) + + assert isinstance(result.data, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index be82655515d..bbfc2df3fd7 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7639,6 +7639,21 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) +def test_as_array() -> None: + ds = xr.Dataset( + {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} + ) + + def as_duck_array(arr): + return DuckArrayWrapper(arr) + + result = ds.as_array(as_duck_array) + + assert isinstance(result.a.data, DuckArrayWrapper) + assert isinstance(result.lat.data, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy""" From 1cc344ba46164fd2294c2d6fdaf8bc1e7c273afd Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 22:42:54 -0500 Subject: [PATCH 08/24] to_like_array helper --- xarray/core/array_api_compat.py | 6 ++++++ xarray/core/computation.py | 5 ++--- xarray/core/dataset.py | 9 ++------- xarray/core/variable.py | 5 ++--- 4 files changed, 12 insertions(+), 13 deletions(-) diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index 28d671cc349..1845d6eddcc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -70,3 +70,9 @@ def _get_single_namespace(x): xp = np return xp + + +def to_like_array(array, like): + # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays + xp = get_array_namespace(like) + return xp.asarray(array) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0bfe21642f7..0945f4638f6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,6 +24,7 @@ from xarray.core import dtypes, duck_array_ops, utils from xarray.core.alignment import align, deep_align +from xarray.core.array_api_compat import to_like_array from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.formatting import limit_lines @@ -2178,9 +2179,7 @@ def _calc_idxminmax( res = array[dim][(indx,)] # The dim is gone but we need to remove the corresponding coordinate. del res.coords[dim] - # Cast to array namespace - xp = duck_array_ops.get_array_namespace(array.data) - res.data = xp.asarray(res.data) + res.data = to_like_array(res.data, array.data) if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 700d733a543..eded1d89d05 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -55,6 +55,7 @@ align, ) from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import ( DataWithCoords, _contains_datetime_like_objects, @@ -8704,13 +8705,7 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: - # cast coord data to duck array if needed - if isinstance(v.data, array_type("cupy")): - coord_data = duck_array_ops.get_array_namespace(v.data).asarray( - coord_var.data - ) - else: - coord_data = coord_var.data + coord_data = to_like_array(coord_var.data, like=v.data) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 833b15d7993..a472e809876 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,6 +19,7 @@ import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import AbstractArray from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( @@ -860,9 +861,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - # cast mask to any duck array type - if type(mask) is not type(data): - mask = duck_array_ops.get_array_namespace(data).asarray(mask) + mask = to_like_array(mask, data) data = duck_array_ops.where( duck_array_ops.logical_not(mask), data, fill_value ) From 372439ce144f97fdf1e2fdee1e0d5c5f05ae919e Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 22:58:45 -0500 Subject: [PATCH 09/24] only cast non-numpy --- xarray/core/array_api_compat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index 1845d6eddcc..e7424325de8 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -75,4 +75,7 @@ def _get_single_namespace(x): def to_like_array(array, like): # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays xp = get_array_namespace(like) - return xp.asarray(array) + if xp is not np: + return xp.asarray(array) + # avoid casting things like pint quantities to numpy arrays + return array From 0eef2cbe2d0cae5fd80c8d3e510e54e1d2978df3 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 23:27:43 -0500 Subject: [PATCH 10/24] better idxminmax approach --- xarray/core/computation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 0945f4638f6..6e233425e95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -35,7 +35,7 @@ from xarray.core.utils import is_dict_like, is_scalar, parse_dims_as_set, result_name from xarray.core.variable import Variable from xarray.namedarray.parallelcompat import get_chunked_array_type -from xarray.namedarray.pycompat import is_chunked_array, to_numpy +from xarray.namedarray.pycompat import is_chunked_array from xarray.util.deprecation_helpers import deprecate_dims if TYPE_CHECKING: @@ -2171,15 +2171,14 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks, strict=True)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) data = dask_coord[duck_array_ops.ravel(indx.data)] - res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) - # we need to attach back the dim name - res.name = dim else: - indx.data = to_numpy(indx.data) - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] - res.data = to_like_array(res.data, array.data) + arr_coord = to_like_array(array[dim].data, array.data) + data = arr_coord[duck_array_ops.ravel(indx.data)] + + # rebuild like the argmin/max output, and rename as the dim name + data = duck_array_ops.reshape(data, indx.shape) + res = indx.copy(data=data) + res.name = dim if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them From 6739504fc7a7f2f4646ff7c06aa7b2653840c00d Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Wed, 20 Nov 2024 23:58:38 -0500 Subject: [PATCH 11/24] fix mypy --- xarray/core/dataarray.py | 2 +- xarray/core/dataset.py | 2 +- xarray/namedarray/core.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 2f8a6ce620b..ff9880bf2da 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -842,7 +842,7 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) - def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + def as_array(self, asarray: Callable, **kwargs) -> Self: """ Coerces wrapped data into a specific array type. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index eded1d89d05..3711045f8c9 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1435,7 +1435,7 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + def as_array(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 8ae17ebce13..e80a15fdc3f 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -860,7 +860,11 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def as_array(self, asarray: Callable[[ArrayLike, ...], Any], **kwargs) -> Self: + def as_array( + self, + asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], + **kwargs: Any, + ) -> Self: """Coerces wrapped data into a specific array type, returning a Variable.""" return self._replace(data=asarray(self._data, **kwargs)) From 9e6d6f8155a467f1c941de8496255bcca2b4ddbf Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 09:27:02 -0500 Subject: [PATCH 12/24] naming, add is_array_type --- xarray/core/dataarray.py | 22 +++++++++++++++++++--- xarray/core/dataset.py | 26 +++++++++++++++++++++++--- xarray/namedarray/core.py | 31 +++++++++++++++++++++++++++++-- xarray/tests/test_dataarray.py | 7 +++++-- xarray/tests/test_dataset.py | 7 +++++-- 5 files changed, 81 insertions(+), 12 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ff9880bf2da..7796904d897 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -842,7 +842,7 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) - def as_array(self, asarray: Callable, **kwargs) -> Self: + def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Coerces wrapped data into a specific array type. @@ -854,7 +854,8 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: ---------- asarray : Callable Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, + or any `from_dlpack` method. **kwargs : dict Additional keyword arguments passed to the `asarray` function. @@ -862,7 +863,22 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: ------- DataArray """ - return self._replace(self.variable.as_array(asarray, **kwargs)) + return self._replace(self.variable.as_array_type(asarray, **kwargs)) + + def is_array_type(self, array_type: type) -> bool: + """ + Check if the wrapped data is of a specific array type. + + Parameters + ---------- + array_type : type + The array type to check for. + + Returns + ------- + bool + """ + return self.variable.is_array_type(array_type) @property def _in_memory(self) -> bool: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3711045f8c9..32ea1b98308 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1435,7 +1435,7 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def as_array(self, asarray: Callable, **kwargs) -> Self: + def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. @@ -1447,7 +1447,8 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: ---------- asarray : Callable Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, or `sparse.COO.from_numpy`. + For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, + or any `from_dlpack` method. **kwargs : dict Additional keyword arguments passed to the `asarray` function. @@ -1456,11 +1457,30 @@ def as_array(self, asarray: Callable, **kwargs) -> Self: Dataset """ array_variables = { - k: v.as_array(asarray, **kwargs) if k not in self._indexes else v + k: v.as_array_type(asarray, **kwargs) if k not in self._indexes else v for k, v in self.variables.items() } return self._replace(variables=array_variables) + def is_array_type(self, array_type: type) -> bool: + """ + Check if all data variables and non-index coordinates are of a specific array type. + + Parameters + ---------- + array_type : type + The array type to check for. + + Returns + ------- + bool + """ + return all( + v.is_array_type(array_type) + for k, v in self.variables.items() + if k not in self._indexes + ) + def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index e80a15fdc3f..2558ecec9c7 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -860,14 +860,41 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def as_array( + def as_array_type( self, asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], **kwargs: Any, ) -> Self: - """Coerces wrapped data into a specific array type, returning a Variable.""" + """Converts wrapped data into a specific array type. + + Parameters + ---------- + asarray : callable + Function that converts the data into a specific array type. + **kwargs : dict + Additional keyword arguments passed on to `asarray`. + + Returns + ------- + array : NamedArray + Array with the same data, but converted into a specific array type + """ return self._replace(data=asarray(self._data, **kwargs)) + def is_array_type(self, array_type: type) -> bool: + """Check if the data is an instance of a specific array type. + + Parameters + ---------- + array_type : type + Array type to check against. + + Returns + ------- + is_array_type : bool + """ + return isinstance(self._data, array_type) + def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 6e1efe85185..8bc63f3bf4b 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -7166,16 +7166,19 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) -def test_as_array() -> None: +def test_as_array_type_is_array_type() -> None: da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) + assert da.is_array_type(np.ndarray) + def as_duck_array(arr): return DuckArrayWrapper(arr) - result = da.as_array(as_duck_array) + result = da.as_array_type(as_duck_array) assert isinstance(result.data, DuckArrayWrapper) assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(DuckArrayWrapper) class TestStackEllipsis: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 13917e28225..edca2a02c93 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7639,19 +7639,22 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) -def test_as_array() -> None: +def test_as_array_type_is_array_type() -> None: ds = xr.Dataset( {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} ) + # lat is a PandasIndex here + assert ds.drop_vars("lat").is_array_type(np.ndarray) def as_duck_array(arr): return DuckArrayWrapper(arr) - result = ds.as_array(as_duck_array) + result = ds.as_array_type(as_duck_array) assert isinstance(result.a.data, DuckArrayWrapper) assert isinstance(result.lat.data, DuckArrayWrapper) assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(DuckArrayWrapper) def test_string_keys_typing() -> None: From e72101155ed172a217b90d5fc36c95a495545049 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 09:31:45 -0500 Subject: [PATCH 13/24] add public doc and whats new --- doc/api.rst | 4 ++++ doc/whats-new.rst | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/doc/api.rst b/doc/api.rst index 0c30ddc4c20..e5517eaf07e 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,6 +117,8 @@ Dataset contents Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index + Dataset.as_array_type + Dataset.is_array_type Comparisons ----------- @@ -315,6 +317,8 @@ DataArray contents DataArray.get_index DataArray.astype DataArray.item + DataArray.as_array_type + DataArray.is_array_type Indexing -------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3a04467d483..8084cc17780 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,6 +64,10 @@ New Features underlying array's backend. Provides better support for certain wrapped array types like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. +- Make more xarray methods fully compatible with duck array types, and introduce new + ``as_array_type`` and ``is_array_type`` methods for converting wrapped data to other + duck array types. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes ~~~~~~~~~~~~~~~~ From 1fe41316b2063db6b131828c0fcd26e1d3926abc Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 10:19:28 -0500 Subject: [PATCH 14/24] update comments --- xarray/core/array_api_compat.py | 3 ++- xarray/core/dataarray.py | 2 +- xarray/tests/test_duck_array_wrapping.py | 27 ++++++++++++------------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index e7424325de8..e1e5d5c5bdc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -51,7 +51,8 @@ def _get_single_namespace(x): if hasattr(x, "__array_namespace__"): return x.__array_namespace__() elif isinstance(x, array_type("cupy")): - # special case cupy for now + # cupy is fully compliant from xarray's perspective, but will not expose + # __array_namespace__ until at least v14. Special case it for now import cupy as cp return cp diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 7796904d897..ead13663cb8 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -844,7 +844,7 @@ def as_numpy(self) -> Self: def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ - Coerces wrapped data into a specific array type. + Converts wrapped data into a specific array type. `asarray` should output an object that supports the Array API Standard. This method does not convert index coordinates, which can't generally be diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index c58c62bf84b..ffaa1440e1a 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -107,8 +107,9 @@ def test_squeeze(self): result = self.x.squeeze("y") assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="interp is not namespace aware") + @pytest.mark.xfail(reason="interp uses numpy and scipy") def test_interp(self): + # TODO: some cases could be made to work result = self.x.interp(x=2.5) assert isinstance(result.data, self.Array) @@ -132,17 +133,17 @@ def test_fillna(self): result = self.x.fillna(0) assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="ffill is not namespace aware") + @pytest.mark.xfail(reason="ffill uses bottleneck or numbagg") def test_ffill(self): result = self.x.ffill() assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="bfill is not namespace aware") + @pytest.mark.xfail(reason="bfill uses bottleneck or numbagg") def test_bfill(self): result = self.x.bfill() assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="interpolate_na is not namespace aware") + @pytest.mark.xfail(reason="interpolate_na uses numpy and scipy") def test_interpolate_na(self): result = self.x.interpolate_na() assert isinstance(result.data, self.Array) @@ -165,7 +166,7 @@ def test_rolling(self): result = self.x.rolling(x=3).mean() assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="rolling_exp is not namespace aware") + @pytest.mark.xfail(reason="rolling_exp uses numbagg") def test_rolling_exp(self): result = self.x.rolling_exp(x=3).mean() assert isinstance(result.data, self.Array) @@ -199,17 +200,18 @@ def test_quantile(self, skipna): assert isinstance(result.data, self.Array) def test_differentiate(self): - if self.xp is jnp: - pytest.xfail("edge_order kwarg") - result = self.x.differentiate("x") + # edge_order is not implemented in jax, and only supports passing None + edge_order = None if self.xp is jnp else 1 + result = self.x.differentiate("x", edge_order=edge_order) assert isinstance(result.data, self.Array) def test_integrate(self): result = self.x.integrate("x") assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="polyfit is not namespace aware") + @pytest.mark.xfail(reason="polyfit uses numpy linalg") def test_polyfit(self): + # TODO: this could work, there are just a lot of different linalg calls result = self.x.polyfit("x", 1) assert isinstance(result.polyfit_coefficients.data, self.Array) @@ -277,15 +279,11 @@ def test_sum(self, skipna): @pytest.mark.parametrize("skipna", [True, False]) def test_std(self, skipna): - if self.xp is cp and not skipna: - pytest.xfail("ddof/correction kwarg mismatch") result = self.x.std(dim="x", skipna=skipna) assert isinstance(result.data, self.Array) @pytest.mark.parametrize("skipna", [True, False]) def test_var(self, skipna): - if self.xp is cp and not skipna: - pytest.xfail("ddof/correction kwarg mismatch") result = self.x.var(dim="x", skipna=skipna) assert isinstance(result.data, self.Array) @@ -335,8 +333,9 @@ def test_T(self): result = self.x.T assert isinstance(result.data, self.Array) - @pytest.mark.xfail(reason="rank is not namespace aware") + @pytest.mark.xfail(reason="rank uses bottleneck") def test_rank(self): + # TODO: scipy has rankdata, as does jax, so this can work result = self.x.rank() assert isinstance(result.data, self.Array) From 205c1995703d0f259cfce907337558c6256c0a43 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Thu, 21 Nov 2024 10:57:26 -0500 Subject: [PATCH 15/24] add support for chunked arrays in as_array_type --- xarray/core/dataarray.py | 2 ++ xarray/core/dataset.py | 2 ++ xarray/namedarray/core.py | 14 +++++++++++--- xarray/tests/test_dataarray.py | 19 +++++++++++++++---- xarray/tests/test_dataset.py | 24 ++++++++++++++++++++---- 5 files changed, 50 insertions(+), 11 deletions(-) diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index ead13663cb8..021e9d85474 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -846,6 +846,8 @@ def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. + If the data is a chunked array, the conversion is applied to each block. + `asarray` should output an object that supports the Array API Standard. This method does not convert index coordinates, which can't generally be represented as arbitrary array types. diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 32ea1b98308..038d503e682 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1439,6 +1439,8 @@ def as_array_type(self, asarray: Callable, **kwargs) -> Self: """ Converts wrapped data into a specific array type. + If the data is a chunked array, the conversion is applied to each block. + `asarray` should output an object that supports the Array API Standard. This method does not convert index coordinates, which can't generally be represented as arbitrary array types. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index 2558ecec9c7..ab4b3bc1820 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -40,8 +40,8 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.parallelcompat import guess_chunkmanager -from xarray.namedarray.pycompat import to_numpy +from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager +from xarray.namedarray.pycompat import is_chunked_array, to_numpy from xarray.namedarray.utils import ( either_dict_or_kwargs, infix_dims, @@ -867,6 +867,8 @@ def as_array_type( ) -> Self: """Converts wrapped data into a specific array type. + If the data is a chunked array, the conversion is applied to each block. + Parameters ---------- asarray : callable @@ -879,7 +881,13 @@ def as_array_type( array : NamedArray Array with the same data, but converted into a specific array type """ - return self._replace(data=asarray(self._data, **kwargs)) + if is_chunked_array(self._data): + chunkmanager = get_chunked_array_type(self._data) + new_data = chunkmanager.map_blocks(asarray, self._data, **kwargs) + else: + new_data = asarray(self._data, **kwargs) + + return self._replace(data=new_data) def is_array_type(self, array_type: type) -> bool: """Check if the data is an instance of a specific array type. diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 8bc63f3bf4b..b4af9d37e35 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -7171,16 +7171,27 @@ def test_as_array_type_is_array_type() -> None: assert da.is_array_type(np.ndarray) - def as_duck_array(arr): - return DuckArrayWrapper(arr) - - result = da.as_array_type(as_duck_array) + result = da.as_array_type(lambda x: DuckArrayWrapper(x)) assert isinstance(result.data, DuckArrayWrapper) assert isinstance(result.x.data, np.ndarray) assert result.is_array_type(DuckArrayWrapper) +@requires_dask +def test_as_array_type_dask() -> None: + import dask.array + + da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}).chunk() + + result = da.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.data, dask.array.Array) + assert isinstance(result.data._meta, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + assert result.is_array_type(dask.array.Array) + + class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index edca2a02c93..b8dbcabf3ce 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7646,10 +7646,7 @@ def test_as_array_type_is_array_type() -> None: # lat is a PandasIndex here assert ds.drop_vars("lat").is_array_type(np.ndarray) - def as_duck_array(arr): - return DuckArrayWrapper(arr) - - result = ds.as_array_type(as_duck_array) + result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) assert isinstance(result.a.data, DuckArrayWrapper) assert isinstance(result.lat.data, DuckArrayWrapper) @@ -7657,6 +7654,25 @@ def as_duck_array(arr): assert result.is_array_type(DuckArrayWrapper) +@requires_dask +def test_as_array_type_dask() -> None: + import dask.array + + ds = xr.Dataset( + {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} + ).chunk() + + assert ds.is_array_type(dask.array.Array) + + result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) + + assert isinstance(result.a.data, dask.array.Array) + assert isinstance(result.a.data._meta, DuckArrayWrapper) + assert isinstance(result.lat.data, dask.array.Array) + assert isinstance(result.lat.data._meta, DuckArrayWrapper) + assert isinstance(result.x.data, np.ndarray) + + def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy""" From c8d4e5ec713358f05a0def3789b38f778e346ad5 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 22 Nov 2024 14:19:25 -0500 Subject: [PATCH 16/24] revert array_type methods --- doc/api.rst | 4 --- doc/whats-new.rst | 5 ++-- xarray/core/dataarray.py | 40 ---------------------------- xarray/core/dataset.py | 48 ---------------------------------- xarray/namedarray/core.py | 47 ++------------------------------- xarray/tests/test_dataarray.py | 27 ------------------- xarray/tests/test_dataset.py | 34 ------------------------ 7 files changed, 4 insertions(+), 201 deletions(-) diff --git a/doc/api.rst b/doc/api.rst index 7a596fdaa2d..85ef46ca6ba 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -117,8 +117,6 @@ Dataset contents Dataset.convert_calendar Dataset.interp_calendar Dataset.get_index - Dataset.as_array_type - Dataset.is_array_type Comparisons ----------- @@ -317,8 +315,6 @@ DataArray contents DataArray.get_index DataArray.astype DataArray.item - DataArray.as_array_type - DataArray.is_array_type Indexing -------- diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 8084cc17780..3801075a310 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,9 +64,8 @@ New Features underlying array's backend. Provides better support for certain wrapped array types like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. -- Make more xarray methods fully compatible with duck array types, and introduce new - ``as_array_type`` and ``is_array_type`` methods for converting wrapped data to other - duck array types. (:issue:`7848`, :pull:`9798`). +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). By `Sam Levang `_. Breaking changes diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index e45aaac5836..eae11c0c491 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -844,46 +844,6 @@ def as_numpy(self) -> Self: coords = {k: v.as_numpy() for k, v in self._coords.items()} return self._replace(self.variable.as_numpy(), coords, indexes=self._indexes) - def as_array_type(self, asarray: Callable, **kwargs) -> Self: - """ - Converts wrapped data into a specific array type. - - If the data is a chunked array, the conversion is applied to each block. - - `asarray` should output an object that supports the Array API Standard. - This method does not convert index coordinates, which can't generally be - represented as arbitrary array types. - - Parameters - ---------- - asarray : Callable - Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, - or any `from_dlpack` method. - **kwargs : dict - Additional keyword arguments passed to the `asarray` function. - - Returns - ------- - DataArray - """ - return self._replace(self.variable.as_array_type(asarray, **kwargs)) - - def is_array_type(self, array_type: type) -> bool: - """ - Check if the wrapped data is of a specific array type. - - Parameters - ---------- - array_type : type - The array type to check for. - - Returns - ------- - bool - """ - return self.variable.is_array_type(array_type) - @property def _in_memory(self) -> bool: return self.variable._in_memory diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dd2df1c77c1..b305e4b51de 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -1437,54 +1437,6 @@ def as_numpy(self) -> Self: numpy_variables = {k: v.as_numpy() for k, v in self.variables.items()} return self._replace(variables=numpy_variables) - def as_array_type(self, asarray: Callable, **kwargs) -> Self: - """ - Converts wrapped data into a specific array type. - - If the data is a chunked array, the conversion is applied to each block. - - `asarray` should output an object that supports the Array API Standard. - This method does not convert index coordinates, which can't generally be - represented as arbitrary array types. - - Parameters - ---------- - asarray : Callable - Function that converts an array-like object to the desired array type. - For example, `cupy.asarray`, `jax.numpy.asarray`, `sparse.COO.from_numpy`, - or any `from_dlpack` method. - **kwargs : dict - Additional keyword arguments passed to the `asarray` function. - - Returns - ------- - Dataset - """ - array_variables = { - k: v.as_array_type(asarray, **kwargs) if k not in self._indexes else v - for k, v in self.variables.items() - } - return self._replace(variables=array_variables) - - def is_array_type(self, array_type: type) -> bool: - """ - Check if all data variables and non-index coordinates are of a specific array type. - - Parameters - ---------- - array_type : type - The array type to check for. - - Returns - ------- - bool - """ - return all( - v.is_array_type(array_type) - for k, v in self.variables.items() - if k not in self._indexes - ) - def _copy_listed(self, names: Iterable[Hashable]) -> Self: """Create a new Dataset with the listed variables from this dataset and the all relevant coordinates. Skips all validation. diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index ab4b3bc1820..98d96c73e91 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -40,8 +40,8 @@ _SupportsImag, _SupportsReal, ) -from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import is_chunked_array, to_numpy +from xarray.namedarray.parallelcompat import guess_chunkmanager +from xarray.namedarray.pycompat import to_numpy from xarray.namedarray.utils import ( either_dict_or_kwargs, infix_dims, @@ -860,49 +860,6 @@ def as_numpy(self) -> Self: """Coerces wrapped data into a numpy array, returning a Variable.""" return self._replace(data=self.to_numpy()) - def as_array_type( - self, - asarray: Callable[[duckarray[Any, _DType_co]], duckarray[Any, _DType_co]], - **kwargs: Any, - ) -> Self: - """Converts wrapped data into a specific array type. - - If the data is a chunked array, the conversion is applied to each block. - - Parameters - ---------- - asarray : callable - Function that converts the data into a specific array type. - **kwargs : dict - Additional keyword arguments passed on to `asarray`. - - Returns - ------- - array : NamedArray - Array with the same data, but converted into a specific array type - """ - if is_chunked_array(self._data): - chunkmanager = get_chunked_array_type(self._data) - new_data = chunkmanager.map_blocks(asarray, self._data, **kwargs) - else: - new_data = asarray(self._data, **kwargs) - - return self._replace(data=new_data) - - def is_array_type(self, array_type: type) -> bool: - """Check if the data is an instance of a specific array type. - - Parameters - ---------- - array_type : type - Array type to check against. - - Returns - ------- - is_array_type : bool - """ - return isinstance(self._data, array_type) - def reduce( self, func: Callable[..., Any], diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index b4af9d37e35..c8b438948de 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -39,7 +39,6 @@ from xarray.core.utils import is_scalar from xarray.testing import _assert_internal_invariants from xarray.tests import ( - DuckArrayWrapper, InaccessibleArray, ReturnItem, assert_allclose, @@ -7166,32 +7165,6 @@ def test_from_pint_wrapping_dask(self) -> None: np.testing.assert_equal(da.to_numpy(), arr) -def test_as_array_type_is_array_type() -> None: - da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}) - - assert da.is_array_type(np.ndarray) - - result = da.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.data, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - assert result.is_array_type(DuckArrayWrapper) - - -@requires_dask -def test_as_array_type_dask() -> None: - import dask.array - - da = xr.DataArray([1, 2, 3], dims=["x"], coords={"x": [4, 5, 6]}).chunk() - - result = da.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.data, dask.array.Array) - assert isinstance(result.data._meta, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - assert result.is_array_type(dask.array.Array) - - class TestStackEllipsis: # https://github.com/pydata/xarray/issues/6051 def test_result_as_expected(self) -> None: diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index b8dbcabf3ce..67d38aac0fe 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -7639,40 +7639,6 @@ def test_from_pint_wrapping_dask(self) -> None: assert_identical(result, expected) -def test_as_array_type_is_array_type() -> None: - ds = xr.Dataset( - {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} - ) - # lat is a PandasIndex here - assert ds.drop_vars("lat").is_array_type(np.ndarray) - - result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.a.data, DuckArrayWrapper) - assert isinstance(result.lat.data, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - assert result.is_array_type(DuckArrayWrapper) - - -@requires_dask -def test_as_array_type_dask() -> None: - import dask.array - - ds = xr.Dataset( - {"a": ("x", [1, 2, 3])}, coords={"lat": ("x", [4, 5, 6]), "x": [7, 8, 9]} - ).chunk() - - assert ds.is_array_type(dask.array.Array) - - result = ds.as_array_type(lambda x: DuckArrayWrapper(x)) - - assert isinstance(result.a.data, dask.array.Array) - assert isinstance(result.a.data._meta, DuckArrayWrapper) - assert isinstance(result.lat.data, dask.array.Array) - assert isinstance(result.lat.data._meta, DuckArrayWrapper) - assert isinstance(result.x.data, np.ndarray) - - def test_string_keys_typing() -> None: """Tests that string keys to `variables` are permitted by mypy""" From f306768fe78d4751e6f264ff992dff09e20453a8 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 22 Nov 2024 14:20:21 -0500 Subject: [PATCH 17/24] fix up whats new --- doc/whats-new.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ccaac9e7263..e1fb12269ed 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -64,11 +64,11 @@ New Features underlying array's backend. Provides better support for certain wrapped array types like ``jax.numpy.ndarray``. (:issue:`7848`, :pull:`9776`). By `Sam Levang `_. +- Speed up loading of large zarr stores using dask arrays. (:issue:`8902`) + By `Deepak Cherian `_. - Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). By `Sam Levang `_. -- Speed up loading of large zarr stores using dask arrays. (:issue:`8902`) - By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ From 18ebdcdb29bda39395d254be4f7cb3c3f88b6e16 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Fri, 22 Nov 2024 17:22:06 -0500 Subject: [PATCH 18/24] comment about bool_ --- xarray/core/duck_array_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 59f3da2c8f7..7e7333fd8ea 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -166,6 +166,7 @@ def isnull(data): ) ): # these types cannot represent missing values + # bool_ is for backwards compat with numpy<2, and cupy dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool return full_like(data, dtype=dtype, fill_value=False) else: From 121af9e5b1a12d2759a2f846dd7207afa3100bcb Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 11:09:25 -0500 Subject: [PATCH 19/24] add jax to complete ci envs --- ci/requirements/environment-3.13.yml | 2 ++ ci/requirements/environment-windows-3.13.yml | 2 ++ ci/requirements/environment-windows.yml | 2 ++ ci/requirements/environment.yml | 2 ++ 4 files changed, 8 insertions(+) diff --git a/ci/requirements/environment-3.13.yml b/ci/requirements/environment-3.13.yml index dbb446f4454..937cb013711 100644 --- a/ci/requirements/environment-3.13.yml +++ b/ci/requirements/environment-3.13.yml @@ -47,3 +47,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment-windows-3.13.yml b/ci/requirements/environment-windows-3.13.yml index 448e3f70c0c..0d32fd13a96 100644 --- a/ci/requirements/environment-windows-3.13.yml +++ b/ci/requirements/environment-windows-3.13.yml @@ -42,3 +42,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3b2e6dc62e6..a9a53d0c1b1 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -42,3 +42,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 43938880592..364ae03666f 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -49,3 +49,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present From 472ae7e7e1fc499adc3598511e96475e8d7ab045 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 11:10:17 -0500 Subject: [PATCH 20/24] add pint and sparse to tests --- xarray/core/common.py | 3 +- xarray/tests/test_duck_array_wrapping.py | 154 +++++++++++++++++------ 2 files changed, 118 insertions(+), 39 deletions(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 8aaa153c1a8..32135996d3c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -1760,8 +1760,7 @@ def _full_like_variable( **from_array_kwargs, ) else: - xp = duck_array_ops.get_array_namespace(other.data) - data = xp.full_like(other.data, fill_value, dtype=dtype) + data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index ffaa1440e1a..05c0ab68bea 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -4,30 +4,107 @@ import xarray as xr -# TODO: how to test these in CI? -jnp = pytest.importorskip("jax.numpy") -cp = pytest.importorskip("cupy") - -NAMESPACES = [cp, jnp] - - -def get_test_dataarray(xp): - return xr.DataArray( - xp.asarray([[1, 2, 3, np.nan, 5]]), - dims=["y", "x"], - coords={"y": [1], "x": np.arange(5)}, - name="foo", - ) - - -@pytest.mark.parametrize("xp", NAMESPACES) -class TestTopLevelMethods: +# Don't run cupy in CI because it requires a GPU +NAMESPACE_ARRAYS = { + "jax.numpy": { + "array": "ndarray", + "constructor": "asarray", + "xfails": { + "rolling": "no sliding_window_view", + "rolling_mean": "no sliding_window_view", + }, + }, + "cupy": { + "array": "ndarray", + "constructor": "asarray", + "xfails": {"quantile": "no nanquantile"}, + }, + "pint": { + "array": "Quantity", + "constructor": "Quantity", + "xfails": { + "all": "returns a bool", + "any": "returns a bool", + "argmax": "returns an int", + "argmin": "returns an int", + "argsort": "returns an int", + "count": "returns an int", + "dot": "no tensordot", + "full_like": "should work, see: https://github.com/hgrecco/pint/pull/1669", + "idxmax": "returns the coordinate", + "idxmin": "returns the coordinate", + "isin": "returns a bool", + "isnull": "returns a bool", + "notnull": "returns a bool", + "rolling_mean": "no dispatch for numbagg/bottleneck", + "searchsorted": "returns an int", + "weighted": "no tensordot", + }, + }, + "sparse": { + "array": "COO", + "constructor": "COO", + "xfails": { + "cov": "dense output", + "corr": "no nanstd", + "cross": "no cross", + "count": "dense output", + "isin": "no isin", + "rolling": "no sliding_window_view", + "rolling_mean": "no sliding_window_view", + "weighted": "fill_value error", + "coarsen": "pad constant_values must be fill_value", + "quantile": "no non skipping version", + "differentiate": "no gradient", + "argmax": "no nan skipping version", + "argmin": "no nan skipping version", + "idxmax": "no nan skipping version", + "idxmin": "no nan skipping version", + "median": "no nan skipping version", + "std": "no nan skipping version", + "var": "no nan skipping version", + "cumsum": "no cumsum", + "cumprod": "no cumprod", + "argsort": "no argsort", + "conjugate": "no conjugate", + "searchsorted": "no searchsorted", + "shift": "pad constant_values must be fill_value", + "pad": "pad constant_values must be fill_value", + }, + }, +} + + +class _BaseTest: + def setup_for_test(self, request, namespace): + self.namespace = namespace + self.xp = pytest.importorskip(namespace) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["array"]) + self.constructor = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["constructor"]) + xarray_method = request.node.name.split("test_")[1].split("[")[0] + if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: + reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] + pytest.xfail(f"xfail for {self.namespace}: {reason}") + + def get_test_dataarray(self): + data = np.asarray([[1, 2, 3, np.nan, 5]]) + x = np.arange(5) + data = self.constructor(data) + return xr.DataArray( + data, + dims=["y", "x"], + coords={"y": [1], "x": x}, + name="foo", + ) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestTopLevelMethods(_BaseTest): @pytest.fixture(autouse=True) - def setUp(self, xp): - self.xp = xp - self.Array = xp.ndarray - self.x1 = get_test_dataarray(xp) - self.x2 = get_test_dataarray(xp).assign_coords(x=np.arange(2, 7)) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x1 = self.get_test_dataarray() + self.x2 = self.get_test_dataarray().assign_coords(x=np.arange(2, 7)) def test_apply_ufunc(self): func = lambda x: x + 1 @@ -83,13 +160,12 @@ def test_map_blocks(self): assert isinstance(result.data, self.Array) -@pytest.mark.parametrize("xp", NAMESPACES) -class TestDataArrayMethods: +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestDataArrayMethods(_BaseTest): @pytest.fixture(autouse=True) - def setUp(self, xp): - self.xp = xp - self.Array = xp.ndarray - self.x = get_test_dataarray(xp) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x = self.get_test_dataarray() def test_loc(self): result = self.x.loc[{"x": slice(1, 3)}] @@ -153,7 +229,8 @@ def test_where(self): assert isinstance(result.data, self.Array) def test_isin(self): - result = self.x.isin(self.xp.asarray([1])) + test_elements = self.constructor(np.asarray([1])) + result = self.x.isin(test_elements) assert isinstance(result.data, self.Array) def test_groupby(self): @@ -161,9 +238,13 @@ def test_groupby(self): assert isinstance(result.data, self.Array) def test_rolling(self): - if self.xp is jnp: - pytest.xfail("no sliding_window_view in jax") - result = self.x.rolling(x=3).mean() + result = self.x.rolling(x=3) + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_rolling_mean(self, skipna): + result = self.x.rolling(x=3).mean(skipna=skipna) assert isinstance(result.data, self.Array) @pytest.mark.xfail(reason="rolling_exp uses numbagg") @@ -194,14 +275,12 @@ def test_dot(self): @pytest.mark.parametrize("skipna", [True, False]) def test_quantile(self, skipna): - if self.xp is cp and skipna: - pytest.xfail("no nanquantile in cupy") result = self.x.quantile(0.5, skipna=skipna) assert isinstance(result.data, self.Array) def test_differentiate(self): # edge_order is not implemented in jax, and only supports passing None - edge_order = None if self.xp is jnp else 1 + edge_order = None if self.namespace == "jax.numpy" else 1 result = self.x.differentiate("x", edge_order=edge_order) assert isinstance(result.data, self.Array) @@ -318,7 +397,8 @@ def test_imag(self): assert isinstance(result.data, self.Array) def test_searchsorted(self): - result = self.x.squeeze().searchsorted(self.xp.asarray(3)) + v = self.constructor(np.asarray([3])) + result = self.x.squeeze().searchsorted(v) assert isinstance(result, self.Array) def test_round(self): From 5aa4a392b314544750ce0395395492b02dbddae6 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 11:19:25 -0500 Subject: [PATCH 21/24] remove from windows --- ci/requirements/environment-windows-3.13.yml | 2 -- ci/requirements/environment-windows.yml | 2 -- 2 files changed, 4 deletions(-) diff --git a/ci/requirements/environment-windows-3.13.yml b/ci/requirements/environment-windows-3.13.yml index 0d32fd13a96..448e3f70c0c 100644 --- a/ci/requirements/environment-windows-3.13.yml +++ b/ci/requirements/environment-windows-3.13.yml @@ -42,5 +42,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index a9a53d0c1b1..3b2e6dc62e6 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -42,5 +42,3 @@ dependencies: - toolz - typing_extensions - zarr - - pip: - - jax # no way to get cpu-only jaxlib from conda if gpu is present From 390df6f7715b46d557cb64fd32a21e6567e64e21 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Sat, 23 Nov 2024 12:40:31 -0500 Subject: [PATCH 22/24] mypy, xfail one more sparse --- xarray/tests/test_duck_array_wrapping.py | 31 ++++++++++++++++-------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 05c0ab68bea..63413aba1a3 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -7,21 +7,27 @@ # Don't run cupy in CI because it requires a GPU NAMESPACE_ARRAYS = { "jax.numpy": { - "array": "ndarray", - "constructor": "asarray", + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, "xfails": { "rolling": "no sliding_window_view", "rolling_mean": "no sliding_window_view", }, }, "cupy": { - "array": "ndarray", - "constructor": "asarray", + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, "xfails": {"quantile": "no nanquantile"}, }, "pint": { - "array": "Quantity", - "constructor": "Quantity", + "attrs": { + "array": "Quantity", + "constructor": "Quantity", + }, "xfails": { "all": "returns a bool", "any": "returns a bool", @@ -42,13 +48,16 @@ }, }, "sparse": { - "array": "COO", - "constructor": "COO", + "attrs": { + "array": "COO", + "constructor": "COO", + }, "xfails": { "cov": "dense output", "corr": "no nanstd", "cross": "no cross", "count": "dense output", + "dot": "fails on some platforms/versions", "isin": "no isin", "rolling": "no sliding_window_view", "rolling_mean": "no sliding_window_view", @@ -79,8 +88,10 @@ class _BaseTest: def setup_for_test(self, request, namespace): self.namespace = namespace self.xp = pytest.importorskip(namespace) - self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["array"]) - self.constructor = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["constructor"]) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["array"]) + self.constructor = getattr( + self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["constructor"] + ) xarray_method = request.node.name.split("test_")[1].split("[")[0] if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] From f6074d2fa3b9c2d3900cda25c4cb322f2c698bd1 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 25 Nov 2024 10:01:59 -0500 Subject: [PATCH 23/24] add dask and a few other methods --- xarray/tests/test_duck_array_wrapping.py | 73 +++++++++++++++++++----- 1 file changed, 60 insertions(+), 13 deletions(-) diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py index 63413aba1a3..59928dce370 100644 --- a/xarray/tests/test_duck_array_wrapping.py +++ b/xarray/tests/test_duck_array_wrapping.py @@ -6,22 +6,35 @@ # Don't run cupy in CI because it requires a GPU NAMESPACE_ARRAYS = { - "jax.numpy": { + "cupy": { "attrs": { "array": "ndarray", "constructor": "asarray", }, + "xfails": {"quantile": "no nanquantile"}, + }, + "dask.array": { + "attrs": { + "array": "Array", + "constructor": "from_array", + }, "xfails": { - "rolling": "no sliding_window_view", - "rolling_mean": "no sliding_window_view", + "argsort": "no argsort", + "conjugate": "conj but no conjugate", + "searchsorted": "dask.array.searchsorted but no Array.searchsorted", }, }, - "cupy": { + "jax.numpy": { "attrs": { "array": "ndarray", "constructor": "asarray", }, - "xfails": {"quantile": "no nanquantile"}, + "xfails": { + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + }, }, "pint": { "attrs": { @@ -42,7 +55,8 @@ "isin": "returns a bool", "isnull": "returns a bool", "notnull": "returns a bool", - "rolling_mean": "no dispatch for numbagg/bottleneck", + "rolling_reduce": "no dispatch for numbagg/bottleneck", + "cumulative_reduce": "no dispatch for numbagg/bottleneck", "searchsorted": "returns an int", "weighted": "no tensordot", }, @@ -59,8 +73,12 @@ "count": "dense output", "dot": "fails on some platforms/versions", "isin": "no isin", - "rolling": "no sliding_window_view", - "rolling_mean": "no sliding_window_view", + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + "coarsen_construct": "pad constant_values must be fill_value", + "coarsen_reduce": "pad constant_values must be fill_value", "weighted": "fill_value error", "coarsen": "pad constant_values must be fill_value", "quantile": "no non skipping version", @@ -119,7 +137,7 @@ def setUp(self, request, namespace): def test_apply_ufunc(self): func = lambda x: x + 1 - result = xr.apply_ufunc(func, self.x1) + result = xr.apply_ufunc(func, self.x1, dask="parallelized") assert isinstance(result.data, self.Array) def test_align(self): @@ -248,26 +266,51 @@ def test_groupby(self): result = self.x.groupby("x").mean() assert isinstance(result.data, self.Array) - def test_rolling(self): + def test_groupby_bins(self): + result = self.x.groupby_bins("x", bins=[0, 2, 4, 6]).mean() + assert isinstance(result.data, self.Array) + + def test_rolling_iter(self): result = self.x.rolling(x=3) elem = next(iter(result))[1] assert isinstance(elem.data, self.Array) + def test_rolling_construct(self): + result = self.x.rolling(x=3).construct(x="window") + assert isinstance(result.data, self.Array) + @pytest.mark.parametrize("skipna", [True, False]) - def test_rolling_mean(self, skipna): + def test_rolling_reduce(self, skipna): result = self.x.rolling(x=3).mean(skipna=skipna) assert isinstance(result.data, self.Array) @pytest.mark.xfail(reason="rolling_exp uses numbagg") - def test_rolling_exp(self): + def test_rolling_exp_reduce(self): result = self.x.rolling_exp(x=3).mean() assert isinstance(result.data, self.Array) + def test_cumulative_iter(self): + result = self.x.cumulative("x") + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_cumulative_construct(self): + result = self.x.cumulative("x").construct(x="window") + assert isinstance(result.data, self.Array) + + def test_cumulative_reduce(self): + result = self.x.cumulative("x").sum() + assert isinstance(result.data, self.Array) + def test_weighted(self): result = self.x.weighted(self.x.fillna(0)).mean() assert isinstance(result.data, self.Array) - def test_coarsen(self): + def test_coarsen_construct(self): + result = self.x.coarsen(x=2, boundary="pad").construct(x=["a", "b"]) + assert isinstance(result.data, self.Array) + + def test_coarsen_reduce(self): result = self.x.coarsen(x=2, boundary="pad").mean() assert isinstance(result.data, self.Array) @@ -391,6 +434,10 @@ def test_argsort(self): result = self.x.argsort() assert isinstance(result.data, self.Array) + def test_astype(self): + result = self.x.astype(int) + assert isinstance(result.data, self.Array) + def test_clip(self): result = self.x.clip(min=2.0, max=4.0) assert isinstance(result.data, self.Array) From bfd6aebbb0cb2f91f86a71a82e7593ae2b9365e3 Mon Sep 17 00:00:00 2001 From: Sam Levang Date: Mon, 25 Nov 2024 10:03:45 -0500 Subject: [PATCH 24/24] move whats new --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ab1cba8d9a6..906fd0a25b2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v.2024.11.1 (unreleased) New Features ~~~~~~~~~~~~ +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes @@ -85,9 +88,6 @@ New Features By `Sam Levang `_. - Speed up loading of large zarr stores using dask arrays. (:issue:`8902`) By `Deepak Cherian `_. -- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized - duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). - By `Sam Levang `_. Breaking Changes ~~~~~~~~~~~~~~~~