Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Bug: with_overrides for cache doesn't work #2975

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@

if name is not None:
self._metadata._name = name
if self.run_entity and hasattr(self.run_entity, "metadata"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not override the run_entity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose you have two tasks in a workflow, like below. The cache version of the second task will be overridden as well.

@workflow
def wf():
  t1().with_override(cache_version="v2")
  t1()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass the node info to local_execute, and use the cache version in the node info to retrieve the data from the cache.

if self.metadata.cache and local_config.cache_enabled:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pingsutw Could I know how to do this?

"You can pass the node info to local_execute`

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is:

In node.py, self._metadata._cacheable won't be affecting the actual Task entity.

That's why we want to modify the Task's entity's metadata.

But I acknowledge the issue that the two tasks example woulld be problematic. I am not sure how local_config can help with it.

self.run_entity.metadata.name = name

Check warning on line 198 in flytekit/core/node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/node.py#L198

Added line #L198 was not covered by tests

if task_config is not None:
logger.warning("This override is beta. We may want to revisit this in the future.")
Expand All @@ -212,14 +214,20 @@
if cache is not None:
assert_not_promise(cache, "cache")
self._metadata._cacheable = cache
if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None:
self.run_entity.metadata.cache = cache

Check warning on line 218 in flytekit/core/node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/node.py#L218

Added line #L218 was not covered by tests

if cache_version is not None:
assert_not_promise(cache_version, "cache_version")
self._metadata._cache_version = cache_version
if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None:
self.run_entity.metadata.cache_version = cache_version

Check warning on line 224 in flytekit/core/node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/node.py#L224

Added line #L224 was not covered by tests

if cache_serialize is not None:
assert_not_promise(cache_serialize, "cache_serialize")
self._metadata._cache_serializable = cache_serialize
if hasattr(self.run_entity, "metadata") and self.run_entity.metadata is not None:
self.run_entity.metadata.cache_serialize = cache_serialize

Check warning on line 230 in flytekit/core/node.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/node.py#L230

Added line #L230 was not covered by tests

return self

Expand Down
7 changes: 6 additions & 1 deletion flytekit/tools/serialize_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@
that are not known to Admin
"""
new_api_serializable_entities = OrderedDict()

# Sort entities to process workflows and launch plans before tasks
# TODO: Clean up the copy() - it's here because we call get_default_launch_plan, which may create a LaunchPlan
# object, which gets added to the FlyteEntities.entities list, which we're iterating over.
for entity in flyte_context.FlyteEntities.entities.copy():
sorted_entities = sorted(

Check warning on line 57 in flytekit/tools/serialize_helpers.py

View check run for this annotation

Codecov / codecov/patch

flytekit/tools/serialize_helpers.py#L57

Added line #L57 was not covered by tests
flyte_context.FlyteEntities.entities.copy(), key=lambda x: 0 if isinstance(x, (WorkflowBase, LaunchPlan)) else 1
)
for entity in sorted_entities:
if isinstance(entity, PythonTask) or isinstance(entity, WorkflowBase) or isinstance(entity, LaunchPlan):
get_serializable(new_api_serializable_entities, ctx.serialization_settings, entity, options=options)

Expand Down
8 changes: 4 additions & 4 deletions tests/flytekit/unit/cli/pyflyte/test_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def test_package_with_fast_registration_and_envvars():
# Uncompress flyte-package.tgz
tarfile.open("flyte-package.tgz", "r:gz").extractall()

# Load the proto message from file 3_core.sample.sum_1.pb
# Load the proto message from file 4_core.sample.sum_1.pb
task_spec = task_pb2.TaskSpec()
task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read())
task_spec.ParseFromString(open("4_core.sample.sum_1.pb", "rb").read())

assert task_spec.template.container.env[0].key == "abc"
assert task_spec.template.container.env[0].value == "42"
Expand Down Expand Up @@ -148,9 +148,9 @@ def test_package_with_fast_registration_and_envvars():

tarfile.open("flyte-package.tgz", "r:gz").extractall()

# Load the proto message from file 3_core.sample.sum_1.pb
# Load the proto message from file 4_core.sample.sum_1.pb
task_spec = task_pb2.TaskSpec()
task_spec.ParseFromString(open("3_core.sample.sum_1.pb", "rb").read())
task_spec.ParseFromString(open("4_core.sample.sum_1.pb", "rb").read())

assert task_spec.template.container.env[0].key == "k1"
assert task_spec.template.container.env[0].value == "v1"
Expand Down
15 changes: 13 additions & 2 deletions tests/flytekit/unit/core/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,8 +513,19 @@ def my_wf(a: str) -> str:
image_config=ImageConfig(Image(name="name", fqn="image", tag="name")),
env={},
)
wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)

assert wf_spec.template.nodes[0].metadata.cache_serializable
task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert not task_spec.template.metadata.discoverable
assert task_spec.template.metadata.discovery_version != "foo"
assert not task_spec.template.metadata.cache_serializable

wf_spec = get_serializable(OrderedDict(), serialization_settings, my_wf)
assert len(wf_spec.template.nodes) == 1
assert wf_spec.template.nodes[0].metadata.cacheable
assert wf_spec.template.nodes[0].metadata.cache_version == "foo"
assert wf_spec.template.nodes[0].metadata.cache_serializable

task_spec = get_serializable(OrderedDict(), serialization_settings, t1)
assert task_spec.template.metadata.discoverable
assert task_spec.template.metadata.discovery_version == "foo"
assert task_spec.template.metadata.cache_serializable
Loading