Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

runopts: infer from TypedDict #708

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions torchx/schedulers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from enum import Enum
from typing import Generic, Iterable, List, Optional, TypeVar

import typing_inspect

from torchx.specs import (
AppDef,
AppDryRunInfo,
Expand All @@ -23,6 +25,7 @@
runopts,
)
from torchx.workspace.api import WorkspaceMixin
from typing_extensions import final


DAYS_IN_2_WEEKS = 14
Expand Down Expand Up @@ -184,6 +187,7 @@ def submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
def _submit_dryrun(self, app: AppDef, cfg: T) -> AppDryRunInfo:
raise NotImplementedError()

@final
def run_opts(self) -> runopts:
"""
Returns the run configuration options expected by the scheduler.
Expand All @@ -195,6 +199,11 @@ def run_opts(self) -> runopts:
return opts

def _run_opts(self) -> runopts:
# pyre-fixme[16]: no attribute __orig_bases__
for base in self.__class__.__orig_bases__:
if typing_inspect.get_origin(base) == Scheduler:
return runopts.from_typed_dict(typing_inspect.get_args(base)[0])

return runopts()

@abc.abstractmethod
Expand Down
18 changes: 7 additions & 11 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ReplicaStatus,
Role,
RoleStatus,
runopts,
VolumeMount,
)
from torchx.workspace.docker_workspace import DockerWorkspaceMixin
Expand Down Expand Up @@ -95,6 +94,13 @@ def has_docker() -> bool:


class DockerOpts(TypedDict, total=False):
"""
Attributes
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use google-style docstring for these like everywhere else in torchx?
"""
Attributes:
copy_env: list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*"
"""

I validated that the runopts.from_typed_dict() parses google-style as well as numpy style docstrings equally as well.

----------
copy_env:
list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*",
"""

copy_env: Optional[List[str]]


Expand Down Expand Up @@ -341,16 +347,6 @@ def _cancel_existing(self, app_id: str) -> None:
for container in containers:
container.stop()

def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"copy_env",
type_=List[str],
default=None,
help="list of glob patterns of environment variables to copy if not set in AppDef. Ex: FOO_*",
)
return opts

def _get_app_state(self, container: "Container") -> AppState:
if container.status == "exited":
# docker doesn't have success/failed states -- we have to call
Expand Down
2 changes: 1 addition & 1 deletion torchx/schedulers/gcp_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ def _submit_dryrun(

return AppDryRunInfo(req, repr)

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"project",
Expand Down
15 changes: 5 additions & 10 deletions torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

TorchX Kubernetes_MCAD scheduler depends on AppWrapper + MCAD.

Install MCAD:
See deploying Multi-Cluster-Application-Dispatcher guide
Install MCAD:
See deploying Multi-Cluster-Application-Dispatcher guide
https://github.com/project-codeflare/multi-cluster-app-dispatcher/blob/main/doc/deploy/deployment.md

TorchX uses `torch.distributed.run <https://pytorch.org/docs/stable/elastic/run.html>`_ to run distributed training.
Expand Down Expand Up @@ -560,7 +560,7 @@ def app_to_resource(

"""
Create Service:
The selector will have the key 'appwrapper.mcad.ibm.com', and the value will be
The selector will have the key 'appwrapper.mcad.ibm.com', and the value will be
the appwrapper name
"""

Expand Down Expand Up @@ -945,7 +945,7 @@ def _submit_dryrun(
if image_secret is not None and service_account is not None:
msg = """Service Account and Image Secret names are both provided.
Depending on the Service Account configuration, an ImagePullSecret may be defined in your Service Account.
If this is the case, check service account and image secret configurations to understand the expected behavior for
If this is the case, check service account and image secret configurations to understand the expected behavior for
patched image push access."""
warnings.warn(msg)
namespace = cfg.get("namespace")
Expand Down Expand Up @@ -1000,19 +1000,14 @@ def _cancel_existing(self, app_id: str) -> None:
name=name,
)

def run_opts(self) -> runopts:
def _run_opts(self) -> runopts:
opts = runopts()
opts.add(
"namespace",
type_=str,
help="Kubernetes namespace to schedule job in",
default="default",
)
opts.add(
"image_repo",
type_=str,
help="The image repository to use when pushing patched images, must have push access. Ex: example.com/your/container",
)
opts.add(
"service_account",
type_=str,
Expand Down
143 changes: 75 additions & 68 deletions torchx/schedulers/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import unittest
from datetime import datetime
from typing import Iterable, List, Mapping, Optional, TypeVar, Union
from typing import Iterable, List, Mapping, Optional, Union
from unittest.mock import MagicMock, patch

from torchx.schedulers.api import (
Expand All @@ -27,75 +27,81 @@
NULL_RESOURCE,
Resource,
Role,
runopts,
)
from torchx.workspace.api import WorkspaceMixin
from typing_extensions import TypedDict

T = TypeVar("T")

class MockOpts(TypedDict, total=False):
"""
Attributes
----------
foo:
a dummy attribute
"""

class SchedulerTest(unittest.TestCase):
class MockScheduler(Scheduler[T], WorkspaceMixin[None]):
def __init__(self, session_name: str) -> None:
super().__init__("mock", session_name)

def schedule(self, dryrun_info: AppDryRunInfo[None]) -> str:
app = dryrun_info._app
assert app is not None
return app.name

def _submit_dryrun(
self,
app: AppDef,
cfg: Mapping[str, CfgVal],
) -> AppDryRunInfo[None]:
return AppDryRunInfo(None, lambda t: "None")

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
return None

def _cancel_existing(self, app_id: str) -> None:
pass

def log_iter(
self,
app_id: str,
role_name: str,
k: int = 0,
regex: Optional[str] = None,
since: Optional[datetime] = None,
until: Optional[datetime] = None,
should_tail: bool = False,
streams: Optional[Stream] = None,
) -> Iterable[str]:
return iter([])

def list(self) -> List[ListAppResponse]:
return []

def _run_opts(self) -> runopts:
opts = runopts()
opts.add("foo", type_=str, required=True, help="required option")
return opts

def resolve_resource(self, resource: Union[str, Resource]) -> Resource:
return NULL_RESOURCE

def build_workspace_and_update_role(
self, role: Role, workspace: str, cfg: Mapping[str, CfgVal]
) -> None:
role.image = workspace
foo: str


class MockScheduler(Scheduler[MockOpts], WorkspaceMixin[None]):
def __init__(self, session_name: str) -> None:
super().__init__("mock", session_name)

def schedule(self, dryrun_info: AppDryRunInfo[None]) -> str:
app = dryrun_info._app
assert app is not None
return app.name

def _submit_dryrun(
self,
app: AppDef,
cfg: MockOpts,
) -> AppDryRunInfo[None]:
return AppDryRunInfo(None, lambda t: "None")

def describe(self, app_id: str) -> Optional[DescribeAppResponse]:
return None

def _cancel_existing(self, app_id: str) -> None:
pass

def log_iter(
self,
app_id: str,
role_name: str,
k: int = 0,
regex: Optional[str] = None,
since: Optional[datetime] = None,
until: Optional[datetime] = None,
should_tail: bool = False,
streams: Optional[Stream] = None,
) -> Iterable[str]:
return iter([])

def list(self) -> List[ListAppResponse]:
return []

def resolve_resource(self, resource: Union[str, Resource]) -> Resource:
return NULL_RESOURCE

def build_workspace_and_update_role(
self, role: Role, workspace: str, cfg: Mapping[str, CfgVal]
) -> None:
role.image = workspace


class SchedulerTest(unittest.TestCase):
def test_invalid_run_cfg(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")
app_mock = MagicMock()

with self.assertRaises(InvalidRunConfigException):
empty_cfg = {}
empty_cfg: MockOpts = {}
scheduler_mock.submit(app_mock, empty_cfg)

with self.assertRaises(InvalidRunConfigException):
bad_type_cfg = {"foo": 100}
# pyre-ignore[55]: expected type str
bad_type_cfg: MockOpts = {"foo": 100}
scheduler_mock.submit(app_mock, bad_type_cfg)

def test_submit_workspace(self) -> None:
Expand All @@ -106,36 +112,37 @@ def test_submit_workspace(self) -> None:
)
app = AppDef(name="test_app", roles=[role])

scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")

bad_type_cfg = {"foo": "asdf"}
bad_type_cfg: MockOpts = {"foo": "asdf"}
scheduler_mock.submit(app, bad_type_cfg, workspace="some_workspace")
self.assertEqual(app.roles[0].image, "some_workspace")

def test_invalid_dryrun_cfg(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")
app_mock = MagicMock()

with self.assertRaises(InvalidRunConfigException):
empty_cfg = {}
empty_cfg: MockOpts = {}
scheduler_mock.submit_dryrun(app_mock, empty_cfg)

with self.assertRaises(InvalidRunConfigException):
bad_type_cfg = {"foo": 100}
# pyre-ignore[55]: expected type str
bad_type_cfg: MockOpts = {"foo": 100}
scheduler_mock.submit_dryrun(app_mock, bad_type_cfg)

def test_role_preproc_called(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")
app_mock = MagicMock()
app_mock.roles = [MagicMock()]

cfg = {"foo": "bar"}
cfg: MockOpts = {"foo": "bar"}
scheduler_mock.submit_dryrun(app_mock, cfg)
role_mock = app_mock.roles[0]
role_mock.pre_proc.assert_called_once()

def test_validate(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")
app_mock = MagicMock()
app_mock.roles = [MagicMock()]
app_mock.roles[0].resource = NULL_RESOURCE
Expand All @@ -144,23 +151,23 @@ def test_validate(self) -> None:
scheduler_mock._validate(app_mock, "local")

def test_cancel_not_exists(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")
with patch.object(scheduler_mock, "_cancel_existing") as cancel_mock:
with patch.object(scheduler_mock, "exists") as exists_mock:
exists_mock.return_value = True
scheduler_mock.cancel("test_id")
cancel_mock.assert_called_once()

def test_cancel_exists(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test_session")
scheduler_mock = MockScheduler("test_session")
with patch.object(scheduler_mock, "_cancel_existing") as cancel_mock:
with patch.object(scheduler_mock, "exists") as exists_mock:
exists_mock.return_value = False
scheduler_mock.cancel("test_id")
cancel_mock.assert_not_called()

def test_close_twice(self) -> None:
scheduler_mock = SchedulerTest.MockScheduler("test")
scheduler_mock = MockScheduler("test")
scheduler_mock.close()
scheduler_mock.close()
# nothing to validate explicitly, just that no errors are raised
Expand Down