Skip to content

Conversation

@qihqi
Copy link
Contributor

@qihqi qihqi commented Dec 3, 2025

What does this PR do?

transformers current recognizes when the PyTorch model is
being traced instead of executed, so it knows to not use untraceable ops (data-dependent ops).
However, currently it recognizes torchscript, fx trace, dynamo but does not recognize jax.jit via torchax.

This PR adds detecting jax.jit.

Fixes google/torchax#56

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [ x] Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@vasqu @ArthurZucker @Cyrilvallez

@qihqi qihqi force-pushed the qihqi-support-jax-jit branch from 9aa3291 to a041420 Compare December 3, 2025 23:26
Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Would just like to add it to the docstring as well if you don't mind! 🤗

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, change looks good and after some internal discussion we're happy to approve!

Copy link
Contributor

@tengomucho tengomucho left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution @qihqi !

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Rocketknight1
Copy link
Member

CI issues are unrelated and seem to be transient connection timeouts - @cyril do you want to force merge, or will we just wait and rebase?

@qihqi qihqi force-pushed the qihqi-support-jax-jit branch from 4e56c0b to 544c760 Compare December 5, 2025 18:44
@qihqi
Copy link
Contributor Author

qihqi commented Dec 6, 2025

LGTM! Would just like to add it to the docstring as well if you don't mind! 🤗

done.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.export fails for (Distil)BERT transformers models due to missing tracing signals

5 participants