Skip to content

Commit a041420

Browse files
committed
Let transformers know when a model is being traced via jax.jit
Move check earler
1 parent a48d68c commit a041420

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/transformers/utils/import_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1292,6 +1292,17 @@ def is_torch_fx_proxy(x):
12921292
return False
12931293

12941294

1295+
def is_jax_jitting(x):
1296+
if not hasattr(x, "jax"):
1297+
return False
1298+
try:
1299+
import jax
1300+
1301+
return isinstance(x.jax(), jax.core.Tracer)
1302+
except Exception:
1303+
return False
1304+
1305+
12951306
def is_jit_tracing() -> bool:
12961307
try:
12971308
import torch
@@ -1317,6 +1328,7 @@ def is_tracing(tensor=None) -> bool:
13171328
_is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
13181329
if tensor is not None:
13191330
_is_tracing |= is_torch_fx_proxy(tensor)
1331+
_is_tracing |= is_jax_jitting(tensor)
13201332
return _is_tracing
13211333

13221334

0 commit comments

Comments
 (0)