Skip to content

Commit 544c760

Browse files
committed
Add docstring
1 parent 4ac7eea commit 544c760

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

src/transformers/utils/import_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1309,6 +1309,23 @@ def is_torch_fx_proxy(x):
13091309

13101310

13111311
def is_jax_jitting(x):
1312+
"""returns True if we are inside of `jax.jit` context, False otherwise.
1313+
1314+
When a torch model is being compiled with `jax.jit` using torchax,
1315+
the tensor that goes through the model would be an instance of
1316+
`torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
1317+
a `jax` method to return the inner Jax array
1318+
(https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
1319+
Here we use ducktyping to detect if the inner jax array is a jax Tracer
1320+
then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
1321+
1322+
Args:
1323+
x: torch.Tensor
1324+
1325+
Returns:
1326+
bool: whether we are inside of jax jit tracing.
1327+
"""
1328+
13121329
if not hasattr(x, "jax"):
13131330
return False
13141331
try:

0 commit comments

Comments
 (0)