We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a48d68c commit 9aa3291Copy full SHA for 9aa3291
src/transformers/utils/import_utils.py
@@ -1292,6 +1292,17 @@ def is_torch_fx_proxy(x):
1292
return False
1293
1294
1295
+def is_jax_jitting(x):
1296
+ try:
1297
+ import jax
1298
+
1299
+ if not hasattr(x, "jax"):
1300
+ return False
1301
+ return isinstance(x.jax(), jax.core.Tracer)
1302
+ except Exception:
1303
1304
1305
1306
def is_jit_tracing() -> bool:
1307
try:
1308
import torch
@@ -1317,6 +1328,7 @@ def is_tracing(tensor=None) -> bool:
1317
1328
_is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
1318
1329
if tensor is not None:
1319
1330
_is_tracing |= is_torch_fx_proxy(tensor)
1331
+ _is_tracing |= is_jax_jitting(tensor)
1320
1332
return _is_tracing
1321
1333
1322
1334
0 commit comments