Skip to content

Commit

Permalink
Fix Series.rename_axis and DataFrame.rename_axis typing (#1048)
Browse files Browse the repository at this point in the history
* Add `hdf5` requirement to setup instructions

* Fix `Series.rename_axis` and `DataFrame.rename_axis` typing

* Fix comment

* Use `check(assert_type(...))` framework in tests

* Combine loops

* Restrict `hdf5` instruction to macOS

* Inline `check(assert_type(...))` calls

* Check invalid usage of `columns`
  • Loading branch information
dpoznik authored Nov 22, 2024
1 parent 47fc9b6 commit d289ef5
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 13 deletions.
1 change: 1 addition & 0 deletions docs/setup.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## Set Up Environment

- Make sure you have `python >= 3.10` installed.
- If using macOS, you may need to install `hdf5` (e.g., via `brew install hdf5`).
- Install poetry: `pip install 'poetry>=1.8'`
- Install the project dependencies: `poetry update`
- Enter the virtual environment: `poetry shell`
Expand Down
16 changes: 10 additions & 6 deletions pandas-stubs/core/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2070,40 +2070,44 @@ class DataFrame(NDFrame, OpsMixin):
limit: int | None = ...,
tolerance=...,
) -> DataFrame: ...
# Rename axis with `mapper`, `axis`, and `inplace=True`
@overload
def rename_axis(
self,
mapper=...,
mapper: Scalar | ListLike | None = ...,
*,
axis: Axis | None = ...,
copy: _bool = ...,
*,
inplace: Literal[True],
) -> None: ...
# Rename axis with `mapper`, `axis`, and `inplace=False`
@overload
def rename_axis(
self,
mapper=...,
mapper: Scalar | ListLike | None = ...,
*,
axis: Axis | None = ...,
copy: _bool = ...,
*,
inplace: Literal[False] = ...,
) -> DataFrame: ...
# Rename axis with `index` and/or `columns` and `inplace=True`
@overload
def rename_axis(
self,
*,
index: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
columns: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
copy: _bool = ...,
*,
inplace: Literal[True],
) -> None: ...
# Rename axis with `index` and/or `columns` and `inplace=False`
@overload
def rename_axis(
self,
*,
index: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
columns: _str | Sequence[_str] | dict[_str | int, _str] | Callable | None = ...,
copy: _bool = ...,
*,
inplace: Literal[False] = ...,
) -> DataFrame: ...
def rfloordiv(
Expand Down
29 changes: 23 additions & 6 deletions pandas-stubs/core/series.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2076,24 +2076,41 @@ class Series(IndexOpsMixin[S1], NDFrame):
numeric_only: _bool = ...,
**kwargs,
) -> Scalar: ...
# Rename axis with `mapper`, `axis`, and `inplace=True`
@overload
def rename_axis(
self,
mapper: Scalar | ListLike = ...,
index: Scalar | ListLike | Callable | dict | None = ...,
columns: Scalar | ListLike | Callable | dict | None = ...,
mapper: Scalar | ListLike | None = ...,
*,
axis: AxisIndex | None = ...,
copy: _bool = ...,
inplace: Literal[True],
) -> None: ...
# Rename axis with `mapper`, `axis`, and `inplace=False`
@overload
def rename_axis(
self,
mapper: Scalar | ListLike | None = ...,
*,
axis: AxisIndex | None = ...,
copy: _bool = ...,
inplace: Literal[False] = ...,
) -> Self: ...
# Rename axis with `index` and `inplace=True`
@overload
def rename_axis(
self,
*,
index: Scalar | ListLike | Callable | dict | None = ...,
copy: _bool = ...,
inplace: Literal[True],
) -> None: ...
# Rename axis with `index` and `inplace=False`
@overload
def rename_axis(
self,
mapper: Scalar | ListLike = ...,
*,
index: Scalar | ListLike | Callable | dict | None = ...,
columns: Scalar | ListLike | Callable | dict | None = ...,
axis: AxisIndex | None = ...,
copy: _bool = ...,
inplace: Literal[False] = ...,
) -> Self: ...
Expand Down
42 changes: 42 additions & 0 deletions tests/test_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1955,6 +1955,48 @@ def test_types_rename() -> None:
df.rename(columns=lambda s: s.upper())


def test_types_rename_axis() -> None:
df = pd.DataFrame({"col_name": [1, 2, 3]})
df.index.name = "a"
df.columns.name = "b"

# Rename axes with `mapper` and `axis`
check(assert_type(df.rename_axis("A"), pd.DataFrame), pd.DataFrame)
check(assert_type(df.rename_axis(["A"]), pd.DataFrame), pd.DataFrame)
check(assert_type(df.rename_axis(None), pd.DataFrame), pd.DataFrame)
check(assert_type(df.rename_axis("B", axis=1), pd.DataFrame), pd.DataFrame)
check(assert_type(df.rename_axis(["B"], axis=1), pd.DataFrame), pd.DataFrame)
check(assert_type(df.rename_axis(None, axis=1), pd.DataFrame), pd.DataFrame)

# Rename axes with `index` and `columns`
check(
assert_type(df.rename_axis(index="A", columns="B"), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(df.rename_axis(index=["A"], columns=["B"]), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(df.rename_axis(index={"a": "A"}, columns={"b": "B"}), pd.DataFrame),
pd.DataFrame,
)
check(
assert_type(
df.rename_axis(
index=lambda name: name.upper(),
columns=lambda name: name.upper(),
),
pd.DataFrame,
),
pd.DataFrame,
)
check(
assert_type(df.rename_axis(index=None, columns=None), pd.DataFrame),
pd.DataFrame,
)


def test_types_eq() -> None:
df1 = pd.DataFrame([[1, 2], [8, 9]], columns=["A", "B"])
res1: pd.DataFrame = df1 == 1
Expand Down
21 changes: 20 additions & 1 deletion tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1121,7 +1121,26 @@ def test_types_eq() -> None:


def test_types_rename_axis() -> None:
s: pd.Series = pd.Series([1, 2, 3]).rename_axis("A")
s = pd.Series([1, 2, 3])
s.index.name = "a"

# Rename index with `mapper`
check(assert_type(s.rename_axis("A"), "pd.Series[int]"), pd.Series)
check(assert_type(s.rename_axis(["A"]), "pd.Series[int]"), pd.Series)
check(assert_type(s.rename_axis(None), "pd.Series[int]"), pd.Series)

# Rename index with `index`
check(assert_type(s.rename_axis(index="A"), "pd.Series[int]"), pd.Series)
check(assert_type(s.rename_axis(index=["A"]), "pd.Series[int]"), pd.Series)
check(assert_type(s.rename_axis(index={"a": "A"}), "pd.Series[int]"), pd.Series)
check(
assert_type(s.rename_axis(index=lambda name: name.upper()), "pd.Series[int]"),
pd.Series,
)
check(assert_type(s.rename_axis(index=None), "pd.Series[int]"), pd.Series)

if TYPE_CHECKING_INVALID_USAGE:
s.rename_axis(columns="A") # type: ignore[call-overload] # pyright: ignore[reportCallIssue]


def test_types_values() -> None:
Expand Down

0 comments on commit d289ef5

Please sign in to comment.