Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use tomlkit to dump updated dependencies #212

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 47 additions & 23 deletions ci_cd/tasks/update_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
parse_ignore_entries,
parse_ignore_rules,
regenerate_requirement,
update_file,
update_specifier_set,
warning_msg,
)
Expand All @@ -44,7 +43,6 @@
# Get logger
LOGGER = logging.getLogger(__name__)


VALID_PACKAGE_NAME_PATTERN = r"^([A-Z0-9]|[A-Z0-9][A-Z0-9._-]*[A-Z0-9])$"
"""
Pattern to validate package names.
Expand All @@ -57,8 +55,44 @@
"""


def _update_pyproject(
original_dependency: str, updated_dependency: str, pyproject: tomlkit.TOMLDocument
) -> None:
"""Update dependency in pyproject data structure.

First, check and update the dependency if it is in the "dependencies" group
Then, check and update if it is in any of the "optional-dependencies" groups.

Essentially, we allow for the original dependency to be in multiple groups.
"""
LOGGER.debug(
"Updating pyproject data structure for %r to %r",
original_dependency,
updated_dependency,
)

if original_dependency in pyproject["project"].get("dependencies", []):
index = pyproject["project"]["dependencies"].index(original_dependency)
pyproject["project"]["dependencies"][index] = updated_dependency.replace(
'"', "'"
)

for extra_name, extra_dependencies in (
pyproject["project"].get("optional-dependencies", {}).items()
):
if original_dependency in extra_dependencies:
index = pyproject["project"]["optional-dependencies"][extra_name].index(
original_dependency
)
pyproject["project"]["optional-dependencies"][extra_name][index] = (
updated_dependency.replace('"', "'")
)


def _format_and_update_dependency(
requirement: Requirement, raw_dependency_line: str, pyproject_path: Path
requirement: Requirement,
raw_dependency_line: str,
pyproject: tomlkit.TOMLDocument = None,
) -> None:
"""Regenerate dependency without changing anything but the formatting.

Expand All @@ -72,12 +106,8 @@ def _format_and_update_dependency(
)
LOGGER.debug("Regenerated dependency: %r", updated_dependency)
if updated_dependency != raw_dependency_line:
# Update pyproject.toml since the dependency formatting has changed
LOGGER.debug("Updating pyproject.toml for %r", requirement.name)
update_file(
pyproject_path,
(re.escape(raw_dependency_line), updated_dependency.replace('"', "'")),
)
# Update pyproject data structure since the dependency formatting has changed
_update_pyproject(raw_dependency_line, updated_dependency, pyproject)


@task(
Expand Down Expand Up @@ -192,7 +222,8 @@ def update_deps(
)

# Build the list of dependencies listed in pyproject.toml
dependencies: list[str] = pyproject.get("project", {}).get("dependencies", [])
dependencies: list[str] = []
dependencies.extend(pyproject.get("project", {}).get("dependencies", []))
for optional_deps in (
pyproject.get("project", {}).get("optional-dependencies", {}).values()
):
Expand Down Expand Up @@ -259,9 +290,7 @@ def update_deps(
LOGGER.info(msg)
print(info_msg(msg), flush=True)

_format_and_update_dependency(
parsed_requirement, dependency, pyproject_path
)
_format_and_update_dependency(parsed_requirement, dependency, pyproject)
already_handled_packages.add(parsed_requirement)
continue

Expand All @@ -278,9 +307,7 @@ def update_deps(
LOGGER.warning(msg)
print(warning_msg(msg), flush=True)

_format_and_update_dependency(
parsed_requirement, dependency, pyproject_path
)
_format_and_update_dependency(parsed_requirement, dependency, pyproject)
already_handled_packages.add(parsed_requirement)
continue

Expand Down Expand Up @@ -469,14 +496,8 @@ def update_deps(
)
LOGGER.debug("Updated dependency: %r", updated_dependency)

pattern_sub_line = re.escape(dependency)
replacement_sub_line = updated_dependency.replace('"', "'")
_update_pyproject(dependency, updated_dependency, pyproject)

LOGGER.debug("pattern_sub_line: %s", pattern_sub_line)
LOGGER.debug("replacement_sub_line: %s", replacement_sub_line)

# Update pyproject.toml
update_file(pyproject_path, (pattern_sub_line, replacement_sub_line))
already_handled_packages.add(parsed_requirement)
updated_packages[parsed_requirement.name] = ",".join(
str(_)
Expand All @@ -492,6 +513,9 @@ def update_deps(
f"{Emoji.CROSS_MARK.value} Errors occurred! See printed statements above."
)

# Update pyproject.toml
pyproject_path.write_text(tomlkit.dumps(pyproject), encoding="utf-8")

if updated_packages:
print(
f"{Emoji.PARTY_POPPER.value} Successfully updated the following "
Expand Down
40 changes: 25 additions & 15 deletions tests/tasks/test_update_deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,22 @@ def test_skip_unnormalized_python_package_names(
r"statements above\.$"
)

# Due to an atomistic approach, if an error occurs, the pyproject.toml file will
# not be updated.
successful_expected_pyproject_file_data = """[project]
name = "{{ cookiecutter.project_slug }}"
requires-python = ">=3.8"

dependencies = []

[project.optional-dependencies]
dev = ["pytest~=7.4"]
all = ["{{ cookiecutter.project_slug }}[dev]"]
"""
erroneous_expected_pyproject_file_data = pyproject_file_data

if skip_unnormalized_python_package_names:
# This should end in success
update_deps(
context,
root_repo_path=str(tmp_path),
Expand All @@ -1140,7 +1155,13 @@ def test_skip_unnormalized_python_package_names(
terminal_error_msg.search(stdouterr.err) is None
), f"{terminal_error_msg!r} unexpectedly found in {stdouterr.err}"

assert (
pyproject_file.read_text(encoding="utf8")
== successful_expected_pyproject_file_data
)

else:
# This should end in failure
with pytest.raises(SystemExit, match=raise_msg):
update_deps(
context,
Expand Down Expand Up @@ -1173,18 +1194,7 @@ def test_skip_unnormalized_python_package_names(
terminal_error_msg.search(stdouterr.err) is not None
), f"{terminal_error_msg!r} not found in {stdouterr.err}"

# In both cases, the pyproject.toml file should be updated for pytest.
# When/if a more atomistic approach is taken, then this should *NOT* be the case
# for runs where an error occurs.
expected_pyproject_file_data = """[project]
name = "{{ cookiecutter.project_slug }}"
requires-python = ">=3.8"

dependencies = []

[project.optional-dependencies]
dev = ["pytest~=7.4"]
all = ["{{ cookiecutter.project_slug }}[dev]"]
"""

assert pyproject_file.read_text(encoding="utf8") == expected_pyproject_file_data
assert (
pyproject_file.read_text(encoding="utf8")
== erroneous_expected_pyproject_file_data
)
Loading