diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 59e26ad09..ca568f72d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -7,8 +7,8 @@ import torch -from bitsandbytes.cextension import BNB_HIP_VERSION import bitsandbytes.functional as F +from bitsandbytes.gpu_specs import get_compute_capabilities # math.prod not compatible with python < 3.8 @@ -224,8 +224,8 @@ def supports_igemmlt(device: torch.device) -> bool: if device == torch.device("cpu"): return True if torch.version.hip: - return False if BNB_HIP_VERSION < 601 else True - if torch.cuda.get_device_capability(device=device) < (7, 5): + return False if get_compute_capabilities() < (6, 1) else True + if get_compute_capabilities() < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index ad478431c..53edc94ad 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -24,7 +24,7 @@ from .base import Backend -if lib and lib.compiled_with_cuda: +if lib and lib.compiled_with_gpu: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index cc5d8deff..d863ad41e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -24,29 +24,28 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch +from bitsandbytes.gpu_specs import GPUSpecs, get_gpu_specs, get_rocm_gpu_arch logger = logging.getLogger(__name__) -def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: +def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path: """ - Get the disk path to the CUDA BNB native library specified by the - given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable. + Get the disk path to the GPU BNB native library specified by the + given GPU specs, taking into account the `BNB_GPU_VERSION` override environment variable. The library is not guaranteed to exist at the returned path. """ - if torch.version.hip: - if BNB_HIP_VERSION < 601: - return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}" - else: - return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}{DYNAMIC_LIBRARY_SUFFIX}" - library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" - if not cuda_specs.has_cublaslt: + library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}" + if not gpu_specs.enable_blaslt: # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - library_name += "_nocublaslt" + if gpu_specs.gpu_backend == "rocm": + library_name += "_nohipblaslt" + else: + library_name += "_nocublaslt" library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + # Do I need to change it to BNB_GPU_VERSION here? IGNORE FOR NOW! override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: library_name_stem, _, library_name_ext = library_name.rpartition(".") @@ -69,7 +68,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: class BNBNativeLibrary: _lib: ct.CDLL - compiled_with_cuda = False + compiled_with_gpu = False def __init__(self, lib: ct.CDLL): self._lib = lib @@ -78,8 +77,8 @@ def __getattr__(self, item): return getattr(self._lib, item) -class CudaBNBNativeLibrary(BNBNativeLibrary): - compiled_with_cuda = True +class GpuBNBNativeLibrary(BNBNativeLibrary): + compiled_with_gpu = True def __init__(self, lib: ct.CDLL): super().__init__(lib) @@ -93,18 +92,18 @@ def __init__(self, lib: ct.CDLL): def get_native_library() -> BNBNativeLibrary: binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" - cuda_specs = get_cuda_specs() - if cuda_specs: - cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) - if cuda_binary_path.exists(): - binary_path = cuda_binary_path + gpu_specs = get_gpu_specs() + if gpu_specs: + gpu_binary_path = get_gpu_bnb_library_path(gpu_specs) + if gpu_binary_path.exists(): + binary_path = gpu_binary_path else: - logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path) + logger.warning("Could not find the bitsandbytes %s binary at %r", gpu_specs.gpu_backend, gpu_binary_path) logger.debug(f"Loading bitsandbytes native library from: {binary_path}") dll = ct.cdll.LoadLibrary(str(binary_path)) if hasattr(dll, "get_context"): # only a CUDA-built library exposes this - return CudaBNBNativeLibrary(dll) + return GpuBNBNativeLibrary(dll) return BNBNativeLibrary(dll) @@ -113,15 +112,11 @@ def get_native_library() -> BNBNativeLibrary: try: if torch.version.hip: - hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) - HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor - BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}" BNB_BACKEND = "ROCm" + HIP_ENVIRONMENT = True else: - HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 - BNB_HIP_VERSION_SHORT = "" BNB_BACKEND = "CUDA" - + HIP_ENVIRONMENT = False lib = get_native_library() except Exception as e: lib = None diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/gpu.py similarity index 81% rename from bitsandbytes/diagnostics/cuda.py rename to bitsandbytes/diagnostics/gpu.py index 014b753a9..e9ae0c71e 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/gpu.py @@ -5,14 +5,14 @@ import torch -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL -from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.diagnostics.utils import print_dedented +from bitsandbytes.gpu_specs import GPUSpecs -CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") +GPU_RT_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") -CUDART_PATH_IGNORED_ENVVARS = { +GPU_RT_PATH_IGNORED_ENVVARS = { "DBUS_SESSION_BUS_ADDRESS", # hardware related "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks "HOME", # Linux shell default @@ -46,7 +46,7 @@ def get_runtime_lib_patterns() -> tuple: ) -def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: +def find_gpu_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: for dir_string in paths_list_candidate.split(os.pathsep): if not dir_string: continue @@ -70,10 +70,10 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: return ( - env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location + env_var in GPU_RT_PATH_PREFERRED_ENVVARS # is a preferred location or ( os.sep in value # might contain a path - and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored + and env_var not in GPU_RT_PATH_IGNORED_ENVVARS # not ignored and "CONDA" not in env_var # not another conda envvar and "BASH_FUNC" not in env_var # not a bash function defined via envvar and "\n" not in value # likely e.g. a script or something? @@ -85,7 +85,7 @@ def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} -def find_cudart_libraries() -> Iterator[Path]: +def find_gpu_rt_libraries() -> Iterator[Path]: """ Searches for a cuda installations, in the following order of priority: 1. active conda env @@ -99,23 +99,23 @@ def find_cudart_libraries() -> Iterator[Path]: """ candidate_env_vars = get_potentially_lib_path_containing_env_vars() - for envvar in CUDART_PATH_PREFERRED_ENVVARS: + for envvar in GPU_RT_PATH_PREFERRED_ENVVARS: if envvar in candidate_env_vars: directory = candidate_env_vars[envvar] - yield from find_cuda_libraries_in_path_list(directory) + yield from find_gpu_libraries_in_path_list(directory) candidate_env_vars.pop(envvar) for env_var, value in candidate_env_vars.items(): - yield from find_cuda_libraries_in_path_list(value) + yield from find_gpu_libraries_in_path_list(value) -def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None: print( - f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " - f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + f"PyTorch settings found: CUDA_VERSION={gpu_specs.backend_version_string}, " + f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.", ) - binary_path = get_cuda_bnb_library_path(cuda_specs) + binary_path = get_gpu_bnb_library_path(gpu_specs) if not binary_path.exists(): print_dedented( f""" @@ -128,7 +128,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: """, ) - cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + cuda_major, cuda_minor = gpu_specs.backend_version_tuple if cuda_major < 11: print_dedented( """ @@ -140,7 +140,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") # 7.5 is the minimum CC for cublaslt - if not cuda_specs.has_cublaslt: + if not gpu_specs.enable_blaslt: print_dedented( """ WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! @@ -154,10 +154,10 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") +def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.backend_version_string}") - binary_path = get_cuda_bnb_library_path(cuda_specs) + binary_path = get_gpu_bnb_library_path(gpu_specs) if not binary_path.exists(): print_dedented( f""" @@ -168,7 +168,7 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: """, ) - hip_major, hip_minor = cuda_specs.cuda_version_tuple + hip_major, hip_minor = gpu_specs.backend_version_tuple if (hip_major, hip_minor) < (6, 1): print_dedented( """ @@ -177,18 +177,18 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: ) -def print_diagnostics(cuda_specs: CUDASpecs) -> None: +def print_diagnostics(gpu_specs: GPUSpecs) -> None: if HIP_ENVIRONMENT: - _print_hip_diagnostics(cuda_specs) + _print_hip_diagnostics(gpu_specs) else: - _print_cuda_diagnostics(cuda_specs) + _print_cuda_diagnostics(gpu_specs) def _print_cuda_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: + gpu_rt_paths = list(find_gpu_rt_libraries()) + if not gpu_rt_paths: print("WARNING! CUDA runtime files not found in any environmental path.") - elif len(cudart_paths) > 1: + elif len(gpu_rt_paths) > 1: print_dedented( f""" Found duplicate CUDA runtime files (see below). @@ -207,15 +207,15 @@ def _print_cuda_runtime_diagnostics() -> None: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, """, ) - for pth in cudart_paths: + for pth in gpu_rt_paths: print(f"* Found CUDA runtime at: {pth}") def _print_hip_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: + gpu_rt_paths = list(find_gpu_rt_libraries()) + if not gpu_rt_paths: print("WARNING! ROCm runtime files not found in any environmental path.") - elif len(cudart_paths) > 1: + elif len(gpu_rt_paths) > 1: print_dedented( f""" Found duplicate ROCm runtime files (see below). @@ -230,7 +230,7 @@ def _print_hip_runtime_diagnostics() -> None: """, ) - for pth in cudart_paths: + for pth in gpu_rt_paths: print(f"* Found ROCm runtime at: {pth}") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 8dc43ed2a..a47ce9f7e 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -5,12 +5,12 @@ from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL -from bitsandbytes.cuda_specs import get_cuda_specs -from bitsandbytes.diagnostics.cuda import ( +from bitsandbytes.diagnostics.gpu import ( print_diagnostics, print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header +from bitsandbytes.gpu_specs import get_gpu_specs def sanity_check(): @@ -50,20 +50,20 @@ def main(): print_header("") print_header("OTHER") - cuda_specs = get_cuda_specs() + gpu_specs = get_gpu_specs() if HIP_ENVIRONMENT: - rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + rocm_specs = f" rocm_version_string='{gpu_specs.backend_version_string}'," + rocm_specs += f" rocm_version_tuple={gpu_specs.backend_version_tuple}" print(f"{BNB_BACKEND} specs:{rocm_specs}") else: - print(f"{BNB_BACKEND} specs:{cuda_specs}") + print(f"{BNB_BACKEND} specs:{gpu_specs}") if not torch.cuda.is_available(): print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") print(f"1. {BNB_BACKEND} driver not installed") print(f"2. {BNB_BACKEND} not installed") print(f"3. You have multiple conflicting {BNB_BACKEND} libraries") - if cuda_specs: - print_diagnostics(cuda_specs) + if gpu_specs: + print_diagnostics(gpu_specs) print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6cf64df28..6fa74d5aa 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -25,7 +25,7 @@ def prod(iterable): name2qmap = {} -if lib and lib.compiled_with_cuda: +if lib and lib.compiled_with_gpu: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/gpu_specs.py similarity index 52% rename from bitsandbytes/cuda_specs.py rename to bitsandbytes/gpu_specs.py index 0afecd3ea..2df48441b 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/gpu_specs.py @@ -2,27 +2,44 @@ import logging import re import subprocess -from typing import List, Optional, Tuple +from typing import Optional, Tuple, Union import torch @dataclasses.dataclass(frozen=True) -class CUDASpecs: +class GPUSpecs: + gpu_backend: str highest_compute_capability: Tuple[int, int] - cuda_version_string: str - cuda_version_tuple: Tuple[int, int] + backend_version_string: str + backend_version_tuple: Tuple[int, int] @property - def has_cublaslt(self) -> bool: - return self.highest_compute_capability >= (7, 5) + def enable_blaslt(self) -> bool: + if torch.version.hip: + return self.highest_compute_capability >= (6, 1) + else: + return self.highest_compute_capability >= (7, 5) + + +def get_gpu_backend() -> str: + if torch.version.hip: + return "rocm" + else: + return "cuda" -def get_compute_capabilities() -> List[Tuple[int, int]]: - return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) +def get_compute_capabilities() -> Tuple[int, int]: + if torch.version.hip: + hip_major, hip_minor = get_backend_version_tuple() + return (hip_major, hip_minor) + else: + return sorted( + torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()) + )[-1] -def get_cuda_version_tuple() -> Tuple[int, int]: +def get_backend_version_tuple() -> Tuple[int, int]: # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION if torch.version.cuda: major, minor = map(int, torch.version.cuda.split(".")) @@ -31,19 +48,20 @@ def get_cuda_version_tuple() -> Tuple[int, int]: return major, minor -def get_cuda_version_string() -> str: - major, minor = get_cuda_version_tuple() +def get_backend_version_string() -> str: + major, minor = get_backend_version_tuple() return f"{major}{minor}" -def get_cuda_specs() -> Optional[CUDASpecs]: +def get_gpu_specs() -> Optional[GPUSpecs]: if not torch.cuda.is_available(): return None - return CUDASpecs( - highest_compute_capability=(get_compute_capabilities()[-1]), - cuda_version_string=(get_cuda_version_string()), - cuda_version_tuple=get_cuda_version_tuple(), + return GPUSpecs( + gpu_backend=get_gpu_backend(), + highest_compute_capability=(get_compute_capabilities()), + backend_version_string=(get_backend_version_string()), + backend_version_tuple=get_backend_version_tuple(), ) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index eafa01f0e..49d5368a6 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -4,7 +4,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import BNB_HIP_VERSION +from bitsandbytes.gpu_specs import get_compute_capabilities from tests.helpers import ( BOOLEAN_TRIPLES, BOOLEAN_TUPLES, @@ -199,7 +199,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 -@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 53dd25044..7612d68af 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,41 +1,41 @@ import pytest -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path +from bitsandbytes.gpu_specs import GPUSpecs @pytest.fixture -def cuda120_spec() -> CUDASpecs: - return CUDASpecs( - cuda_version_string="120", +def cuda120_spec() -> GPUSpecs: + return GPUSpecs( + backend_version_string="120", highest_compute_capability=(8, 6), - cuda_version_tuple=(12, 0), + backend_version_tuple=(12, 0), ) @pytest.fixture -def cuda111_noblas_spec() -> CUDASpecs: - return CUDASpecs( - cuda_version_string="111", +def cuda111_noblas_spec() -> GPUSpecs: + return GPUSpecs( + backend_version_string="111", highest_compute_capability=(7, 2), - cuda_version_tuple=(11, 1), + backend_version_tuple=(11, 1), ) @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") -def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): +def test_get_gpu_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" + assert get_gpu_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") -def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): +def test_get_gpu_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") - assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" + assert get_gpu_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") -def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): +def test_get_gpu_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" + assert get_gpu_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_functional.py b/tests/test_functional.py index 35187db78..2e3e8c90c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.gpu_specs import get_compute_capabilities from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -512,7 +513,7 @@ def test_vector_quant(dim1, dim2, dim3): assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) -@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) @@ -1817,7 +1818,7 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) -@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_extract_outliers(device): for i in range(k):