Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions torchx/schedulers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchx.schedulers.api import Scheduler
from torchx.util.entrypoints import load_group

DEFAULT_SCHEDULER_MODULES: Mapping[str, str] = {
BUILTIN_SCHEDULER_MODULES: Mapping[str, str] = {
"local_docker": "torchx.schedulers.docker_scheduler",
"local_cwd": "torchx.schedulers.local_scheduler",
"slurm": "torchx.schedulers.slurm_scheduler",
Expand All @@ -39,24 +39,25 @@ def run(*args: object, **kwargs: object) -> Scheduler:
return run


def default_schedulers() -> dict[str, SchedulerFactory]:
"""Build default schedulers (built-in + extras from torchx.schedulers.extra)."""
return {
**{s: _defer_load_scheduler(p) for s, p in BUILTIN_SCHEDULER_MODULES.items()},
**load_group("torchx.schedulers.extra", default={}),
}
Comment on lines +42 to +47
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a closer look at this, and realized that we dont' need this change to get what you want.

Suppose what you want is NeMO's [torchx.schedulers] as well as other ones that you use. Python entrypoint groups are compounding (but the keys in the groups are not). So as long as you don't have a conflict in the scheduler names, you can define your own entrypoint as

[torchx.schedulers]
local_cwd = torchx.schedulers.local_scheduler_fb:create_cwd_scheduler
aws_batch = torchx.schedulers.aws_batch_scheduler:create_scheduler
... others you want ...

and NeMO defines their set as:

[torchx.schedulers]
nemo_sched_1 = ...
nemo_sched_2 = ...

the schedulers torchx ends up loading would be:

local_cwd
aws_batch
nemo_sched_1
nemo_sched_2

However in the absence of your entrypoint [torchx.schedulers] torchx would only load the ones defined in NeMO (nemo_sched_1 and nemo_sched_2).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm aware of this effect @kiukchung but it means there must be some package re-registering the (ideally all) built-in schedulers and keep up with the list (I remember you mentioned there's no plan to extend the list though - I was not aware when this ticket was cut).
E.g. inside a NeMo container where only NeMo-Run and TorchX dependency are installed we only get NeMo-Run schedulers (like you mentioned), i.e. no local_cwd. As as user I would like to get NeMo-ones, built-ins and maybe add my own (just like components), but not at the cost re-registering if that makes sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you running on a pre-built/published NeMo docker image (e.g.nvcr.io/nvidia/nemo) directly? you don't have your workspace/project to install?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't include torchx in runtime dependencies to have a nice separation in dependency closures, so no @kiukchung

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the fact that you want to use local_cwd means that you have a direct dep on torchx no? I'm having a hard time understanding the exact use-case to be able to offer you solutions/help. Perhaps a quick catch-up over slack or VC would help?

I realize this is a long standing ask so want to help unblock you.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to use the vanilla one that is already there :) So in order for me to be able to use local_cwd from within the NeMo container is either install a custom package that re-register built-ins (I don't know such package nor does it look tempting to create one), or comment out entrypoints for NeMo-Run @kiukchung



def get_scheduler_factories(
group: str = "torchx.schedulers", skip_defaults: bool = False
group: str = "torchx.schedulers",
skip_defaults: bool = False,
) -> dict[str, SchedulerFactory]:
"""
get_scheduler_factories returns all the available schedulers names under `group` and the
method to instantiate them.

The first scheduler in the dictionary is used as the default scheduler.
"""

if skip_defaults:
default_schedulers = {}
else:
default_schedulers: dict[str, SchedulerFactory] = {}
for scheduler, path in DEFAULT_SCHEDULER_MODULES.items():
default_schedulers[scheduler] = _defer_load_scheduler(path)

return load_group(group, default=default_schedulers)
return load_group(group, default={} if skip_defaults else default_schedulers())


def get_default_scheduler_name() -> str:
Expand Down
72 changes: 72 additions & 0 deletions torchx/schedulers/test/registry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,75 @@ def test_get_local_schedulers(self, mock_load_group: MagicMock) -> None:

for scheduler in schedulers.values():
self.assertEqual("test_session", scheduler.session_name)

@patch("torchx.schedulers.load_group")
def test_torchx_schedulers_overrides_all(self, mock_load_group: MagicMock) -> None:
"""torchx.schedulers completely overrides defaults and ignores extras"""
mock_custom: MagicMock = MagicMock()
mock_extra: MagicMock = MagicMock()

mock_load_group.side_effect = lambda group, default: (
{"custom": mock_custom}
if group == "torchx.schedulers"
else {"extra": mock_extra} if group == "torchx.schedulers.extra" else {}
)

factories = get_scheduler_factories()

self.assertEqual(factories, {"custom": mock_custom})
self.assertNotIn("local_docker", factories)
self.assertNotIn("extra", factories)

@patch("torchx.schedulers.load_group")
def test_no_custom_returns_defaults_and_extras(
self, mock_load_group: MagicMock
) -> None:
"""no custom schedulers returns built-in + extras"""
mock_extra: MagicMock = MagicMock()

mock_load_group.side_effect = lambda group, default: (
{"extra": mock_extra} if group == "torchx.schedulers.extra" else default
)

factories = get_scheduler_factories()

self.assertIn("local_docker", factories)
self.assertIn("slurm", factories)
self.assertIn("extra", factories)

@patch("torchx.schedulers.load_group")
def test_no_custom_no_extras_returns_builtins(
self, mock_load_group: MagicMock
) -> None:
"""no custom, no extras returns only built-in schedulers"""
mock_load_group.side_effect = lambda group, default: default

factories = get_scheduler_factories()

self.assertIn("local_docker", factories)
self.assertIn("slurm", factories)

@patch("torchx.schedulers.load_group")
def test_skip_defaults_returns_empty(self, mock_load_group: MagicMock) -> None:
"""skip_defaults=True with no custom schedulers returns empty"""
mock_load_group.side_effect = lambda group, default: default

factories = get_scheduler_factories(skip_defaults=True)

self.assertEqual(factories, {})

@patch("torchx.schedulers.load_group")
def test_custom_scheduler_is_default(self, mock_load_group: MagicMock) -> None:
"""first custom scheduler becomes the default"""
mock_aws: MagicMock = MagicMock()
mock_custom: MagicMock = MagicMock()

mock_load_group.side_effect = lambda group, default: (
{"aws_batch": mock_aws, "custom_1": mock_custom}
if group == "torchx.schedulers"
else {}
)

default_name = get_default_scheduler_name()

self.assertIn(default_name, ["aws_batch", "custom_1"])
Loading