From 4ac7eea516a00966a710720fcf66a6566d71a296 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Wed, 3 Dec 2025 15:25:59 -0800 Subject: [PATCH 1/2] Let transformers know when a model is being traced via jax.jit Move check earler --- src/transformers/utils/import_utils.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3c10269c44a5..e2e817af3e86 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1308,6 +1308,17 @@ def is_torch_fx_proxy(x): return False +def is_jax_jitting(x): + if not hasattr(x, "jax"): + return False + try: + import jax + + return isinstance(x.jax(), jax.core.Tracer) + except Exception: + return False + + def is_jit_tracing() -> bool: try: import torch @@ -1333,6 +1344,7 @@ def is_tracing(tensor=None) -> bool: _is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing() if tensor is not None: _is_tracing |= is_torch_fx_proxy(tensor) + _is_tracing |= is_jax_jitting(tensor) return _is_tracing From 544c7609cfacb0c94f449b2ebb939298a32ac1b7 Mon Sep 17 00:00:00 2001 From: Han Qi Date: Fri, 5 Dec 2025 10:42:31 -0800 Subject: [PATCH 2/2] Add docstring --- src/transformers/utils/import_utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index e2e817af3e86..35581ed28a24 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1309,6 +1309,23 @@ def is_torch_fx_proxy(x): def is_jax_jitting(x): + """returns True if we are inside of `jax.jit` context, False otherwise. + + When a torch model is being compiled with `jax.jit` using torchax, + the tensor that goes through the model would be an instance of + `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has + a `jax` method to return the inner Jax array + (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134). + Here we use ducktyping to detect if the inner jax array is a jax Tracer + then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241) + + Args: + x: torch.Tensor + + Returns: + bool: whether we are inside of jax jit tracing. + """ + if not hasattr(x, "jax"): return False try: