Skip to content

Commit 4ac7eea

Browse files
committed
Let transformers know when a model is being traced via jax.jit
Move check earler
1 parent 1d86d00 commit 4ac7eea

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/transformers/utils/import_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,6 +1308,17 @@ def is_torch_fx_proxy(x):
13081308
return False
13091309

13101310

1311+
def is_jax_jitting(x):
1312+
if not hasattr(x, "jax"):
1313+
return False
1314+
try:
1315+
import jax
1316+
1317+
return isinstance(x.jax(), jax.core.Tracer)
1318+
except Exception:
1319+
return False
1320+
1321+
13111322
def is_jit_tracing() -> bool:
13121323
try:
13131324
import torch
@@ -1333,6 +1344,7 @@ def is_tracing(tensor=None) -> bool:
13331344
_is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
13341345
if tensor is not None:
13351346
_is_tracing |= is_torch_fx_proxy(tensor)
1347+
_is_tracing |= is_jax_jitting(tensor)
13361348
return _is_tracing
13371349

13381350

0 commit comments

Comments
 (0)