Skip to content

Commit 44bf04c

Browse files
justusschockdeependujhabhimrazy
authored andcommitted
Internal Refactor: Reroute Implementations (#21354)
* forward xla impl * forward logger implementation * forward logger implementation: mlflow * update neptune logger * forward kubeflow implementation * forward lsf env * move torchelastic * update xla env * forward bitsandbytes * forward deepspeed precision * forward transformer engine * forward XLA precision * forward deepspeed strategy fabric * integrate xla strategies * update pytorch deepspeed precision * forward trainer xla single device * XLA ddp trainer * update * update * update * update * update * update * update * update * update * update * update * update * update * update * update fabric tests * fabric tests * tests * update version * update * update * update * update * update * update * fix doc issue * fix mypy issue * fix readthedocs and ci cpu tests * update * update * update * update * update * update * fix deepspeed assertion * update * fix transformer engine mock * update * logger mocks * add tpu mocks * update * update * update * update * fix docmake * update * update * fix loggers error * update * update * update * update * pin cuda version * update * try with removing libnccl downloading * undo cuda pinning * update * update * corretly handle model property * update error types and add property forwarding * update * update * update * meow meow * claymore!!! * remove todo * remove todos + version * retrigger-ci to fix ple release issue * fix mocks xla --------- Co-authored-by: Deependu Jha <deependujha21@gmail.com> Co-authored-by: Bhimraj Yadav <bhimrajyadav977@gmail.com>
1 parent 2e2c8f6 commit 44bf04c

File tree

65 files changed

+1096
-4184
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1096
-4184
lines changed

.github/workflows/docs-build.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,10 @@ jobs:
125125
working-directory: ./docs/source-${{ matrix.pkg-name }}
126126
# allow failing link check and doctest if you run with dispatch
127127
continue-on-error: ${{ (matrix.target == 'doctest' || matrix.target == 'linkcheck') && github.event_name == 'workflow_dispatch' }}
128-
run: make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
128+
run: |
129+
# temp fix: https://github.com/Lightning-AI/pytorch-lightning/actions/runs/19440502586/job/55622388642?pr=21354#step:11:4596
130+
uv pip install -U fastapi
131+
make ${{ matrix.target }} --debug --jobs $(nproc) SPHINXOPTS="$BUILD_SPHINX_OPTS"
129132
130133
- name: Keep artifact
131134
if: github.event_name == 'pull_request'

docs/source-pytorch/conf.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -604,10 +604,7 @@ def package_list_from_file(file):
604604
from lightning.pytorch.cli import _JSONARGPARSE_SIGNATURES_AVAILABLE as _JSONARGPARSE_AVAILABLE
605605
from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE
606606
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE, _TENSORBOARDX_AVAILABLE
607-
from lightning.pytorch.loggers.neptune import _NEPTUNE_AVAILABLE
608-
from lightning.pytorch.loggers.comet import _COMET_AVAILABLE
609-
from lightning.pytorch.loggers.mlflow import _MLFLOW_AVAILABLE
610-
from lightning.pytorch.loggers.wandb import _WANDB_AVAILABLE
607+
from lightning.fabric.utilities.imports import _COMET_AVAILABLE, _MLFLOW_AVAILABLE, _NEPTUNE_AVAILABLE, _WANDB_AVAILABLE
611608
"""
612609
coverage_skip_undoc_in_source = True
613610

requirements/fabric/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ fsspec[http] >=2022.5.0, <2025.11.0
66
packaging >=20.0, <=25.0
77
typing-extensions >4.5.0, <4.16.0
88
lightning-utilities >=0.10.0, <0.16.0
9+
pytorch-lightning-enterprise >=2.6.0

requirements/pytorch/base.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ torchmetrics >0.7.0, <1.9.0
99
packaging >=20.0, <=25.0
1010
typing-extensions >4.5.0, <4.16.0
1111
lightning-utilities >=0.10.0, <0.16.0
12+
pytorch-lightning-enterprise >=2.6.0

src/lightning/fabric/accelerators/xla.py

Lines changed: 27 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import functools
15+
import warnings
1516
from typing import Any, Union
1617

1718
import torch
@@ -20,7 +21,11 @@
2021

2122
from lightning.fabric.accelerators.accelerator import Accelerator
2223
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
23-
from lightning.fabric.utilities.device_parser import _check_data_type
24+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
25+
26+
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
27+
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
28+
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
2429

2530

2631
class XLAAccelerator(Accelerator):
@@ -31,38 +36,38 @@ class XLAAccelerator(Accelerator):
3136
"""
3237

3338
def __init__(self, *args: Any, **kwargs: Any) -> None:
34-
if not _XLA_AVAILABLE:
35-
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
36-
if not _using_pjrt():
37-
raise RuntimeError("The XLA XRT runtime is not supported anymore.")
39+
_raise_enterprise_not_available()
3840
super().__init__(*args, **kwargs)
3941

42+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
43+
44+
self.accelerator_impl = EnterpriseXLAAccelerator(*args, **kwargs)
45+
4046
@override
4147
def setup_device(self, device: torch.device) -> None:
42-
pass
48+
return self.accelerator_impl.setup_device(device)
4349

4450
@override
4551
def teardown(self) -> None:
46-
pass
52+
return self.accelerator_impl.teardown()
4753

4854
@staticmethod
4955
@override
5056
def parse_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
5157
"""Accelerator device parsing logic."""
52-
return _parse_tpu_devices(devices)
58+
_raise_enterprise_not_available()
59+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
60+
61+
return EnterpriseXLAAccelerator.parse_devices(devices)
5362

5463
@staticmethod
5564
@override
5665
def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
5766
"""Gets parallel devices for the Accelerator."""
58-
devices = _parse_tpu_devices(devices)
59-
if isinstance(devices, int):
60-
return [torch.device("xla", i) for i in range(devices)]
61-
# list of devices is not supported, just a specific index, fine to access [0]
62-
return [torch.device("xla", devices[0])]
63-
# we cannot create `xla_device` here because processes have not been spawned yet (this is called in the
64-
# accelerator connector init). However, there doesn't seem to be a problem with instantiating `torch.device`.
65-
# it will be replaced with `xla_device` (also a torch.device`, but with extra logic) in the strategy
67+
_raise_enterprise_not_available()
68+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
69+
70+
return EnterpriseXLAAccelerator.get_parallel_devices(devices)
6671

6772
@staticmethod
6873
@override
@@ -71,16 +76,10 @@ def get_parallel_devices(devices: Union[int, list[int]]) -> list[torch.device]:
7176
@functools.lru_cache(maxsize=1)
7277
def auto_device_count() -> int:
7378
"""Get the devices when set to auto."""
74-
if not _XLA_AVAILABLE:
75-
return 0
76-
if _XLA_GREATER_EQUAL_2_1:
77-
from torch_xla._internal import tpu
78-
79-
return tpu.num_available_devices()
80-
from torch_xla.experimental import tpu
79+
_raise_enterprise_not_available()
80+
from pytorch_lightning_enterprise.accelerators.xla import XLAAccelerator as EnterpriseXLAAccelerator
8181

82-
device_count_on_version = {2: 8, 3: 8, 4: 4}
83-
return device_count_on_version.get(tpu.version(), 8)
82+
return EnterpriseXLAAccelerator.auto_device_count()
8483

8584
@staticmethod
8685
@override
@@ -92,6 +91,9 @@ def is_available() -> bool:
9291
# XLA may raise these exceptions if it's not properly configured. This needs to be avoided for the cases
9392
# when `torch_xla` is imported but not used
9493
return False
94+
except ModuleNotFoundError as e:
95+
warnings.warn(str(e))
96+
return False
9597

9698
@staticmethod
9799
@override
@@ -106,74 +108,3 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No
106108
cls,
107109
description=cls.__name__,
108110
)
109-
110-
111-
# PJRT support requires this minimum version
112-
_XLA_AVAILABLE = RequirementCache("torch_xla>=1.13", "torch_xla")
113-
_XLA_GREATER_EQUAL_2_1 = RequirementCache("torch_xla>=2.1")
114-
_XLA_GREATER_EQUAL_2_5 = RequirementCache("torch_xla>=2.5")
115-
116-
117-
def _using_pjrt() -> bool:
118-
# `using_pjrt` is removed in torch_xla 2.5
119-
if _XLA_GREATER_EQUAL_2_5:
120-
from torch_xla import runtime as xr
121-
122-
return xr.device_type() is not None
123-
# delete me when torch_xla 2.2 is the min supported version, where XRT support has been dropped.
124-
if _XLA_GREATER_EQUAL_2_1:
125-
from torch_xla import runtime as xr
126-
127-
return xr.using_pjrt()
128-
129-
from torch_xla.experimental import pjrt
130-
131-
return pjrt.using_pjrt()
132-
133-
134-
def _parse_tpu_devices(devices: Union[int, str, list[int]]) -> Union[int, list[int]]:
135-
"""Parses the TPU devices given in the format as accepted by the
136-
:class:`~lightning.pytorch.trainer.trainer.Trainer` and :class:`~lightning.fabric.Fabric`.
137-
138-
Args:
139-
devices: An int of 1 or string '1' indicates that 1 core with multi-processing should be used
140-
An int 8 or string '8' indicates that all 8 cores with multi-processing should be used
141-
A single element list of int or string can be used to indicate the specific TPU core to use.
142-
143-
Returns:
144-
A list of tpu cores to be used.
145-
146-
"""
147-
_check_data_type(devices)
148-
if isinstance(devices, str):
149-
devices = _parse_tpu_devices_str(devices)
150-
_check_tpu_devices_valid(devices)
151-
return devices
152-
153-
154-
def _check_tpu_devices_valid(devices: object) -> None:
155-
device_count = XLAAccelerator.auto_device_count()
156-
if (
157-
# support number of devices
158-
isinstance(devices, int)
159-
and devices in {1, device_count}
160-
# support picking a specific device
161-
or isinstance(devices, (list, tuple))
162-
and len(devices) == 1
163-
and 0 <= devices[0] <= device_count - 1
164-
):
165-
return
166-
raise ValueError(
167-
f"`devices` can only be 'auto', 1, {device_count} or [<0-{device_count - 1}>] for TPUs. Got {devices!r}"
168-
)
169-
170-
171-
def _parse_tpu_devices_str(devices: str) -> Union[int, list[int]]:
172-
devices = devices.strip()
173-
try:
174-
return int(devices)
175-
except ValueError:
176-
try:
177-
return [int(x.strip()) for x in devices.split(",") if len(x) > 0]
178-
except ValueError:
179-
raise ValueError(f"Could not parse the selected TPU devices: {devices!r}")

src/lightning/fabric/plugins/environments/kubeflow.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
# limitations under the License.
1414

1515
import logging
16-
import os
1716

1817
from typing_extensions import override
1918

2019
from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment
20+
from lightning.fabric.utilities.imports import _raise_enterprise_not_available
2121

2222
log = logging.getLogger(__name__)
2323

@@ -33,20 +33,28 @@ class KubeflowEnvironment(ClusterEnvironment):
3333
3434
"""
3535

36+
def __init__(self) -> None:
37+
_raise_enterprise_not_available()
38+
from pytorch_lightning_enterprise.plugins.environments.kubeflow import (
39+
KubeflowEnvironment as EnterpriseKubeflowEnvironment,
40+
)
41+
42+
self.kubeflow_impl = EnterpriseKubeflowEnvironment()
43+
3644
@property
3745
@override
3846
def creates_processes_externally(self) -> bool:
39-
return True
47+
return self.kubeflow_impl.creates_processes_externally
4048

4149
@property
4250
@override
4351
def main_address(self) -> str:
44-
return os.environ["MASTER_ADDR"]
52+
return self.kubeflow_impl.main_address
4553

4654
@property
4755
@override
4856
def main_port(self) -> int:
49-
return int(os.environ["MASTER_PORT"])
57+
return self.kubeflow_impl.main_port
5058

5159
@staticmethod
5260
@override
@@ -55,24 +63,24 @@ def detect() -> bool:
5563

5664
@override
5765
def world_size(self) -> int:
58-
return int(os.environ["WORLD_SIZE"])
66+
return self.kubeflow_impl.world_size()
5967

6068
@override
6169
def set_world_size(self, size: int) -> None:
62-
log.debug("KubeflowEnvironment.set_world_size was called, but setting world size is not allowed. Ignored.")
70+
return self.kubeflow_impl.set_world_size(size)
6371

6472
@override
6573
def global_rank(self) -> int:
66-
return int(os.environ["RANK"])
74+
return self.kubeflow_impl.global_rank()
6775

6876
@override
6977
def set_global_rank(self, rank: int) -> None:
70-
log.debug("KubeflowEnvironment.set_global_rank was called, but setting global rank is not allowed. Ignored.")
78+
return self.kubeflow_impl.set_global_rank(rank)
7179

7280
@override
7381
def local_rank(self) -> int:
74-
return 0
82+
return self.kubeflow_impl.local_rank()
7583

7684
@override
7785
def node_rank(self) -> int:
78-
return self.global_rank()
86+
return self.kubeflow_impl.node_rank()

0 commit comments

Comments
 (0)