Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
</div>
<hr/>

**Note:** Torch 2.10 builds are still based on PyTorch release candidates.
Typically the ABI does not break during release candidates. If it does,
you have to recompile your kernels with the final 2.10.0 release.

[Join us on Discord](https://discord.gg/H6Tkmd88N3) for questions and discussions!

This repo contains a Nix package that can be used to build custom machine learning kernels for PyTorch. The kernels are built using the [PyTorch C++ Frontend](https://pytorch.org/cppdocs/frontend.html) and can be loaded from the Hub with the [kernels](https://github.com/huggingface/kernels)
Expand Down
13 changes: 13 additions & 0 deletions build-variants.json
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
{
"aarch64-darwin": {
"cpu": [
"torch210-cpu-aarch64-darwin",
"torch28-cpu-aarch64-darwin",
"torch29-cpu-aarch64-darwin"
],
"metal": [
"torch210-metal-aarch64-darwin",
"torch28-metal-aarch64-darwin",
"torch29-metal-aarch64-darwin"
]
},
"aarch64-linux": {
"cpu": [
"torch210-cxx11-cpu-aarch64-linux",
"torch28-cxx11-cpu-aarch64-linux",
"torch29-cxx11-cpu-aarch64-linux"
],
"cuda": [
"torch210-cxx11-cu126-aarch64-linux",
"torch210-cxx11-cu128-aarch64-linux",
"torch210-cxx11-cu130-aarch64-linux",
"torch28-cxx11-cu129-aarch64-linux",
"torch29-cxx11-cu126-aarch64-linux",
"torch29-cxx11-cu128-aarch64-linux",
Expand All @@ -23,10 +29,14 @@
},
"x86_64-linux": {
"cpu": [
"torch210-cxx11-cpu-x86_64-linux",
"torch28-cxx11-cpu-x86_64-linux",
"torch29-cxx11-cpu-x86_64-linux"
],
"cuda": [
"torch210-cxx11-cu126-x86_64-linux",
"torch210-cxx11-cu128-x86_64-linux",
"torch210-cxx11-cu130-x86_64-linux",
"torch28-cxx11-cu126-x86_64-linux",
"torch28-cxx11-cu128-x86_64-linux",
"torch28-cxx11-cu129-x86_64-linux",
Expand All @@ -35,12 +45,15 @@
"torch29-cxx11-cu130-x86_64-linux"
],
"rocm": [
"torch210-cxx11-rocm70-x86_64-linux",
"torch210-cxx11-rocm71-x86_64-linux",
"torch28-cxx11-rocm63-x86_64-linux",
"torch28-cxx11-rocm64-x86_64-linux",
"torch29-cxx11-rocm63-x86_64-linux",
"torch29-cxx11-rocm64-x86_64-linux"
],
"xpu": [
"torch210-cxx11-xpu20253-x86_64-linux",
"torch28-cxx11-xpu20251-x86_64-linux",
"torch29-cxx11-xpu20252-x86_64-linux"
]
Expand Down
6 changes: 4 additions & 2 deletions build2cmake/src/templates/xpu/dep-cutlass-sycl.cmake
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
find_package(CutlassSycl)

if(DPCPP_VERSION STREQUAL "2025.2")
if(DPCPP_VERSION STREQUAL "2025.3")
set(CUTLASS_SYCL_REVISION "14055e78510b8776ba739755eb57e592fdceefdb" CACHE STRING "CUTLASS revision to use")
elseif(DPCPP_VERSION STREQUAL "2025.2")
set(CUTLASS_SYCL_REVISION "14055e78510b8776ba739755eb57e592fdceefdb" CACHE STRING "CUTLASS revision to use")
elseif(DPCPP_VERSION STREQUAL "2025.1")
set(CUTLASS_SYCL_REVISION "v3.9-0.3" CACHE STRING "CUTLASS revision to use")
Expand Down Expand Up @@ -67,7 +69,7 @@ endif()
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=spir64" sycl_link_flags "${sycl_link_flags}")
string(REPLACE "-device pvc,xe-lpg,ats-m150" "-device bmg_g21,pvc" sycl_link_flags "${sycl_link_flags}")
string(APPEND sycl_link_flags "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
if(DPCPP_VERSION STREQUAL "2025.2" OR CUTLASS_SYCL_REVISION STREQUAL "v0.5")
if(DPCPP_VERSION STREQUAL "2025.2" OR DPCPP_VERSION STREQUAL "2025.3" OR CUTLASS_SYCL_REVISION STREQUAL "v0.5")
string(APPEND sycl_link_flags ",+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate")
endif()
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=spir64" sycl_flags "${sycl_flags}")
Expand Down
2 changes: 1 addition & 1 deletion build2cmake/src/torch/metal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub fn write_torch_ext_metal(

let mut file_set = FileSet::default();

let ops_name = kernel_ops_identifier(&target_dir, &build.general.name, ops_id);
let ops_name = kernel_ops_identifier(&target_dir, &build.general.python_name(), ops_id);

write_cmake(
env,
Expand Down
13 changes: 13 additions & 0 deletions docs/build-variants.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,43 @@ available. This list will be updated as new PyTorch versions are released.

## CPU aarch64-darwin

- `torch210-cpu-aarch64-darwin`
- `torch28-cpu-aarch64-darwin`
- `torch29-cpu-aarch64-darwin`

## Metal aarch64-darwin

- `torch210-metal-aarch64-darwin`
- `torch28-metal-aarch64-darwin`
- `torch29-metal-aarch64-darwin`

## CPU aarch64-linux

- `torch210-cxx11-cpu-aarch64-linux`
- `torch28-cxx11-cpu-aarch64-linux`
- `torch29-cxx11-cpu-aarch64-linux`

## CUDA aarch64-linux

- `torch210-cxx11-cu126-aarch64-linux`
- `torch210-cxx11-cu128-aarch64-linux`
- `torch210-cxx11-cu130-aarch64-linux`
- `torch28-cxx11-cu129-aarch64-linux`
- `torch29-cxx11-cu126-aarch64-linux`
- `torch29-cxx11-cu128-aarch64-linux`
- `torch29-cxx11-cu130-aarch64-linux`

## CPU x86_64-linux

- `torch210-cxx11-cpu-x86_64-linux`
- `torch28-cxx11-cpu-x86_64-linux`
- `torch29-cxx11-cpu-x86_64-linux`

## CUDA x86_64-linux

- `torch210-cxx11-cu126-x86_64-linux`
- `torch210-cxx11-cu128-x86_64-linux`
- `torch210-cxx11-cu130-x86_64-linux`
- `torch28-cxx11-cu126-x86_64-linux`
- `torch28-cxx11-cu128-x86_64-linux`
- `torch28-cxx11-cu129-x86_64-linux`
Expand All @@ -43,13 +53,16 @@ available. This list will be updated as new PyTorch versions are released.

## ROCm x86_64-linux

- `torch210-cxx11-rocm70-x86_64-linux`
- `torch210-cxx11-rocm71-x86_64-linux`
- `torch28-cxx11-rocm63-x86_64-linux`
- `torch28-cxx11-rocm64-x86_64-linux`
- `torch29-cxx11-rocm63-x86_64-linux`
- `torch29-cxx11-rocm64-x86_64-linux`

## XPU x86_64-linux

- `torch210-cxx11-xpu20253-x86_64-linux`
- `torch28-cxx11-xpu20251-x86_64-linux`
- `torch29-cxx11-xpu20252-x86_64-linux`

Expand Down
14 changes: 7 additions & 7 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions flake.nix
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

inputs = {
flake-utils.url = "github:numtide/flake-utils";
# Put back to nixos-unstable-small the next bump. Exact revision is
# to avoid a rebuild during the hf-nix -> kernel-builder transition.
nixpkgs.url = "github:NixOS/nixpkgs/c543a59edf25ada193719764f3bc0c6ba835f94d";
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
flake-compat.url = "github:edolstra/flake-compat";
};

Expand Down
11 changes: 10 additions & 1 deletion lib/torch-extension/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
rocmSupport ? torch.rocmSupport,
xpuSupport ? torch.xpuSupport,

pkgs,
lib,
callPackage,
stdenv,
Expand Down Expand Up @@ -30,7 +31,15 @@ let
);

cuda_nvcc = cudaPackages.cuda_nvcc.override {
backendStdenv = cudaPackages.backendStdenv.override {
backendStdenv = import ../../pkgs/cuda/backendStdenv {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backendStdenv itself is not overridable anymore in upstream nixpkgs (.override is on the actual stdenv and not what used to be the wrapper). So I decided the easiest is to vendor this functionality until I come up with a better idea :).

inherit (pkgs)
_cuda
config
lib
pkgs
stdenvAdapters
;
inherit (cudaPackages) cudaMajorMinorVersion;
stdenv = effectiveStdenv;
};
};
Expand Down
26 changes: 24 additions & 2 deletions overlay.nix
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,22 @@ in

remove-bytecode-hook = prev.callPackage ./pkgs/remove-bytecode-hook { };

stdenvGlibc_2_27 = prev.callPackage ./pkgs/stdenv-glibc-2_27 { };
stdenvGlibc_2_27 = import ./pkgs/stdenv-glibc-2_27 {
# Do not use callPackage, because we want overrides to apply to
# the stdenv itself and not this file.
inherit (final)
config
fetchFromGitHub
overrideCC
wrapBintoolsWith
wrapCCWith
gcc13Stdenv
stdenv
bintools-unwrapped
cudaPackages
libgcc
;
};

ucx = prev.ucx.overrideAttrs (
_: prevAttrs: {
Expand Down Expand Up @@ -107,6 +122,11 @@ in
xpuPackages = final.xpuPackages_2025_2;
};

torch-bin_2_10 = mkTorch {
version = "2.10";
xpuPackages = final.xpuPackages_2025_3;
};

torch_2_8 = callPackage ./pkgs/python-modules/torch/source/2_8 {
xpuPackages = final.xpuPackages_2025_1;
};
Expand Down Expand Up @@ -139,7 +159,8 @@ in
versions = [
"6.3.4"
"6.4.2"
"7.0.1"
"7.0.2"
"7.1.1"
];
newRocmPackages = final.callPackage ./pkgs/rocm-packages { };
in
Expand All @@ -159,6 +180,7 @@ in
xpuVersions = [
"2025.1.3"
"2025.2.1"
"2025.3.1"
];
newXpuPackages = final.callPackage ./pkgs/xpu-packages { };
in
Expand Down
63 changes: 59 additions & 4 deletions pkgs/aotriton/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ let
find "$out" -name .git -print0 | xargs -0 rm -rf
'';
mkImages =
srcs:
version: srcs:
stdenvNoCC.mkDerivation {
name = "images";
name = "images-${version}";

inherit srcs;

Expand Down Expand Up @@ -69,7 +69,7 @@ in
"gfx1201"
];

images = mkImages [
images = mkImages version [
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.10b/aotriton-0.10b-manylinux_2_28_x86_64-rocm6.3-shared.tar.gz";
hash = "sha256-hhzZ90ee7JQ5M8J8uGkgJH5bXdE5vHwTdsgYCKu31/4=";
Expand Down Expand Up @@ -107,7 +107,7 @@ in
"gfx1201"
];

images = mkImages [
images = mkImages version [
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.11b/aotriton-0.11b-images-amd-gfx90a.tar.gz";
hash = "sha256-wZpByUgFEKsy5vsF5u0KODLWsHY08FC4NrdgIAvvpzU=";
Expand All @@ -132,4 +132,59 @@ in

extraPythonDepends = ps: [ ps.pandas ];
};

aotriton_0_11_1 = generic rec {
version = "0.11.1b";

src = fetchFromGitHub {
owner = "ROCm";
repo = "aotriton";
rev = version;
hash = "sha256-F7JjyS+6gMdCpOFLldTsNJdVzzVwd6lwW7+V8ZOZfig=";
leaveDotGit = true;
inherit postFetch;
};

patches = [
# Fails with: ld.lld: error: unable to insert .comment after .comment
./v0.11.1b-no-ld-script.diff
];

gpuTargets = [
# aotriton GPU support list:
# https://github.com/ROCm/aotriton/blob/main/v2python/gpu_targets.py
"gfx90a"
"gfx942"
"gfx950"
"gfx1100"
"gfx1151"
"gfx1201"
];

images = mkImages version [
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx90a.tar.gz";
hash = "sha256-/p8Etmv1KsJ80CXh2Jz9BJdN0/s64HYZL3g2QaTYD98=";
})
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx942.tar.gz";
hash = "sha256-CnvO4Z07ttVIcyJIwyNPe5JzbCq3p6rmUpS4en/WTAY=";
})
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx950.tar.gz";
hash = "sha256-wbo7/oQhf9Z9890fi2fICn97M9CtTXS0HWVnA24DKs4=";
})
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx11xx.tar.gz";
hash = "sha256-ZjIEDEBdgzvm/3ICkknHdoOLr18Do8E7pOjTeoe3p0A=";
})
(fetchurl {
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx120x.tar.gz";
hash = "sha256-Ck/zJL/9rAwv3oeop/cFY9PISoCtTo8xNF8rQKE4TpU=";
})
];

extraPythonDepends = ps: [ ps.pandas ];
};

}
Loading
Loading