-
Notifications
You must be signed in to change notification settings - Fork 588
Solve pytorch-triton and triton package contention #2540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…for jax Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
Signed-off-by: tdophung <tdophung@nvidia.com>
for more information, see https://pre-commit.ci
| num_ctas, # arg2: num_ctas (int) | ||
| compiled.metadata.shared, # arg3: shared_mem_bytes (int) | ||
| compiled.asm["ptx"], # arg4: ptx (str) | ||
| "", # arg5: ttir (str) - empty |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will soon be the same as main. as this change is made here in: #1921, to be merged. it is just in this PR so I can test triton calls locally with the nitghtly jax container without running into errors because of jax 0.8.2+
Greptile SummaryThis PR resolves package conflicts between
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant setup.py
participant build_tools/jax.py
participant build_tools/pytorch.py
participant triton_extensions/utils.py
Note over User,triton_extensions/utils.py: Package Installation Flow
User->>setup.py: pip install transformer-engine[pytorch]
setup.py->>build_tools/pytorch.py: install_requirements()
build_tools/pytorch.py-->>setup.py: ["pytorch-triton", ...]
User->>setup.py: pip install transformer-engine[jax]
setup.py->>build_tools/jax.py: test_requirements()
build_tools/jax.py->>build_tools/jax.py: Check NVTE_USE_PYTORCH_TRITON
alt NVTE_USE_PYTORCH_TRITON=1
build_tools/jax.py-->>setup.py: ["pytorch-triton", ...]
else Default
build_tools/jax.py-->>setup.py: ["triton", ...]
end
Note over User,triton_extensions/utils.py: Runtime Detection Flow
User->>triton_extensions/utils.py: import triton_extensions
triton_extensions/utils.py->>triton_extensions/utils.py: _detect_triton_package()
triton_extensions/utils.py->>triton_extensions/utils.py: _check_triton_compatibility()
alt Placeholder package detected
triton_extensions/utils.py-->>User: ImportError
else pytorch-triton without env var
triton_extensions/utils.py-->>User: UserWarning
else Valid triton
triton_extensions/utils.py-->>User: Success
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
transformer_engine/jax/triton_extensions/utils.py, line 322 (link)syntax: Typo:
compile.nameshould becompiled.name. The variablecompileis not defined in this scope - onlycompiledexists from line 300. This will cause aNameErrorat runtime for JAX versions < 0.8.2.
4 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
build_tools/pytorch.py, line 21 (link)style: Placeholder text
<version??>should be replaced with an actual version (e.g.,cu121orcu124) or made generic.
4 files reviewed, 1 comment
Description
pytorch-tritonandtritonpackages install to the same location at site-packages/triton, andtritondoes not work for pytorch's torch.compile() call as there are a few things pytorch has added onto their version of triton (creatingpytorch-tritonto make it work and validated it with the release of torch). Howeverpytorch-tritonshould in theory (and experimented) still be compatible with how jax uses it*. Pending more explanation from XLA: https://nvidia.slack.com/archives/C03L7BHTNEM/p1766452411154419Fixes # (issue)
Type of change
Changes
Checklist: