File tree Expand file tree Collapse file tree 1 file changed +17
-0
lines changed
Expand file tree Collapse file tree 1 file changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -1309,6 +1309,23 @@ def is_torch_fx_proxy(x):
13091309
13101310
13111311def is_jax_jitting (x ):
1312+ """returns True if we are inside of `jax.jit` context, False otherwise.
1313+
1314+ When a torch model is being compiled with `jax.jit` using torchax,
1315+ the tensor that goes through the model would be an instance of
1316+ `torchax.tensor.Tensor`, which is a tensor subclass. This tensor has
1317+ a `jax` method to return the inner Jax array
1318+ (https://github.com/google/torchax/blob/13ce870a1d9adb2430333c27bb623469e3aea34e/torchax/tensor.py#L134).
1319+ Here we use ducktyping to detect if the inner jax array is a jax Tracer
1320+ then we are in tracing context. (See more at: https://github.com/jax-ml/jax/discussions/9241)
1321+
1322+ Args:
1323+ x: torch.Tensor
1324+
1325+ Returns:
1326+ bool: whether we are inside of jax jit tracing.
1327+ """
1328+
13121329 if not hasattr (x , "jax" ):
13131330 return False
13141331 try :
You can’t perform that action at this time.
0 commit comments