Skip to content

Commit

Permalink
Simple check for numpy 2.0 compatbility (#570)
Browse files Browse the repository at this point in the history
Enable compatibility with numpy 2.0

* Fixed some copy keyword arguments to adhere to new numpy standard
* Fixed np.ptp, which does no longer exist in numpy 2
* Remove restriction to numpy<2
* Replace `as_strided`, which is not available in numpy 2.0
* Fixed type issues because we support numpy 1 and 2
  • Loading branch information
david-zwicker authored Jun 17, 2024
1 parent 244ba62 commit 9de5764
Show file tree
Hide file tree
Showing 16 changed files with 50 additions and 28 deletions.
2 changes: 1 addition & 1 deletion pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def __init__(
dof += len(this_data)

# initialize the data from the individual fields
data_arr = number_array(fields_data, dtype=dtype, copy=False)
data_arr = number_array(fields_data, dtype=dtype, copy=None)

# initialize the class
super().__init__(grid, data_arr, label=label)
Expand Down
4 changes: 2 additions & 2 deletions pde/fields/datafield_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
# use full data without copying (unless necessary)
if data is None or isinstance(data, str):
raise ValueError("`data` must be supplied if with_ghost_cells==True")
data_arr = number_array(data, dtype=dtype, copy=False)
data_arr = number_array(data, dtype=dtype, copy=None)
super().__init__(grid, data=data_arr, label=label)

else:
Expand Down Expand Up @@ -108,7 +108,7 @@ def __init__(

else:
# initialize empty data and set the valid data
data_arr = number_array(data, dtype=dtype, copy=False)
data_arr = number_array(data, dtype=dtype, copy=None)
empty_data = np.empty(full_shape, dtype=data_arr.dtype)
super().__init__(grid, data=empty_data, label=label)
self.data = data_arr
Expand Down
7 changes: 3 additions & 4 deletions pde/pdes/cahn_hilliard.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,9 @@ def _make_pde_rhs_numba( # type: ignore
An example for the state defining the grid and data types
Returns:
A function with signature `(state_data, t)`, which can be called
with an instance of :class:`~numpy.ndarray` of the state data and
the time to obtained an instance of :class:`~numpy.ndarray` giving
the evolution rate.
A function with signature `(state_data, t)`, which can be called with an
instance of :class:`~numpy.ndarray` of the state data and the time to
obtained an instance of :class:`~numpy.ndarray` giving the evolution rate.
"""
arr_type = nb.typeof(state.data)
signature = arr_type(arr_type, nb.double)
Expand Down
2 changes: 1 addition & 1 deletion pde/tools/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self, x, y, sigma: float | None = None):
self.y = self.y[idx]

if sigma is None:
self.sigma = float(self.sigma_auto_scale * self.x.ptp() / len(self.x))
self.sigma = float(self.sigma_auto_scale * np.ptp(self.x) / len(self.x))
else:
self.sigma = sigma

Expand Down
23 changes: 14 additions & 9 deletions pde/tools/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def preserve_scalars(method: TFunc) -> TFunc:

@functools.wraps(method)
def wrapper(self, *args):
args = [number_array(arg, copy=False) for arg in args]
args = [number_array(arg, copy=None) for arg in args]
if args[0].ndim == 0:
args = [arg[None] for arg in args]
return method(self, *args)[0]
Expand Down Expand Up @@ -362,37 +362,42 @@ def get_common_dtype(*args):


def number_array(
data: ArrayLike, dtype: DTypeLike = None, copy: bool = True
data: ArrayLike, dtype: DTypeLike = None, copy: bool | None = None
) -> np.ndarray:
"""convert an array with arbitrary dtype either to np.double or np.cdouble
"""convert data into an array, assuming float numbers if no dtype is given
Args:
data (:class:`~numpy.ndarray`):
The data that needs to be converted to a float array. This can also be any
The data that needs to be converted to a number array. This can also be any
iterable of numbers.
dtype (numpy dtype):
The data type of the field. All the numpy dtypes are supported. If omitted,
it will be determined from `data` automatically.
it will be :class:`~numpy.double` unless `data` contains complex numbers in
which case it will be :class:`~numpy.cdouble`.
copy (bool):
Whether the data must be copied (in which case the original array is left
untouched). Note that data will always be copied when changing the dtype.
untouched). The default `None` implies that data is only copied if
necessary, e.g., when changing the dtype.
Returns:
:class:`~numpy.ndarray`: An array with the correct dtype
"""
if np.__version__.startswith("1") and copy is None:
copy = False # fall-back for numpy 1

if dtype is None:
# dtype needs to be determined automatically
try:
# convert the result to a numpy array with the given dtype
result = np.array(data, dtype=get_common_dtype(data), copy=copy)
result = np.array(data, dtype=get_common_dtype(data), copy=copy) # type: ignore
except TypeError:
# Conversion can fail when `data` contains a complex sympy number, i.e.,
# sympy.I. In this case, we simply try to convert the expression using a
# complex dtype
result = np.array(data, dtype=np.cdouble, copy=copy)
result = np.array(data, dtype=np.cdouble, copy=copy) # type: ignore

else:
# a specific dtype is requested
result = np.array(data, dtype=np.dtype(dtype), copy=copy)
result = np.array(data, dtype=np.dtype(dtype), copy=copy) # type: ignore

return result
2 changes: 1 addition & 1 deletion pde/tools/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def array_constructor() -> np.ndarray:
"""helper that reconstructs the array from the pointer and structural info"""
data: np.ndarray = nb.carray(address_as_void_pointer(data_addr), shape, dtype)
if strides is not None:
data = np.lib.index_tricks.as_strided(data, shape, strides) # type: ignore
data = np.lib.stride_tricks.as_strided(data, shape, strides)
return data

return array_constructor # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion pde/tools/resources/requirements_basic.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# These are the basic requirements for the package
matplotlib>=3.1
numba>=0.59
numpy>=1.22,<2
numpy>=1.22
scipy>=1.10
sympy>=1.9
tqdm>=4.66
2 changes: 1 addition & 1 deletion pde/tools/resources/requirements_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ffmpeg-python>=0.2
h5py>=2.10
matplotlib>=3.1
numba>=0.59
numpy>=1.22,<2
numpy>=1.22
pandas>=2
py-modelrunner>=0.18
rocket-fft>=0.2.4
Expand Down
2 changes: 1 addition & 1 deletion pde/tools/resources/requirements_mpi.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ matplotlib>=3.1
mpi4py>=3
numba>=0.59
numba-mpi>=0.22
numpy>=1.22,<2
numpy>=1.22
pandas>=2
scipy>=1.10
sympy>=1.9
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ classifiers = [
]

# Requirements for setuptools
dependencies = ["matplotlib>=3.1", "numba>=0.59", "numpy>=1.22,<2", "scipy>=1.10", "sympy>=1.9", "tqdm>=4.66"]
dependencies = ["matplotlib>=3.1", "numba>=0.59", "numpy>=1.22", "scipy>=1.10", "sympy>=1.9", "tqdm>=4.66"]

[project.optional-dependencies]
io = ["h5py>=2.10", "pandas>=2", "ffmpeg-python>=0.2"]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
matplotlib>=3.1
numba>=0.59
numpy>=1.22,<2
numpy>=1.22
scipy>=1.10
sympy>=1.9
tqdm>=4.66
2 changes: 1 addition & 1 deletion scripts/create_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def line(self, relation: str = ">=") -> str:
),
Requirement(
name="numpy",
version_min="1.22,<2",
version_min="1.22",
usage="Handling numerical data",
essential=True,
),
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements_full.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ ffmpeg-python>=0.2
h5py>=2.10
matplotlib>=3.1
numba>=0.59
numpy>=1.22,<2
numpy>=1.22
pandas>=2
py-modelrunner>=0.18
rocket-fft>=0.2.4
Expand Down
2 changes: 1 addition & 1 deletion tests/requirements_min.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# These are the minimal requirements used to test compatibility
matplotlib~=3.1
numba~=0.59
numpy~=1.22,<2
numpy~=1.22
scipy~=1.10
sympy~=1.9
tqdm~=4.66
2 changes: 1 addition & 1 deletion tests/requirements_mpi.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ matplotlib>=3.1
mpi4py>=3
numba>=0.59
numba-mpi>=0.22
numpy>=1.22,<2
numpy>=1.22
pandas>=2
scipy>=1.10
sympy>=1.9
Expand Down
20 changes: 19 additions & 1 deletion tests/tools/test_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,15 @@
"""

import numpy as np
import pytest

from pde.tools.numba import Counter, flat_idx, jit, numba_environment
from pde.tools.numba import (
Counter,
flat_idx,
jit,
make_array_constructor,
numba_environment,
)


def test_environment():
Expand Down Expand Up @@ -45,3 +52,14 @@ def test_counter():
c2 = Counter(3)
assert c1 is not c2
assert c1 == c2


@pytest.mark.parametrize(
"arr", [np.arange(5), np.linspace(0, 1, 3), np.arange(12).reshape(3, 4)[1:, 2:]]
)
def test_make_array_constructor(arr):
"""test implementation to create array"""
constructor = jit(make_array_constructor(arr))
arr2 = constructor()
np.testing.assert_equal(arr, arr2)
assert np.shares_memory(arr, arr2)

0 comments on commit 9de5764

Please sign in to comment.