Skip to content

Commit

Permalink
Add @ operator type hints for Series (#1047)
Browse files Browse the repository at this point in the history
* Add @ operator type hints for Series

* Fix test

* Formatting

* Fix test
  • Loading branch information
loicdiridollou authored Nov 21, 2024
1 parent 92bd9cb commit 03396ef
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
14 changes: 12 additions & 2 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -800,8 +800,18 @@ class Series(IndexOpsMixin[S1], NDFrame):
def dot(
self, other: ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | Index[S1]
) -> np.ndarray: ...
def __matmul__(self, other): ...
def __rmatmul__(self, other): ...
@overload
def __matmul__(self, other: Series) -> Scalar: ...
@overload
def __matmul__(self, other: DataFrame) -> Series: ...
@overload
def __matmul__(self, other: np.ndarray) -> np.ndarray: ...
@overload
def __rmatmul__(self, other: Series) -> Scalar: ...
@overload
def __rmatmul__(self, other: DataFrame) -> Series: ...
@overload
def __rmatmul__(self, other: np.ndarray) -> np.ndarray: ...
@overload
def searchsorted(
self,
Expand Down
13 changes: 7 additions & 6 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,16 +1238,17 @@ def test_types_as_type() -> None:


def test_types_dot() -> None:
"""Test typing of multiplication methods (dot and @) for Series."""
s1 = pd.Series([0, 1, 2, 3])
s2 = pd.Series([-1, 2, -3, 4])
df1 = pd.DataFrame([[0, 1], [-2, 3], [4, -5], [6, 7]])
n1 = np.array([[0, 1], [1, 2], [-1, -1], [2, 0]])
sc1: Scalar = s1.dot(s2)
sc2: Scalar = s1 @ s2
s3: pd.Series = s1.dot(df1)
s4: pd.Series = s1 @ df1
n2: np.ndarray = s1.dot(n1)
n3: np.ndarray = s1 @ n1
check(assert_type(s1.dot(s2), Scalar), np.int64)
check(assert_type(s1 @ s2, Scalar), np.int64)
check(assert_type(s1.dot(df1), "pd.Series[int]"), pd.Series, np.int64)
check(assert_type(s1 @ df1, pd.Series), pd.Series)
check(assert_type(s1.dot(n1), np.ndarray), np.ndarray)
check(assert_type(s1 @ n1, np.ndarray), np.ndarray)


def test_series_loc_setitem() -> None:
Expand Down

0 comments on commit 03396ef

Please sign in to comment.