diff --git a/ci/requirements/environment-3.13.yml b/ci/requirements/environment-3.13.yml index dbb446f4454..937cb013711 100644 --- a/ci/requirements/environment-3.13.yml +++ b/ci/requirements/environment-3.13.yml @@ -47,3 +47,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 43938880592..364ae03666f 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -49,3 +49,5 @@ dependencies: - toolz - typing_extensions - zarr + - pip: + - jax # no way to get cpu-only jaxlib from conda if gpu is present diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 0da34df2c1a..906fd0a25b2 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -21,6 +21,9 @@ v.2024.11.1 (unreleased) New Features ~~~~~~~~~~~~ +- Better support wrapping additional array types (e.g. ``cupy`` or ``jax``) by calling generalized + duck array operations throughout more xarray methods. (:issue:`7848`, :pull:`9798`). + By `Sam Levang `_. Breaking changes diff --git a/xarray/core/array_api_compat.py b/xarray/core/array_api_compat.py index da072de5b69..e1e5d5c5bdc 100644 --- a/xarray/core/array_api_compat.py +++ b/xarray/core/array_api_compat.py @@ -1,5 +1,7 @@ import numpy as np +from xarray.namedarray.pycompat import array_type + def is_weak_scalar_type(t): return isinstance(t, bool | int | float | complex | str | bytes) @@ -42,3 +44,39 @@ def result_type(*arrays_and_dtypes, xp) -> np.dtype: return xp.result_type(*arrays_and_dtypes) else: return _future_array_api_result_type(*arrays_and_dtypes, xp=xp) + + +def get_array_namespace(*values): + def _get_single_namespace(x): + if hasattr(x, "__array_namespace__"): + return x.__array_namespace__() + elif isinstance(x, array_type("cupy")): + # cupy is fully compliant from xarray's perspective, but will not expose + # __array_namespace__ until at least v14. Special case it for now + import cupy as cp + + return cp + else: + return np + + namespaces = {_get_single_namespace(t) for t in values} + non_numpy = namespaces - {np} + + if len(non_numpy) > 1: + names = [module.__name__ for module in non_numpy] + raise TypeError(f"Mixed array types {names} are not supported.") + elif non_numpy: + [xp] = non_numpy + else: + xp = np + + return xp + + +def to_like_array(array, like): + # Mostly for cupy compatibility, because cupy binary ops require all cupy arrays + xp = get_array_namespace(like) + if xp is not np: + return xp.asarray(array) + # avoid casting things like pint quantities to numpy arrays + return array diff --git a/xarray/core/common.py b/xarray/core/common.py index 6f788f408d0..32135996d3c 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -496,7 +496,7 @@ def clip( keep_attrs = _get_keep_attrs(default=True) return apply_ufunc( - np.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" + duck_array_ops.clip, self, min, max, keep_attrs=keep_attrs, dask="allowed" ) def get_index(self, key: Hashable) -> pd.Index: @@ -1760,7 +1760,7 @@ def _full_like_variable( **from_array_kwargs, ) else: - data = np.full_like(other.data, fill_value, dtype=dtype) + data = duck_array_ops.full_like(other.data, fill_value, dtype=dtype) return Variable(dims=other.dims, data=data, attrs=other.attrs) diff --git a/xarray/core/computation.py b/xarray/core/computation.py index b15ed7f3f34..6e233425e95 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -24,6 +24,7 @@ from xarray.core import dtypes, duck_array_ops, utils from xarray.core.alignment import align, deep_align +from xarray.core.array_api_compat import to_like_array from xarray.core.common import zeros_like from xarray.core.duck_array_ops import datetime_to_numeric from xarray.core.formatting import limit_lines @@ -1702,7 +1703,7 @@ def cross( ) c = apply_ufunc( - np.cross, + duck_array_ops.cross, a, b, input_core_dims=[[dim], [dim]], @@ -2170,13 +2171,14 @@ def _calc_idxminmax( chunks = dict(zip(array.dims, array.chunks, strict=True)) dask_coord = chunkmanager.from_array(array[dim].data, chunks=chunks[dim]) data = dask_coord[duck_array_ops.ravel(indx.data)] - res = indx.copy(data=duck_array_ops.reshape(data, indx.shape)) - # we need to attach back the dim name - res.name = dim else: - res = array[dim][(indx,)] - # The dim is gone but we need to remove the corresponding coordinate. - del res.coords[dim] + arr_coord = to_like_array(array[dim].data, array.data) + data = arr_coord[duck_array_ops.ravel(indx.data)] + + # rebuild like the argmin/max output, and rename as the dim name + data = duck_array_ops.reshape(data, indx.shape) + res = indx.copy(data=data) + res.name = dim if skipna or (skipna is None and array.dtype.kind in na_dtypes): # Put the NaN values back in after removing them diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index e80ce5fa64a..ce8f93a37e5 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -55,6 +55,7 @@ align, ) from xarray.core.arithmetic import DatasetArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import ( DataWithCoords, _contains_datetime_like_objects, @@ -127,7 +128,7 @@ calculate_dimensions, ) from xarray.namedarray.parallelcompat import get_chunked_array_type, guess_chunkmanager -from xarray.namedarray.pycompat import array_type, is_chunked_array +from xarray.namedarray.pycompat import array_type, is_chunked_array, to_numpy from xarray.plot.accessor import DatasetPlotAccessor from xarray.util.deprecation_helpers import _deprecate_positional_args, deprecate_dims @@ -6622,7 +6623,7 @@ def dropna( array = self._variables[k] if dim in array.dims: dims = [d for d in array.dims if d != dim] - count += np.asarray(array.count(dims)) + count += to_numpy(array.count(dims).data) size += math.prod([self.sizes[d] for d in dims]) if thresh is not None: @@ -8736,16 +8737,17 @@ def _integrate_one(self, coord, datetime_unit=None, cumulative=False): coord_names.add(k) else: if k in self.data_vars and dim in v.dims: + coord_data = to_like_array(coord_var.data, like=v.data) if _contains_datetime_like_objects(v): v = datetime_to_numeric(v, datetime_unit=datetime_unit) if cumulative: integ = duck_array_ops.cumulative_trapezoid( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = v.dims else: integ = duck_array_ops.trapz( - v.data, coord_var.data, axis=v.get_axis_num(dim) + v.data, coord_data, axis=v.get_axis_num(dim) ) v_dims = list(v.dims) v_dims.remove(dim) diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 0b915166279..7e7333fd8ea 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -18,21 +18,16 @@ import pandas as pd from numpy import all as array_all # noqa: F401 from numpy import any as array_any # noqa: F401 -from numpy import concatenate as _concatenate from numpy import ( # noqa: F401 - full_like, - gradient, isclose, - isin, isnat, take, - tensordot, - transpose, unravel_index, ) from pandas.api.types import is_extension_array_dtype from xarray.core import dask_array_compat, dask_array_ops, dtypes, nputils +from xarray.core.array_api_compat import get_array_namespace from xarray.core.options import OPTIONS from xarray.core.utils import is_duck_array, is_duck_dask_array, module_available from xarray.namedarray.parallelcompat import get_chunked_array_type @@ -52,28 +47,6 @@ dask_available = module_available("dask") -def get_array_namespace(*values): - def _get_array_namespace(x): - if hasattr(x, "__array_namespace__"): - return x.__array_namespace__() - else: - return np - - namespaces = {_get_array_namespace(t) for t in values} - non_numpy = namespaces - {np} - - if len(non_numpy) > 1: - raise TypeError( - "cannot deal with more than one type supporting the array API at the same time" - ) - elif non_numpy: - [xp] = non_numpy - else: - xp = np - - return xp - - def einsum(*args, **kwargs): from xarray.core.options import OPTIONS @@ -82,7 +55,23 @@ def einsum(*args, **kwargs): return opt_einsum.contract(*args, **kwargs) else: - return np.einsum(*args, **kwargs) + xp = get_array_namespace(*args) + return xp.einsum(*args, **kwargs) + + +def tensordot(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.tensordot(*args, **kwargs) + + +def cross(*args, **kwargs): + xp = get_array_namespace(*args) + return xp.cross(*args, **kwargs) + + +def gradient(f, *varargs, axis=None, edge_order=1): + xp = get_array_namespace(f) + return xp.gradient(f, *varargs, axis=axis, edge_order=edge_order) def _dask_or_eager_func( @@ -131,15 +120,20 @@ def fail_on_dask_array_input(values, msg=None, func_name=None): "masked_invalid", eager_module=np.ma, dask_module="dask.array.ma" ) -# sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), -# so we need to hand-code this. -sliding_window_view = _dask_or_eager_func( - "sliding_window_view", - eager_module=np.lib.stride_tricks, - dask_module=dask_array_compat, - dask_only_kwargs=("automatic_rechunk",), - numpy_only_kwargs=("subok", "writeable"), -) + +def sliding_window_view(array, window_shape, axis=None, **kwargs): + # TODO: some libraries (e.g. jax) don't have this, implement an alternative? + xp = get_array_namespace(array) + # sliding_window_view will not dispatch arbitrary kwargs (automatic_rechunk), + # so we need to hand-code this. + func = _dask_or_eager_func( + "sliding_window_view", + eager_module=xp.lib.stride_tricks, + dask_module=dask_array_compat, + dask_only_kwargs=("automatic_rechunk",), + numpy_only_kwargs=("subok", "writeable"), + ) + return func(array, window_shape, axis=axis, **kwargs) def round(array): @@ -172,7 +166,9 @@ def isnull(data): ) ): # these types cannot represent missing values - return full_like(data, dtype=bool, fill_value=False) + # bool_ is for backwards compat with numpy<2, and cupy + dtype = xp.bool_ if hasattr(xp, "bool_") else xp.bool + return full_like(data, dtype=dtype, fill_value=False) else: # at this point, array should have dtype=object if isinstance(data, np.ndarray) or is_extension_array_dtype(data): @@ -213,11 +209,23 @@ def cumulative_trapezoid(y, x, axis): # Pad so that 'axis' has same length in result as it did in y pads = [(1, 0) if i == axis else (0, 0) for i in range(y.ndim)] - integrand = np.pad(integrand, pads, mode="constant", constant_values=0.0) + + xp = get_array_namespace(y, x) + integrand = xp.pad(integrand, pads, mode="constant", constant_values=0.0) return cumsum(integrand, axis=axis, skipna=False) +def full_like(a, fill_value, **kwargs): + xp = get_array_namespace(a) + return xp.full_like(a, fill_value, **kwargs) + + +def empty_like(a, **kwargs): + xp = get_array_namespace(a) + return xp.empty_like(a, **kwargs) + + def astype(data, dtype, **kwargs): if hasattr(data, "__array_namespace__"): xp = get_array_namespace(data) @@ -348,7 +356,8 @@ def array_notnull_equiv(arr1, arr2): def count(data, axis=None): """Count the number of non-NA in this array along the given axis or axes""" - return np.sum(np.logical_not(isnull(data)), axis=axis) + xp = get_array_namespace(data) + return xp.sum(xp.logical_not(isnull(data)), axis=axis) def sum_where(data, axis=None, dtype=None, where=None): @@ -363,7 +372,7 @@ def sum_where(data, axis=None, dtype=None, where=None): def where(condition, x, y): """Three argument where() with better dtype promotion rules.""" - xp = get_array_namespace(condition) + xp = get_array_namespace(condition, x, y) return xp.where(condition, *as_shared_dtype([x, y], xp=xp)) @@ -380,15 +389,25 @@ def fillna(data, other): return where(notnull(data), data, other) +def logical_not(data): + xp = get_array_namespace(data) + return xp.logical_not(data) + + +def clip(data, min=None, max=None): + xp = get_array_namespace(data) + return xp.clip(data, min, max) + + def concatenate(arrays, axis=0): """concatenate() with better dtype promotion rules.""" - # TODO: remove the additional check once `numpy` adds `concat` to its array namespace - if hasattr(arrays[0], "__array_namespace__") and not isinstance( - arrays[0], np.ndarray - ): - xp = get_array_namespace(arrays[0]) + # TODO: `concat` is the xp compliant name, but fallback to concatenate for + # older numpy and for cupy + xp = get_array_namespace(*arrays) + if hasattr(xp, "concat"): return xp.concat(as_shared_dtype(arrays, xp=xp), axis=axis) - return _concatenate(as_shared_dtype(arrays), axis=axis) + else: + return xp.concatenate(as_shared_dtype(arrays, xp=xp), axis=axis) def stack(arrays, axis=0): @@ -406,6 +425,26 @@ def ravel(array): return reshape(array, (-1,)) +def transpose(array, axes=None): + xp = get_array_namespace(array) + return xp.transpose(array, axes) + + +def moveaxis(array, source, destination): + xp = get_array_namespace(array) + return xp.moveaxis(array, source, destination) + + +def pad(array, pad_width, **kwargs): + xp = get_array_namespace(array) + return xp.pad(array, pad_width, **kwargs) + + +def quantile(array, q, axis=None, **kwargs): + xp = get_array_namespace(array) + return xp.quantile(array, q, axis=axis, **kwargs) + + @contextlib.contextmanager def _ignore_warnings_if(condition): if condition: @@ -747,6 +786,11 @@ def last(values, axis, skipna=None): return take(values, -1, axis=axis) +def isin(element, test_elements, **kwargs): + xp = get_array_namespace(element, test_elements) + return xp.isin(element, test_elements, **kwargs) + + def least_squares(lhs, rhs, rcond=None, skipna=False): """Return the coefficients and residuals of a least-squares fit.""" if is_duck_dask_array(rhs): diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index 7fbb63068c0..4894cf02be2 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -128,7 +128,7 @@ def nanmean(a, axis=None, dtype=None, out=None): "ignore", r"Mean of empty slice", category=RuntimeWarning ) - return np.nanmean(a, axis=axis, dtype=dtype) + return nputils.nanmean(a, axis=axis, dtype=dtype) def nanmedian(a, axis=None, out=None): diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index bf5dfa1bc32..3211ab296e6 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -7,6 +7,7 @@ import pandas as pd from packaging.version import Version +from xarray.core.array_api_compat import get_array_namespace from xarray.core.utils import is_duck_array, module_available from xarray.namedarray import pycompat @@ -179,6 +180,11 @@ def f(values, axis=None, **kwargs): dtype = kwargs.get("dtype") bn_func = getattr(bn, name, None) + xp = get_array_namespace(values) + if xp is not np: + func = getattr(xp, name, None) + if func is not None: + return func(values, axis=axis, **kwargs) if ( module_available("numbagg") and OPTIONS["use_numbagg"] @@ -229,6 +235,9 @@ def f(values, axis=None, **kwargs): # bottleneck does not take care dtype, min_count kwargs.pop("dtype", None) result = bn_func(values, axis=axis, **kwargs) + # bottleneck returns python scalars for reduction over all axes + if isinstance(result, float): + result = np.float64(result) else: result = getattr(npmodule, name)(values, axis=axis, **kwargs) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index cb16c3723ca..fde87841d32 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -708,6 +708,7 @@ def _array_reduce( ) del kwargs["dim"] + xp = duck_array_ops.get_array_namespace(self.obj.data) if ( OPTIONS["use_numbagg"] and module_available("numbagg") @@ -722,6 +723,7 @@ def _array_reduce( # TODO: we could also allow this, probably as part of a refactoring of this # module, so we can use the machinery in `self.reduce`. and self.ndim == 1 + and xp is np ): import numbagg @@ -744,6 +746,7 @@ def _array_reduce( or module_available("dask", "2024.11.0") ) and self.ndim == 1 + and xp is np ): return self._bottleneck_reduce( bottleneck_move_func, keep_attrs=keep_attrs, **kwargs diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 9f660d0878a..07113d66b5b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -19,6 +19,7 @@ import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils from xarray.core.arithmetic import VariableArithmetic +from xarray.core.array_api_compat import to_like_array from xarray.core.common import AbstractArray from xarray.core.extension_array import PandasExtensionArray from xarray.core.indexing import ( @@ -828,7 +829,7 @@ def __getitem__(self, key) -> Self: data = indexing.apply_indexer(indexable, indexer) if new_order: - data = np.moveaxis(data, range(len(new_order)), new_order) + data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) return self._finalize_indexing_result(dims, data) def _finalize_indexing_result(self, dims, data) -> Self: @@ -866,12 +867,15 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): # we need to invert the mask in order to pass data first. This helps # pint to choose the correct unit # TODO: revert after https://github.com/hgrecco/pint/issues/1019 is fixed - data = duck_array_ops.where(np.logical_not(mask), data, fill_value) + mask = to_like_array(mask, data) + data = duck_array_ops.where( + duck_array_ops.logical_not(mask), data, fill_value + ) else: # array cannot be indexed along dimensions of size 0, so just # build the mask directly instead. mask = indexing.create_mask(indexer, self.shape) - data = np.broadcast_to(fill_value, getattr(mask, "shape", ())) + data = duck_array_ops.broadcast_to(fill_value, getattr(mask, "shape", ())) if new_order: data = duck_array_ops.moveaxis(data, range(len(new_order)), new_order) @@ -902,7 +906,7 @@ def __setitem__(self, key, value): if new_order: value = duck_array_ops.asarray(value) value = value[(len(dims) - value.ndim) * (np.newaxis,) + (Ellipsis,)] - value = np.moveaxis(value, new_order, range(len(new_order))) + value = duck_array_ops.moveaxis(value, new_order, range(len(new_order))) indexable = as_indexable(self._data) indexing.set_with_indexer(indexable, index_tuple, value) @@ -1122,7 +1126,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): dim_pad = (width, 0) if count >= 0 else (0, width) pads = [(0, 0) if d != dim else dim_pad for d in self.dims] - data = np.pad( + data = duck_array_ops.pad( duck_array_ops.astype(trimmed_data, dtype), pads, mode="constant", @@ -1268,7 +1272,7 @@ def pad( if reflect_type is not None: pad_option_kwargs["reflect_type"] = reflect_type - array = np.pad( + array = duck_array_ops.pad( duck_array_ops.astype(self.data, dtype, copy=False), pad_width_by_index, mode=mode, @@ -1557,14 +1561,16 @@ def _unstack_once( if is_missing_values: dtype, fill_value = dtypes.maybe_promote(self.dtype) - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial( + duck_array_ops.full_like, fill_value=fill_value + ) else: dtype = self.dtype fill_value = dtypes.get_fill_value(dtype) - create_template = np.empty_like + create_template = duck_array_ops.empty_like else: dtype = self.dtype - create_template = partial(np.full_like, fill_value=fill_value) + create_template = partial(duck_array_ops.full_like, fill_value=fill_value) if sparse: # unstacking a dense multitindexed array to a sparse array @@ -1654,7 +1660,8 @@ def clip(self, min=None, max=None): """ from xarray.core.computation import apply_ufunc - return apply_ufunc(np.clip, self, min, max, dask="allowed") + xp = duck_array_ops.get_array_namespace(self.data) + return apply_ufunc(xp.clip, self, min, max, dask="allowed") def reduce( # type: ignore[override] self, @@ -1947,7 +1954,7 @@ def quantile( if skipna or (skipna is None and self.dtype.kind in "cfO"): _quantile_func = nputils.nanquantile else: - _quantile_func = np.quantile + _quantile_func = duck_array_ops.quantile if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -1961,11 +1968,14 @@ def quantile( if utils.is_scalar(dim): dim = [dim] + xp = duck_array_ops.get_array_namespace(self.data) + def _wrapper(npa, **kwargs): # move quantile axis to end. required for apply_ufunc - return np.moveaxis(_quantile_func(npa, **kwargs), 0, -1) + return xp.moveaxis(_quantile_func(npa, **kwargs), 0, -1) - axis = np.arange(-1, -1 * len(dim) - 1, -1) + # jax requires hashable + axis = tuple(range(-1, -1 * len(dim) - 1, -1)) kwargs = {"q": q, "axis": axis, "method": method} diff --git a/xarray/tests/test_duck_array_wrapping.py b/xarray/tests/test_duck_array_wrapping.py new file mode 100644 index 00000000000..59928dce370 --- /dev/null +++ b/xarray/tests/test_duck_array_wrapping.py @@ -0,0 +1,510 @@ +import numpy as np +import pandas as pd +import pytest + +import xarray as xr + +# Don't run cupy in CI because it requires a GPU +NAMESPACE_ARRAYS = { + "cupy": { + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, + "xfails": {"quantile": "no nanquantile"}, + }, + "dask.array": { + "attrs": { + "array": "Array", + "constructor": "from_array", + }, + "xfails": { + "argsort": "no argsort", + "conjugate": "conj but no conjugate", + "searchsorted": "dask.array.searchsorted but no Array.searchsorted", + }, + }, + "jax.numpy": { + "attrs": { + "array": "ndarray", + "constructor": "asarray", + }, + "xfails": { + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + }, + }, + "pint": { + "attrs": { + "array": "Quantity", + "constructor": "Quantity", + }, + "xfails": { + "all": "returns a bool", + "any": "returns a bool", + "argmax": "returns an int", + "argmin": "returns an int", + "argsort": "returns an int", + "count": "returns an int", + "dot": "no tensordot", + "full_like": "should work, see: https://github.com/hgrecco/pint/pull/1669", + "idxmax": "returns the coordinate", + "idxmin": "returns the coordinate", + "isin": "returns a bool", + "isnull": "returns a bool", + "notnull": "returns a bool", + "rolling_reduce": "no dispatch for numbagg/bottleneck", + "cumulative_reduce": "no dispatch for numbagg/bottleneck", + "searchsorted": "returns an int", + "weighted": "no tensordot", + }, + }, + "sparse": { + "attrs": { + "array": "COO", + "constructor": "COO", + }, + "xfails": { + "cov": "dense output", + "corr": "no nanstd", + "cross": "no cross", + "count": "dense output", + "dot": "fails on some platforms/versions", + "isin": "no isin", + "rolling_construct": "no sliding_window_view", + "rolling_reduce": "no sliding_window_view", + "cumulative_construct": "no sliding_window_view", + "cumulative_reduce": "no sliding_window_view", + "coarsen_construct": "pad constant_values must be fill_value", + "coarsen_reduce": "pad constant_values must be fill_value", + "weighted": "fill_value error", + "coarsen": "pad constant_values must be fill_value", + "quantile": "no non skipping version", + "differentiate": "no gradient", + "argmax": "no nan skipping version", + "argmin": "no nan skipping version", + "idxmax": "no nan skipping version", + "idxmin": "no nan skipping version", + "median": "no nan skipping version", + "std": "no nan skipping version", + "var": "no nan skipping version", + "cumsum": "no cumsum", + "cumprod": "no cumprod", + "argsort": "no argsort", + "conjugate": "no conjugate", + "searchsorted": "no searchsorted", + "shift": "pad constant_values must be fill_value", + "pad": "pad constant_values must be fill_value", + }, + }, +} + + +class _BaseTest: + def setup_for_test(self, request, namespace): + self.namespace = namespace + self.xp = pytest.importorskip(namespace) + self.Array = getattr(self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["array"]) + self.constructor = getattr( + self.xp, NAMESPACE_ARRAYS[namespace]["attrs"]["constructor"] + ) + xarray_method = request.node.name.split("test_")[1].split("[")[0] + if xarray_method in NAMESPACE_ARRAYS[namespace]["xfails"]: + reason = NAMESPACE_ARRAYS[namespace]["xfails"][xarray_method] + pytest.xfail(f"xfail for {self.namespace}: {reason}") + + def get_test_dataarray(self): + data = np.asarray([[1, 2, 3, np.nan, 5]]) + x = np.arange(5) + data = self.constructor(data) + return xr.DataArray( + data, + dims=["y", "x"], + coords={"y": [1], "x": x}, + name="foo", + ) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestTopLevelMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x1 = self.get_test_dataarray() + self.x2 = self.get_test_dataarray().assign_coords(x=np.arange(2, 7)) + + def test_apply_ufunc(self): + func = lambda x: x + 1 + result = xr.apply_ufunc(func, self.x1, dask="parallelized") + assert isinstance(result.data, self.Array) + + def test_align(self): + result = xr.align(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_broadcast(self): + result = xr.broadcast(self.x1, self.x2) + assert isinstance(result[0].data, self.Array) + assert isinstance(result[1].data, self.Array) + + def test_concat(self): + result = xr.concat([self.x1, self.x2], dim="x") + assert isinstance(result.data, self.Array) + + def test_merge(self): + result = xr.merge([self.x1, self.x2], compat="override") + assert isinstance(result.foo.data, self.Array) + + def test_where(self): + x1, x2 = xr.align(self.x1, self.x2, join="inner") + result = xr.where(x1 > 2, x1, x2) + assert isinstance(result.data, self.Array) + + def test_full_like(self): + result = xr.full_like(self.x1, 0) + assert isinstance(result.data, self.Array) + + def test_cov(self): + result = xr.cov(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_corr(self): + result = xr.corr(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_cross(self): + x1, x2 = xr.align(self.x1.squeeze(), self.x2.squeeze(), join="inner") + result = xr.cross(x1, x2, dim="x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = xr.dot(self.x1, self.x2) + assert isinstance(result.data, self.Array) + + def test_map_blocks(self): + result = xr.map_blocks(lambda x: x + 1, self.x1) + assert isinstance(result.data, self.Array) + + +@pytest.mark.parametrize("namespace", NAMESPACE_ARRAYS) +class TestDataArrayMethods(_BaseTest): + @pytest.fixture(autouse=True) + def setUp(self, request, namespace): + self.setup_for_test(request, namespace) + self.x = self.get_test_dataarray() + + def test_loc(self): + result = self.x.loc[{"x": slice(1, 3)}] + assert isinstance(result.data, self.Array) + + def test_isel(self): + result = self.x.isel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_sel(self): + result = self.x.sel(x=slice(1, 3)) + assert isinstance(result.data, self.Array) + + def test_squeeze(self): + result = self.x.squeeze("y") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interp uses numpy and scipy") + def test_interp(self): + # TODO: some cases could be made to work + result = self.x.interp(x=2.5) + assert isinstance(result.data, self.Array) + + def test_isnull(self): + result = self.x.isnull() + assert isinstance(result.data, self.Array) + + def test_notnull(self): + result = self.x.notnull() + assert isinstance(result.data, self.Array) + + def test_count(self): + result = self.x.count() + assert isinstance(result.data, self.Array) + + def test_dropna(self): + result = self.x.dropna(dim="x") + assert isinstance(result.data, self.Array) + + def test_fillna(self): + result = self.x.fillna(0) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="ffill uses bottleneck or numbagg") + def test_ffill(self): + result = self.x.ffill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="bfill uses bottleneck or numbagg") + def test_bfill(self): + result = self.x.bfill() + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="interpolate_na uses numpy and scipy") + def test_interpolate_na(self): + result = self.x.interpolate_na() + assert isinstance(result.data, self.Array) + + def test_where(self): + result = self.x.where(self.x > 2) + assert isinstance(result.data, self.Array) + + def test_isin(self): + test_elements = self.constructor(np.asarray([1])) + result = self.x.isin(test_elements) + assert isinstance(result.data, self.Array) + + def test_groupby(self): + result = self.x.groupby("x").mean() + assert isinstance(result.data, self.Array) + + def test_groupby_bins(self): + result = self.x.groupby_bins("x", bins=[0, 2, 4, 6]).mean() + assert isinstance(result.data, self.Array) + + def test_rolling_iter(self): + result = self.x.rolling(x=3) + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_rolling_construct(self): + result = self.x.rolling(x=3).construct(x="window") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_rolling_reduce(self, skipna): + result = self.x.rolling(x=3).mean(skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rolling_exp uses numbagg") + def test_rolling_exp_reduce(self): + result = self.x.rolling_exp(x=3).mean() + assert isinstance(result.data, self.Array) + + def test_cumulative_iter(self): + result = self.x.cumulative("x") + elem = next(iter(result))[1] + assert isinstance(elem.data, self.Array) + + def test_cumulative_construct(self): + result = self.x.cumulative("x").construct(x="window") + assert isinstance(result.data, self.Array) + + def test_cumulative_reduce(self): + result = self.x.cumulative("x").sum() + assert isinstance(result.data, self.Array) + + def test_weighted(self): + result = self.x.weighted(self.x.fillna(0)).mean() + assert isinstance(result.data, self.Array) + + def test_coarsen_construct(self): + result = self.x.coarsen(x=2, boundary="pad").construct(x=["a", "b"]) + assert isinstance(result.data, self.Array) + + def test_coarsen_reduce(self): + result = self.x.coarsen(x=2, boundary="pad").mean() + assert isinstance(result.data, self.Array) + + def test_resample(self): + time_coord = pd.date_range("2000-01-01", periods=5) + result = self.x.assign_coords(x=time_coord).resample(x="D").mean() + assert isinstance(result.data, self.Array) + + def test_diff(self): + result = self.x.diff("x") + assert isinstance(result.data, self.Array) + + def test_dot(self): + result = self.x.dot(self.x) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_quantile(self, skipna): + result = self.x.quantile(0.5, skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_differentiate(self): + # edge_order is not implemented in jax, and only supports passing None + edge_order = None if self.namespace == "jax.numpy" else 1 + result = self.x.differentiate("x", edge_order=edge_order) + assert isinstance(result.data, self.Array) + + def test_integrate(self): + result = self.x.integrate("x") + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="polyfit uses numpy linalg") + def test_polyfit(self): + # TODO: this could work, there are just a lot of different linalg calls + result = self.x.polyfit("x", 1) + assert isinstance(result.polyfit_coefficients.data, self.Array) + + def test_map_blocks(self): + result = self.x.map_blocks(lambda x: x + 1) + assert isinstance(result.data, self.Array) + + def test_all(self): + result = self.x.all(dim="x") + assert isinstance(result.data, self.Array) + + def test_any(self): + result = self.x.any(dim="x") + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmax(self, skipna): + result = self.x.argmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_argmin(self, skipna): + result = self.x.argmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmax(self, skipna): + result = self.x.idxmax(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_idxmin(self, skipna): + result = self.x.idxmin(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_max(self, skipna): + result = self.x.max(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_min(self, skipna): + result = self.x.min(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_mean(self, skipna): + result = self.x.mean(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_median(self, skipna): + result = self.x.median(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_prod(self, skipna): + result = self.x.prod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_sum(self, skipna): + result = self.x.sum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_std(self, skipna): + result = self.x.std(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_var(self, skipna): + result = self.x.var(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumsum(self, skipna): + result = self.x.cumsum(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + @pytest.mark.parametrize("skipna", [True, False]) + def test_cumprod(self, skipna): + result = self.x.cumprod(dim="x", skipna=skipna) + assert isinstance(result.data, self.Array) + + def test_argsort(self): + result = self.x.argsort() + assert isinstance(result.data, self.Array) + + def test_astype(self): + result = self.x.astype(int) + assert isinstance(result.data, self.Array) + + def test_clip(self): + result = self.x.clip(min=2.0, max=4.0) + assert isinstance(result.data, self.Array) + + def test_conj(self): + result = self.x.conj() + assert isinstance(result.data, self.Array) + + def test_conjugate(self): + result = self.x.conjugate() + assert isinstance(result.data, self.Array) + + def test_imag(self): + result = self.x.imag + assert isinstance(result.data, self.Array) + + def test_searchsorted(self): + v = self.constructor(np.asarray([3])) + result = self.x.squeeze().searchsorted(v) + assert isinstance(result, self.Array) + + def test_round(self): + result = self.x.round() + assert isinstance(result.data, self.Array) + + def test_real(self): + result = self.x.real + assert isinstance(result.data, self.Array) + + def test_T(self): + result = self.x.T + assert isinstance(result.data, self.Array) + + @pytest.mark.xfail(reason="rank uses bottleneck") + def test_rank(self): + # TODO: scipy has rankdata, as does jax, so this can work + result = self.x.rank() + assert isinstance(result.data, self.Array) + + def test_transpose(self): + result = self.x.transpose() + assert isinstance(result.data, self.Array) + + def test_stack(self): + result = self.x.stack(z=("x", "y")) + assert isinstance(result.data, self.Array) + + def test_unstack(self): + result = self.x.stack(z=("x", "y")).unstack("z") + assert isinstance(result.data, self.Array) + + def test_shift(self): + result = self.x.shift(x=1) + assert isinstance(result.data, self.Array) + + def test_roll(self): + result = self.x.roll(x=1) + assert isinstance(result.data, self.Array) + + def test_pad(self): + result = self.x.pad(x=1) + assert isinstance(result.data, self.Array) + + def test_sortby(self): + result = self.x.sortby("x") + assert isinstance(result.data, self.Array) + + def test_broadcast_like(self): + result = self.x.broadcast_like(self.x) + assert isinstance(result.data, self.Array) diff --git a/xarray/tests/test_strategies.py b/xarray/tests/test_strategies.py index 798f5f732d1..48819333ca2 100644 --- a/xarray/tests/test_strategies.py +++ b/xarray/tests/test_strategies.py @@ -13,6 +13,7 @@ from hypothesis import given from hypothesis.extra.array_api import make_strategies_namespace +from xarray.core.options import set_options from xarray.core.variable import Variable from xarray.testing.strategies import ( attrs, @@ -267,14 +268,14 @@ def test_mean(self, data, var): Test that given a Variable of at least one dimension, the mean of the Variable is always equal to the mean of the underlying array. """ + with set_options(use_numbagg=False): + # specify arbitrary reduction along at least one dimension + reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) - # specify arbitrary reduction along at least one dimension - reduction_dims = data.draw(unique_subset_of(var.dims, min_size=1)) + # create expected result (using nanmean because arrays with Nans will be generated) + reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) + expected = np.nanmean(var.data, axis=reduction_axes) - # create expected result (using nanmean because arrays with Nans will be generated) - reduction_axes = tuple(var.get_axis_num(dim) for dim in reduction_dims) - expected = np.nanmean(var.data, axis=reduction_axes) - - # assert property is always satisfied - result = var.mean(dim=reduction_dims).data - npt.assert_equal(expected, result) + # assert property is always satisfied + result = var.mean(dim=reduction_dims).data + npt.assert_equal(expected, result) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index 9c6f50037d3..1461489e731 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -1978,26 +1978,27 @@ def test_reduce_funcs(self): def test_reduce_keepdims(self): v = Variable(["x", "y"], self.d) - assert_identical( - v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) - ) - assert_identical( - v.mean(dim="x", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), - ) - assert_identical( - v.mean(dim="y", keepdims=True), - Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), - ) - assert_identical( - v.mean(dim=["y", "x"], keepdims=True), - Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), - ) + with set_options(use_numbagg=False): + assert_identical( + v.mean(keepdims=True), Variable(v.dims, np.mean(self.d, keepdims=True)) + ) + assert_identical( + v.mean(dim="x", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=0, keepdims=True)), + ) + assert_identical( + v.mean(dim="y", keepdims=True), + Variable(v.dims, np.mean(self.d, axis=1, keepdims=True)), + ) + assert_identical( + v.mean(dim=["y", "x"], keepdims=True), + Variable(v.dims, np.mean(self.d, axis=(1, 0), keepdims=True)), + ) - v = Variable([], 1.0) - assert_identical( - v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) - ) + v = Variable([], 1.0) + assert_identical( + v.mean(keepdims=True), Variable([], np.mean(v.data, keepdims=True)) + ) @requires_dask def test_reduce_keepdims_dask(self):