Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[@parallel on Kubernetes] support for Jobsets #1804

Merged
merged 1 commit into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions metaflow/metaflow_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,9 @@
ARGO_WORKFLOWS_KUBERNETES_SECRETS = from_conf("ARGO_WORKFLOWS_KUBERNETES_SECRETS", "")
ARGO_WORKFLOWS_ENV_VARS_TO_SKIP = from_conf("ARGO_WORKFLOWS_ENV_VARS_TO_SKIP", "")

KUBERNETES_JOBSET_GROUP = from_conf("KUBERNETES_JOBSET_GROUP", "jobset.x-k8s.io")
KUBERNETES_JOBSET_VERSION = from_conf("KUBERNETES_JOBSET_VERSION", "v1alpha2")

##
# Argo Events Configuration
##
Expand Down
5 changes: 5 additions & 0 deletions metaflow/plugins/argo/argo_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,11 @@ def _dag_templates(self):
def _visit(
node, exit_node=None, templates=None, dag_tasks=None, parent_foreach=None
):
if node.parallel_foreach:
raise ArgoWorkflowsException(
"Deploying flows with @parallel decorator(s) "
"as Argo Workflows is not supported currently."
)
# Every for-each node results in a separate subDAG and an equivalent
# DAGTemplate rooted at the child of the for-each node. Each DAGTemplate
# has a unique name - the top-level DAGTemplate is named as the name of
Expand Down
81 changes: 73 additions & 8 deletions metaflow/plugins/kubernetes/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import os
import re
import shlex
import copy
import time
from typing import Dict, List, Optional
import uuid
from uuid import uuid4

from metaflow import current, util
Expand Down Expand Up @@ -67,6 +67,12 @@ class KubernetesKilledException(MetaflowException):
headline = "Kubernetes Batch job killed"


def _extract_labels_and_annotations_from_job_spec(job_spec):
annotations = job_spec.template.metadata.annotations
labels = job_spec.template.metadata.labels
return copy.copy(annotations), copy.copy(labels)


class Kubernetes(object):
def __init__(
self,
Expand Down Expand Up @@ -141,9 +147,64 @@ def _command(
return shlex.split('bash -c "%s"' % cmd_str)

def launch_job(self, **kwargs):
self._job = self.create_job(**kwargs).execute()
if (
"num_parallel" in kwargs
and kwargs["num_parallel"]
and int(kwargs["num_parallel"]) > 0
):
job = self.create_job_object(**kwargs)
spec = job.create_job_spec()
# `kwargs["step_cli"]` is setting `ubf_context` as control to ALL pods.
# This will be modified by the KubernetesJobSet object
annotations, labels = _extract_labels_and_annotations_from_job_spec(spec)
self._job = self.create_jobset(
job_spec=spec,
run_id=kwargs["run_id"],
step_name=kwargs["step_name"],
task_id=kwargs["task_id"],
namespace=kwargs["namespace"],
env=kwargs["env"],
num_parallel=kwargs["num_parallel"],
port=kwargs["port"],
annotations=annotations,
labels=labels,
).execute()
else:
kwargs["name_pattern"] = "t-{uid}-".format(uid=str(uuid4())[:8])
self._job = self.create_job_object(**kwargs).create().execute()

def create_jobset(
self,
job_spec=None,
run_id=None,
step_name=None,
task_id=None,
namespace=None,
env=None,
num_parallel=None,
port=None,
annotations=None,
labels=None,
):
if env is None:
env = {}

def create_job(
_prefix = str(uuid4())[:6]
js = KubernetesClient().jobset(
name="js-%s" % _prefix,
run_id=run_id,
task_id=task_id,
step_name=step_name,
namespace=namespace,
labels=self._get_labels(labels),
annotations=annotations,
num_parallel=num_parallel,
job_spec=job_spec,
port=port,
)
return js

def create_job_object(
self,
flow_name,
run_id,
Expand Down Expand Up @@ -177,14 +238,15 @@ def create_job(
labels=None,
shared_memory=None,
port=None,
name_pattern=None,
num_parallel=None,
):
if env is None:
env = {}

job = (
KubernetesClient()
.job(
generate_name="t-{uid}-".format(uid=str(uuid4())[:8]),
generate_name=name_pattern,
namespace=namespace,
service_account=service_account,
secrets=secrets,
Expand Down Expand Up @@ -218,6 +280,7 @@ def create_job(
persistent_volume_claims=persistent_volume_claims,
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
)
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
.environment_variable("METAFLOW_CODE_URL", code_package_url)
Expand Down Expand Up @@ -336,6 +399,9 @@ def create_job(
.label("app.kubernetes.io/part-of", "metaflow")
)

return job

def create_k8sjob(self, job):
return job.create()

def wait(self, stdout_location, stderr_location, echo=None):
Expand Down Expand Up @@ -370,7 +436,7 @@ def wait_for_launch(job):
t = time.time()
time.sleep(update_delay(time.time() - start_time))

prefix = b"[%s] " % util.to_bytes(self._job.id)
_make_prefix = lambda: b"[%s] " % util.to_bytes(self._job.id)

stdout_tail = get_log_tailer(stdout_location, self._datastore.TYPE)
stderr_tail = get_log_tailer(stderr_location, self._datastore.TYPE)
Expand All @@ -380,7 +446,7 @@ def wait_for_launch(job):

# 2) Tail logs until the job has finished
tail_logs(
prefix=prefix,
prefix=_make_prefix(),
stdout_tail=stdout_tail,
stderr_tail=stderr_tail,
echo=echo,
Expand All @@ -396,7 +462,6 @@ def wait_for_launch(job):
# exists prior to calling S3Tail and note the user about
# truncated logs if it doesn't.
# TODO : For hard crashes, we can fetch logs from the pod.

if self._job.has_failed:
exit_code, reason = self._job.reason
msg = next(
Expand Down
25 changes: 24 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,17 @@
from metaflow._vendor import click
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
from metaflow.metadata.util import sync_local_metadata_from_datastore
from metaflow.unbounded_foreach import UBF_CONTROL, UBF_TASK
from metaflow.metaflow_config import DATASTORE_LOCAL_DIR, KUBERNETES_LABELS
from metaflow.mflog import TASK_LOG_SOURCE
import metaflow.tracing as tracing

from .kubernetes import Kubernetes, KubernetesKilledException, parse_kube_keyvalue_list
from .kubernetes import (
Kubernetes,
KubernetesKilledException,
parse_kube_keyvalue_list,
KubernetesException,
)
from .kubernetes_decorator import KubernetesDecorator


Expand Down Expand Up @@ -109,6 +115,15 @@ def kubernetes():
)
@click.option("--shared-memory", default=None, help="Size of shared memory in MiB")
@click.option("--port", default=None, help="Port number to expose from the container")
@click.option(
"--ubf-context", default=None, type=click.Choice([None, UBF_CONTROL, UBF_TASK])
)
@click.option(
"--num-parallel",
default=None,
type=int,
help="Number of parallel nodes to run as a multi-node job.",
)
@click.pass_context
def step(
ctx,
Expand Down Expand Up @@ -136,6 +151,7 @@ def step(
tolerations=None,
shared_memory=None,
port=None,
num_parallel=None,
**kwargs
):
def echo(msg, stream="stderr", job_id=None, **kwargs):
Expand Down Expand Up @@ -167,6 +183,12 @@ def echo(msg, stream="stderr", job_id=None, **kwargs):
kwargs["input_paths"] = "".join("${%s}" % s for s in split_vars.keys())
env.update(split_vars)

if num_parallel is not None and num_parallel <= 1:
raise KubernetesException(
"Using @parallel with `num_parallel` <= 1 is not supported with Kubernetes. "
"Please set the value of `num_parallel` to be greater than 1."
)

Comment on lines +186 to +191
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

constraint added because jobset doesn't play nice with replicas = 0; Once kubernetes-sigs/jobset allow this, we can lift this constraint and add version logic to verify if the jobset can be submitted or not. Currently not supported with Jobset CRD version jobset.x-k8s.io/v1alpha2

# Set retry policy.
retry_count = int(kwargs.get("retry_count", 0))
retry_deco = [deco for deco in node.decorators if deco.name == "retry"]
Expand Down Expand Up @@ -251,6 +273,7 @@ def _sync_metadata():
tolerations=tolerations,
shared_memory=shared_memory,
port=port,
num_parallel=num_parallel,
)
except Exception as e:
traceback.print_exc(chain=False)
Expand Down
5 changes: 4 additions & 1 deletion metaflow/plugins/kubernetes/kubernetes_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from metaflow.exception import MetaflowException

from .kubernetes_job import KubernetesJob
from .kubernetes_job import KubernetesJob, KubernetesJobSet
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need this import for Kubernetes clients which maybe getting used via extensions.



CLIENT_REFRESH_INTERVAL_SECONDS = 300
Expand Down Expand Up @@ -61,5 +61,8 @@ def get(self):

return self._client

def jobset(self, **kwargs):
return KubernetesJobSet(self, **kwargs)

def job(self, **kwargs):
return KubernetesJob(self, **kwargs)
53 changes: 49 additions & 4 deletions metaflow/plugins/kubernetes/kubernetes_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@

from ..aws.aws_utils import get_docker_registry, get_ec2_instance_metadata
from .kubernetes import KubernetesException, parse_kube_keyvalue_list
from metaflow.unbounded_foreach import UBF_CONTROL
from .kubernetes_jobsets import TaskIdConstructor

try:
unicode
Expand Down Expand Up @@ -239,11 +241,15 @@ def step_init(self, flow, graph, step, decos, environment, flow_datastore, logge
"Kubernetes. Please use one or the other.".format(step=step)
)

for deco in decos:
if getattr(deco, "IS_PARALLEL", False):
raise KubernetesException(
"@kubernetes does not support parallel execution currently."
if any([deco.name == "parallel" for deco in decos]) and any(
[deco.name == "catch" for deco in decos]
):
raise MetaflowException(
"Step *{step}* contains a @parallel decorator "
"with the @catch decorator. @catch is not supported with @parallel on Kubernetes.".format(
step=step
)
)

# Set run time limit for the Kubernetes job.
self.run_time_limit = get_run_time_limit_for_task(decos)
Expand Down Expand Up @@ -421,6 +427,10 @@ def task_pre_step(
"METAFLOW_KUBERNETES_SERVICE_ACCOUNT_NAME"
]
meta["kubernetes-node-ip"] = os.environ["METAFLOW_KUBERNETES_NODE_IP"]
if os.environ.get("METAFLOW_KUBERNETES_JOBSET_NAME"):
meta["kubernetes-jobset-name"] = os.environ[
"METAFLOW_KUBERNETES_JOBSET_NAME"
]

# TODO (savin): Introduce equivalent support for Microsoft Azure and
# Google Cloud Platform
Expand Down Expand Up @@ -453,6 +463,24 @@ def task_pre_step(
self._save_logs_sidecar = Sidecar("save_logs_periodically")
self._save_logs_sidecar.start()

num_parallel = None
if hasattr(flow, "_parallel_ubf_iter"):
num_parallel = flow._parallel_ubf_iter.num_parallel

if num_parallel and num_parallel >= 1 and ubf_context == UBF_CONTROL:
control_task_id, worker_task_ids = TaskIdConstructor.join_step_task_ids(
num_parallel
)
mapper_task_ids = [control_task_id] + worker_task_ids
flow._control_mapper_tasks = [
"%s/%s/%s" % (run_id, step_name, mapper_task_id)
for mapper_task_id in mapper_task_ids
]
flow._control_task_is_mapper_zero = True

if num_parallel and num_parallel > 1:
_setup_multinode_environment()
Comment on lines +466 to +482
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed so that Join steps has all the relevant task-ids.


def task_finished(
self, step_name, flow, graph, is_task_ok, retry_count, max_retries
):
Expand Down Expand Up @@ -486,3 +514,20 @@ def _save_package_once(cls, flow_datastore, package):
cls.package_url, cls.package_sha = flow_datastore.save_data(
[package.blob], len_hint=1
)[0]


def _setup_multinode_environment():
import socket

os.environ["MF_PARALLEL_MAIN_IP"] = socket.gethostbyname(os.environ["MASTER_ADDR"])
os.environ["MF_PARALLEL_NUM_NODES"] = os.environ["WORLD_SIZE"]
if os.environ.get("CONTROL_INDEX") is not None:
os.environ["MF_PARALLEL_NODE_INDEX"] = str(0)
elif os.environ.get("WORKER_REPLICA_INDEX") is not None:
os.environ["MF_PARALLEL_NODE_INDEX"] = str(
int(os.environ["WORKER_REPLICA_INDEX"]) + 1
)
else:
raise MetaflowException(
"Jobset related ENV vars called $CONTROL_INDEX or $WORKER_REPLICA_INDEX not found"
)