Skip to content

Commit 053b990

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent cc60700 commit 053b990

File tree

14 files changed

+37
-26
lines changed

14 files changed

+37
-26
lines changed

docs/source-pytorch/accelerators/musa.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ MUSA training (Advanced)
1010

1111
MUSAAccelerator Overview
1212
--------------------
13-
torch_musa is an extended Python package based on PyTorch that enables full utilization of MooreThreads graphics cards'
14-
super computing power. Combined with PyTorch, users can take advantage of the strong power of MooreThreads graphics cards
13+
torch_musa is an extended Python package based on PyTorch that enables full utilization of MooreThreads graphics cards'
14+
super computing power. Combined with PyTorch, users can take advantage of the strong power of MooreThreads graphics cards
1515
through torch_musa.
1616

1717
PyTorch Lightning automatically finds these weights and ties them after the modules are moved to the

src/lightning/fabric/accelerators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@
1616
from lightning.fabric.accelerators.cpu import CPUAccelerator # noqa: F401
1717
from lightning.fabric.accelerators.cuda import CUDAAccelerator, find_usable_cuda_devices # noqa: F401
1818
from lightning.fabric.accelerators.mps import MPSAccelerator # noqa: F401
19+
from lightning.fabric.accelerators.musa import MUSAAccelerator # noqa: F401
1920
from lightning.fabric.accelerators.registry import _AcceleratorRegistry
2021
from lightning.fabric.accelerators.xla import XLAAccelerator # noqa: F401
21-
from lightning.fabric.accelerators.musa import MUSAAccelerator # noqa: F401
2222
from lightning.fabric.utilities.registry import _register_classes
2323

2424
ACCELERATOR_REGISTRY = _AcceleratorRegistry()

src/lightning/fabric/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from lightning.fabric.accelerators.accelerator import Accelerator
2424
from lightning.fabric.accelerators.cuda import CUDAAccelerator
2525
from lightning.fabric.accelerators.mps import MPSAccelerator
26-
from lightning.fabric.accelerators.xla import XLAAccelerator
2726
from lightning.fabric.accelerators.musa import MUSAAccelerator
27+
from lightning.fabric.accelerators.xla import XLAAccelerator
2828
from lightning.fabric.plugins import (
2929
BitsandbytesPrecision,
3030
CheckpointIO,

src/lightning/fabric/utilities/device_parser.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,15 +86,18 @@ def _parse_gpu_ids(
8686
# We know the user requested GPUs therefore if some of the
8787
# requested GPUs are not available an exception is thrown.
8888
gpus = _normalize_parse_gpu_string_input(gpus)
89-
gpus = _normalize_parse_gpu_input_to_list(gpus, include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)
89+
gpus = _normalize_parse_gpu_input_to_list(
90+
gpus, include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa
91+
)
9092
if not gpus:
9193
raise MisconfigurationException("GPUs requested but none are available.")
9294

9395
if (
9496
torch.distributed.is_available()
9597
and torch.distributed.is_torchelastic_launched()
9698
and len(gpus) != 1
97-
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)) == 1
99+
and len(_get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa))
100+
== 1
98101
):
99102
# Omit sanity check on torchelastic because by default it shows one visible GPU per process
100103
return gpus
@@ -115,7 +118,9 @@ def _normalize_parse_gpu_string_input(s: Union[int, str, list[int]]) -> Union[in
115118
return int(s.strip())
116119

117120

118-
def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False) -> list[int]:
121+
def _sanitize_gpu_ids(
122+
gpus: list[int], include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False
123+
) -> list[int]:
119124
"""Checks that each of the GPUs in the list is actually available. Raises a MisconfigurationException if any of the
120125
GPUs is not available.
121126
@@ -132,7 +137,9 @@ def _sanitize_gpu_ids(gpus: list[int], include_cuda: bool = False, include_mps:
132137
"""
133138
if sum((include_cuda, include_mps, include_musa)) == 0:
134139
raise ValueError("At least one gpu type should be specified!")
135-
all_available_gpus = _get_all_available_gpus(include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa)
140+
all_available_gpus = _get_all_available_gpus(
141+
include_cuda=include_cuda, include_mps=include_mps, include_musa=include_musa
142+
)
136143
for gpu in gpus:
137144
if gpu not in all_available_gpus:
138145
raise MisconfigurationException(
@@ -157,7 +164,9 @@ def _normalize_parse_gpu_input_to_list(
157164
return list(range(gpus))
158165

159166

160-
def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False) -> list[int]:
167+
def _get_all_available_gpus(
168+
include_cuda: bool = False, include_mps: bool = False, include_musa: bool = False
169+
) -> list[int]:
161170
"""
162171
Returns:
163172
A list of all available GPUs
@@ -214,8 +223,8 @@ def _select_auto_accelerator() -> str:
214223
"""Choose the accelerator type (str) based on availability."""
215224
from lightning.fabric.accelerators.cuda import CUDAAccelerator
216225
from lightning.fabric.accelerators.mps import MPSAccelerator
217-
from lightning.fabric.accelerators.xla import XLAAccelerator
218226
from lightning.fabric.accelerators.musa import MUSAAccelerator
227+
from lightning.fabric.accelerators.xla import XLAAccelerator
219228

220229
if XLAAccelerator.is_available():
221230
return "tpu"

src/lightning/fabric/utilities/testing/_runif.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _runif_reasons(
110110
reasons.append("MPS")
111111
elif not mps and MPSAccelerator.is_available():
112112
reasons.append("not MPS")
113-
113+
114114
if musa is not None:
115115
if musa and not MUSAAccelerator.is_available():
116116
reasons.append("MUSA")

src/lightning/pytorch/accelerators/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from lightning.pytorch.accelerators.cpu import CPUAccelerator
3131
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
3232
from lightning.pytorch.accelerators.mps import MPSAccelerator
33-
from lightning.pytorch.accelerators.xla import XLAAccelerator
3433
from lightning.pytorch.accelerators.musa import MUSAAccelerator
34+
from lightning.pytorch.accelerators.xla import XLAAccelerator
3535

3636
AcceleratorRegistry = _AcceleratorRegistry()
3737
_register_classes(AcceleratorRegistry, "register_accelerators", sys.modules[__name__], Accelerator)

src/lightning/pytorch/accelerators/musa.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
# limitations under the License.
1414
import logging
1515
import os
16-
import shutil
17-
import subprocess
1816
from typing import Any, Optional, Union
1917

2018
import torch

src/lightning/pytorch/core/saving.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from lightning.fabric.utilities.cloud_io import _load as pl_load
3535
from lightning.fabric.utilities.data import AttributeDict
3636
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
37-
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator, MUSAAccelerator
37+
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, MUSAAccelerator, XLAAccelerator
3838
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
3939
from lightning.pytorch.utilities.migration import pl_legacy_patch
4040
from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@
3535
from lightning.pytorch.accelerators.accelerator import Accelerator
3636
from lightning.pytorch.accelerators.cuda import CUDAAccelerator
3737
from lightning.pytorch.accelerators.mps import MPSAccelerator
38-
from lightning.pytorch.accelerators.xla import XLAAccelerator
3938
from lightning.pytorch.accelerators.musa import MUSAAccelerator
39+
from lightning.pytorch.accelerators.xla import XLAAccelerator
4040
from lightning.pytorch.plugins import (
4141
_PLUGIN_INPUT,
4242
BitsandbytesPrecision,

src/lightning/pytorch/trainer/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import lightning.pytorch as pl
2020
from lightning.fabric.utilities.warnings import PossibleUserWarning
21-
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, XLAAccelerator, MUSAAccelerator
21+
from lightning.pytorch.accelerators import CUDAAccelerator, MPSAccelerator, MUSAAccelerator, XLAAccelerator
2222
from lightning.pytorch.loggers.logger import DummyLogger
2323
from lightning.pytorch.profilers import (
2424
AdvancedProfiler,

0 commit comments

Comments
 (0)