Skip to content

Commit 4e56c0b

Browse files
committed
Add docstring
1 parent 266adfa commit 4e56c0b

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/transformers/utils/import_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,6 +1293,23 @@ def is_torch_fx_proxy(x):
12931293

12941294

12951295
def is_jax_jitting(x):
1296+
"""returns True if we are inside of `jax.jit` context, False otherwise.
1297+
1298+
When a torch model is being compiled with `jax.jit` using torchax,
1299+
the tensor that goes through the model would be an instance of
1300+
`torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
1301+
a `jax` method to return the inner Jax array
1302+
(https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
1303+
Here we use ducktyping to detect if the inner jax array is a jax Tracer
1304+
then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
1305+
1306+
Args:
1307+
x: torch.Tensor
1308+
1309+
Returns:
1310+
bool: whether we are inside of jax jit tracing.
1311+
"""
1312+
12961313
if not hasattr(x, "jax"):
12971314
return False
12981315
try:

0 commit comments

Comments
 (0)