diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3c10269c44a5..35581ed28a24 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1308,6 +1308,34 @@ def is_torch_fx_proxy(x): return False +def is_jax_jitting(x): + """returns True if we are inside of `jax.jit` context, False otherwise. + + When a torch model is being compiled with `jax.jit` using torchax, + the tensor that goes through the model would be an instance of + `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has + a `jax` method to return the inner Jax array + (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134). + Here we use ducktyping to detect if the inner jax array is a jax Tracer + then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241) + + Args: + x: torch.Tensor + + Returns: + bool: whether we are inside of jax jit tracing. + """ + + if not hasattr(x, "jax"): + return False + try: + import jax + + return isinstance(x.jax(), jax.core.Tracer) + except Exception: + return False + + def is_jit_tracing() -> bool: try: import torch @@ -1333,6 +1361,7 @@ def is_tracing(tensor=None) -> bool: _is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing() if tensor is not None: _is_tracing |= is_torch_fx_proxy(tensor) + _is_tracing |= is_jax_jitting(tensor) return _is_tracing