File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -1293,6 +1293,23 @@ def is_torch_fx_proxy(x):
12931293
12941294
12951295def is_jax_jitting (x ):
1296+ """returns True if we are inside of `jax.jit` context, False otherwise.
1297+
1298+ When a torch model is being compiled with `jax.jit` using torchax,
1299+ the tensor that goes through the model would be an instance of
1300+ `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
1301+ a `jax` method to return the inner Jax array
1302+ (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
1303+ Here we use ducktyping to detect if the inner jax array is a jax Tracer
1304+ then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
1305+
1306+ Args:
1307+ x: torch.Tensor
1308+
1309+ Returns:
1310+ bool: whether we are inside of jax jit tracing.
1311+ """
1312+
12961313 if not hasattr (x , "jax" ):
12971314 return False
12981315 try :
You can’t perform that action at this time.
0 commit comments