We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1d86d00 commit 4ac7eeaCopy full SHA for 4ac7eea
src/transformers/utils/import_utils.py
@@ -1308,6 +1308,17 @@ def is_torch_fx_proxy(x):
1308
return False
1309
1310
1311
+def is_jax_jitting(x):
1312
+ if not hasattr(x, "jax"):
1313
+ return False
1314
+ try:
1315
+ import jax
1316
+
1317
+ return isinstance(x.jax(), jax.core.Tracer)
1318
+ except Exception:
1319
1320
1321
1322
def is_jit_tracing() -> bool:
1323
try:
1324
import torch
@@ -1333,6 +1344,7 @@ def is_tracing(tensor=None) -> bool:
1333
1344
_is_tracing = is_torchdynamo_compiling() or is_jit_tracing() or is_cuda_stream_capturing()
1334
1345
if tensor is not None:
1335
1346
_is_tracing |= is_torch_fx_proxy(tensor)
1347
+ _is_tracing |= is_jax_jitting(tensor)
1336
1348
return _is_tracing
1337
1349
1338
1350
0 commit comments