From 4095693fc781f195d8c3823020445468951e09d4 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Dec 2025 12:46:44 +0100 Subject: [PATCH 1/7] deprecate method --- src/lightning/pytorch/core/module.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 37b07f025f8e9..ce2e336bf80c0 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -64,7 +64,7 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.model_helpers import _restricted_classmethod -from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn +from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_deprecation, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature from lightning.pytorch.utilities.types import ( _METRIC, @@ -1498,6 +1498,11 @@ def to_torchscript( scripted you should override this method. In case you want to return multiple modules, we recommend using a dictionary. + .. deprecated:: + ``LightningModule.to_torchscript`` has been deprecated in v2.7 and will be removed in v2.8. + TorchScript is deprecated in PyTorch. Use ``torch.export.export()`` for model exporting instead. + See https://pytorch.org/docs/stable/export.html for more information. + Args: file_path: Path where to save the torchscript. Default: None (no file saved). method: Whether to use TorchScript's script or trace method. Default: 'script' @@ -1536,6 +1541,11 @@ def forward(self, x): defined or not. """ + rank_zero_deprecation( + "`LightningModule.to_torchscript` has been deprecated in v2.5 and will be removed in v2.7. " + "TorchScript is deprecated in PyTorch. Use `torch.export.export()` for model exporting instead. " + "See https://pytorch.org/docs/stable/export.html for more information." + ) mode = self.training if method == "script": From 6b9db2cf89d2e69623741002aa32785570b6a4b7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Dec 2025 12:48:17 +0100 Subject: [PATCH 2/7] deprecate method --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index ce2e336bf80c0..bae7f876c8211 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1542,7 +1542,7 @@ def forward(self, x): """ rank_zero_deprecation( - "`LightningModule.to_torchscript` has been deprecated in v2.5 and will be removed in v2.7. " + "`LightningModule.to_torchscript` has been deprecated in v2.7 and will be removed in v2.8. " "TorchScript is deprecated in PyTorch. Use `torch.export.export()` for model exporting instead. " "See https://pytorch.org/docs/stable/export.html for more information." ) From 8e42016b389844f391ee11d8308aafe63bed5b9e Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Dec 2025 12:50:37 +0100 Subject: [PATCH 3/7] add deprecation to tests --- tests/tests_pytorch/helpers/test_models.py | 3 +- .../tests_pytorch/models/test_torchscript.py | 54 ++++++++++++++----- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 721641ae8343a..2c25289ea1dde 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -46,7 +46,8 @@ def test_models(tmp_path, data_class, model_class): if dm is not None: trainer.test(model, datamodule=dm) - model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + model.to_torchscript() if data_class: model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 29f251044c0b5..9657d0bcc6534 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -21,6 +21,7 @@ from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.rank_zero import LightningDeprecationWarning from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN @@ -36,7 +37,8 @@ def test_torchscript_input_output(modelclass): if isinstance(model, BoringModel): model.example_input_array = torch.randn(5, 32) - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert isinstance(script, torch.jit.ScriptModule) model.eval() @@ -59,7 +61,8 @@ def test_torchscript_example_input_output_trace(modelclass): if isinstance(model, BoringModel): model.example_input_array = torch.randn(5, 32) - script = model.to_torchscript(method="trace") + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(method="trace") assert isinstance(script, torch.jit.ScriptModule) model.eval() @@ -74,7 +77,8 @@ def test_torchscript_input_output_trace(): """Test that traced LightningModule forward works with example_inputs.""" model = BoringModel() example_inputs = torch.randn(1, 32) - script = model.to_torchscript(example_inputs=example_inputs, method="trace") + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(example_inputs=example_inputs, method="trace") assert isinstance(script, torch.jit.ScriptModule) model.eval() @@ -99,7 +103,8 @@ def test_torchscript_device(device_str): model = BoringModel().to(device) model.example_input_array = torch.randn(5, 32) - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert next(script.parameters()).device == device script_output = script(model.example_input_array.to(device)) assert script_output.device == device @@ -121,7 +126,8 @@ def test_torchscript_device_with_check_inputs(device_str): check_inputs = torch.rand(5, 32) - script = model.to_torchscript(method="trace", check_inputs=check_inputs) + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(method="trace", check_inputs=check_inputs) assert isinstance(script, torch.jit.ScriptModule) @@ -129,11 +135,13 @@ def test_torchscript_retain_training_state(): """Test that torchscript export does not alter the training mode of original model.""" model = BoringModel() model.train(True) - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert model.training assert not script.training model.train(False) - _ = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + _ = model.to_torchscript() assert not model.training assert not script.training @@ -142,7 +150,8 @@ def test_torchscript_retain_training_state(): def test_torchscript_properties(modelclass): """Test that scripted LightningModule has unnecessary methods removed.""" model = modelclass() - script = model.to_torchscript() + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() assert not hasattr(model, "batch_size") or hasattr(script, "batch_size") assert not hasattr(model, "learning_rate") or hasattr(script, "learning_rate") assert not callable(getattr(script, "training_step", None)) @@ -153,7 +162,8 @@ def test_torchscript_save_load(tmp_path, modelclass): """Test that scripted LightningModule is correctly saved and can be loaded.""" model = modelclass() output_file = str(tmp_path / "model.pt") - script = model.to_torchscript(file_path=output_file) + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(file_path=output_file) loaded_script = torch.jit.load(output_file) assert torch.allclose(next(script.parameters()), next(loaded_script.parameters())) @@ -170,7 +180,8 @@ class DummyFileSystem(LocalFileSystem): ... model = modelclass() output_file = os.path.join(_DUMMY_PRFEIX, _PREFIX_SEPARATOR, tmp_path, "model.pt") - script = model.to_torchscript(file_path=output_file) + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript(file_path=output_file) fs = get_filesystem(output_file) with fs.open(output_file, "rb") as f: @@ -184,7 +195,10 @@ def test_torchcript_invalid_method(): model = BoringModel() model.train(True) - with pytest.raises(ValueError, match="only supports 'script' or 'trace'"): + with ( + pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"), + pytest.raises(ValueError, match="only supports 'script' or 'trace'"), + ): model.to_torchscript(method="temp") @@ -193,7 +207,10 @@ def test_torchscript_with_no_input(): model = BoringModel() model.example_input_array = None - with pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"): + with ( + pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"), + pytest.raises(ValueError, match="requires either `example_inputs` or `model.example_input_array`"), + ): model.to_torchscript(method="trace") @@ -224,6 +241,17 @@ def forward(self, inputs): lm = Parent() assert not lm._jit_is_scripting - script = lm.to_torchscript(method="script") + with pytest.deprecated_call(match="has been deprecated in v2.7 and will be removed in v2.8"): + script = lm.to_torchscript(method="script") assert not lm._jit_is_scripting assert isinstance(script, torch.jit.RecursiveScriptModule) + + +def test_to_torchscript_deprecation(): + """Test that to_torchscript raises a deprecation warning.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.warns(LightningDeprecationWarning, match="has been deprecated in v2.7 and will be removed in v2.8"): + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) From adbcae7a907b3ebce46b58f3c24dbd9e63615801 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Dec 2025 12:52:26 +0100 Subject: [PATCH 4/7] remove example from readme --- README.md | 14 ++++++++++++-- src/pytorch_lightning/README.md | 8 -------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 35afe7477f343..f221f7d09f843 100644 --- a/README.md +++ b/README.md @@ -324,12 +324,22 @@ trainer = Trainer(callbacks=[checkpointing])
- Export to torchscript (JIT) (production use) + Export to torchscript (JIT) (production use) - DEPRECATED + +> **⚠️ Deprecated**: `to_torchscript()` is deprecated in PyTorch Lightning v2.5 and will be removed in v2.7. +> TorchScript is deprecated in PyTorch. Use `torch.export.export()` instead. +> See [PyTorch Export Documentation](https://pytorch.org/docs/stable/export.html) for more information. ```python -# torchscript +# torchscript (deprecated) autoencoder = LitAutoEncoder() torch.jit.save(autoencoder.to_torchscript(), "model.pt") + +# Recommended alternative using torch.export +import torch +autoencoder = LitAutoEncoder() +exported_program = torch.export.export(autoencoder, (torch.randn(1, 64),)) +torch.export.save(exported_program, "model.pt2") ```
diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index 86176359f6231..f4e6131a2fbef 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -271,14 +271,6 @@ checkpointing = ModelCheckpoint(monitor="val_loss") trainer = Trainer(callbacks=[checkpointing]) ``` -Export to torchscript (JIT) (production use) - -```python -# torchscript -autoencoder = LitAutoEncoder() -torch.jit.save(autoencoder.to_torchscript(), "model.pt") -``` - Export to ONNX (production use) ```python From c0e1f195cef8773f57756a21fbf40a0f8350b267 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 1 Dec 2025 12:52:39 +0100 Subject: [PATCH 5/7] remove example from readme --- README.md | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/README.md b/README.md index f221f7d09f843..e84c025406ba4 100644 --- a/README.md +++ b/README.md @@ -323,26 +323,6 @@ trainer = Trainer(callbacks=[checkpointing]) -
- Export to torchscript (JIT) (production use) - DEPRECATED - -> **⚠️ Deprecated**: `to_torchscript()` is deprecated in PyTorch Lightning v2.5 and will be removed in v2.7. -> TorchScript is deprecated in PyTorch. Use `torch.export.export()` instead. -> See [PyTorch Export Documentation](https://pytorch.org/docs/stable/export.html) for more information. - -```python -# torchscript (deprecated) -autoencoder = LitAutoEncoder() -torch.jit.save(autoencoder.to_torchscript(), "model.pt") - -# Recommended alternative using torch.export -import torch -autoencoder = LitAutoEncoder() -exported_program = torch.export.export(autoencoder, (torch.randn(1, 64),)) -torch.export.save(exported_program, "model.pt2") -``` - -
Export to ONNX (production use) From 4fc3c63c56f2693d726eb0d0aa1e38fd9b7f8af0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 2 Dec 2025 09:12:51 +0100 Subject: [PATCH 6/7] changelog --- src/lightning/pytorch/CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 3558fe3cadc66..b117a524e3e47 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,6 +16,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - +### Deprecated + +- Deprecated `to_torchscript` method due to deprecation of TorchScript in PyTorch ([#21397](https://github.com/Lightning-AI/pytorch-lightning/pull/21397)) + ### Removed - From feee20f4a4b2a68f918a4e3f04c897602d58c027 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 9 Dec 2025 08:05:17 +0100 Subject: [PATCH 7/7] remove readme changes --- README.md | 10 ++++++++++ src/pytorch_lightning/README.md | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/README.md b/README.md index e84c025406ba4..35afe7477f343 100644 --- a/README.md +++ b/README.md @@ -323,6 +323,16 @@ trainer = Trainer(callbacks=[checkpointing])
+
+ Export to torchscript (JIT) (production use) + +```python +# torchscript +autoencoder = LitAutoEncoder() +torch.jit.save(autoencoder.to_torchscript(), "model.pt") +``` + +
Export to ONNX (production use) diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index f4e6131a2fbef..86176359f6231 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -271,6 +271,14 @@ checkpointing = ModelCheckpoint(monitor="val_loss") trainer = Trainer(callbacks=[checkpointing]) ``` +Export to torchscript (JIT) (production use) + +```python +# torchscript +autoencoder = LitAutoEncoder() +torch.jit.save(autoencoder.to_torchscript(), "model.pt") +``` + Export to ONNX (production use) ```python