diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..b117a524e3e47 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Deprecated + +- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397)) + ### Removed - diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 37b07f025f8e9..bae7f876c8211 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -64,7 +64,7 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.model_helpers import _restricted_classmethod -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_deprecation, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature from lightning.pytorch.utilities.types import ( _METRIC, @@ -1498,6 +1498,11 @@ def to_torchscript( scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary. + .. deprecated:: + ``LightningModule.to_torchscript`` has been deprecated in v2.7 and will be removed in v2.8. + TorchScript is deprecated in PyTorch. Use ``torch.export.export()`` for model exporting instead. + See https://pytorch.org/docs/stable/export.html for more information. + Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' @@ -1536,6 +1541,11 @@ def forward(self, x): defined or not. """ + rank_zero_deprecation( + "`LightningModule.to_torchscript` has been deprecated in v2.7 and will be removed in v2.8. " + "TorchScript is deprecated in PyTorch. Use `torch.export.export()` for model exporting instead. " + "See https://pytorch.org/docs/stable/export.html for more information." + ) mode = self.training if method == "script": diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 721641ae8343a..2c25289ea1dde 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -46,7 +46,8 @@ def test_models(tmp_path, data_class, model_class): if dm is not None: trainer.test(model, datamodule=dm) - model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + model.to_torchscript() if data_class: model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 29f251044c0b5..9657d0bcc6534 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -21,6 +21,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN @@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass): if isinstance(model, BoringModel): model.example_input_array = torch.randn(5, 32) - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) model.eval() @@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass): if isinstance(model, BoringModel): model.example_input_array = torch.randn(5, 32) - script = model.to_torchscript(method="trace") + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(method="trace") assert isinstance(script, torch.jit.ScriptModule) model.eval() @@ -74,7 +77,8 @@ def test_torchscript_input_output_trace(): """Test that traced LightningModule forward works with example_inputs.""" model = BoringModel() example_inputs = torch.randn(1, 32) - script = model.to_torchscript(example_inputs=example_inputs, method="trace") + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(example_inputs=example_inputs, method="trace") assert isinstance(script, torch.jit.ScriptModule) model.eval() @@ -99,7 +103,8 @@ def test_torchscript_device(device_str): model = BoringModel().to(device) model.example_input_array = torch.randn(5, 32) - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) assert script_output.device == device @@ -121,7 +126,8 @@ def test_torchscript_device_with_check_inputs(device_str): check_inputs = torch.rand(5, 32) - script = model.to_torchscript(method="trace", check_inputs=check_inputs) + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(method="trace", check_inputs=check_inputs) assert isinstance(script, torch.jit.ScriptModule) @@ -129,11 +135,13 @@ def test_torchscript_retain_training_state(): """Test that torchscript export does not alter the training mode of original model.""" model = BoringModel() model.train(True) - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert model.training assert not script.training model.train(False) - _ = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + _ = model.to_torchscript() assert not model.training assert not script.training @@ -142,7 +150,8 @@ def test_torchscript_retain_training_state(): def test_torchscript_properties(modelclass): """Test that scripted LightningModule has unnecessary methods removed.""" model = modelclass() - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") assert not callable(getattr(script, "training_step", None)) @@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass): """Test that scripted LightningModule is correctly saved and can be loaded.""" model = modelclass() output_file = str(tmp_path / "model.pt") - script = model.to_torchscript(file_path=output_file) + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(file_path=output_file) loaded_script = torch.jit.load(output_file) assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) @@ -170,7 +180,8 @@ class DummyFileSystem(LocalFileSystem): ... model = modelclass() output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt") - script = model.to_torchscript(file_path=output_file) + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(file_path=output_file) fs = get_filesystem(output_file) with fs.open(output_file, "rb") as f: @@ -184,7 +195,10 @@ def test_torchcript_invalid_method(): model = BoringModel() model.train(True) - with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): + with ( + pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"), + pytest.raises(ValueError, match="only supports 'script' or 'trace'"), + ): model.to_torchscript(method="temp") @@ -193,7 +207,10 @@ def test_torchscript_with_no_input(): model = BoringModel() model.example_input_array = None - with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"): + with ( + pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"), + pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"), + ): model.to_torchscript(method="trace") @@ -224,6 +241,17 @@ def forward(self, inputs): lm = Parent() assert not lm._jit_is_scripting - script = lm.to_torchscript(method="script") + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = lm.to_torchscript(method="script") assert not lm._jit_is_scripting assert isinstance(script, torch.jit.RecursiveScriptModule) + + +def test_to_torchscript_deprecation(): + """Test that to_torchscript raises a deprecation warning.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule)