-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
import bitblas
import torch
# enabling debug output
bitblas.set_log_level("Debug")
model = bitblas.Linear(
in_features=1024,
out_features=1024,
bias=False,
A_dtype="float16", # activation A dtype
W_dtype="uint4", # weight W dtype
accum_dtype="float32", # accumulation dtype
out_dtype="float16", # output dtype
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=True, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
# Target optimization var for dynamic symbolic.
# For detailed information please checkout docs/PythonAPI.md
# By default, the optimization var is [1, 16, 32, 64, 128, 256, 512]
opt_M=[1, 16, 32, 64, 128],
)
# Create an integer weight tensor
intweight = torch.randint(0, 15, (1024, 1024), dtype=torch.int8).cuda()
# Load and transform weights into the BitBLAS linear module
model.load_and_transform_weight(intweight)
model.scales.uniform_(0.1, 0.2)
# Set the model to evaluation mode
model.eval()
dummpy_input = torch.randn(1, 1024, dtype=torch.float16).cuda()
print(model.qweight, model.scales, dummpy_input)
output = model(dummpy_input)
print(output)output
2025-03-10 16:58:41 [BitBLAS:INFO]: Loaded 10 operators from database.
BitBLAS Operator found in global_operator_cache.
tensor([[ 38, 112, 41, ..., 112, -22, 56],
[ 37, 122, -88, ..., 121, 74, 5],
[ 96, 105, -72, ..., 71, 19, -50],
...,
[ 41, -114, -27, ..., -107, -87, -63],
[ -68, 26, -61, ..., 17, -34, 62],
[ -69, 104, 10, ..., -92, 90, 118]], device='cuda:0',
dtype=torch.int8) tensor([[0.1353],
[0.1842],
[0.1820],
...,
[0.1693],
[0.1836],
[0.1566]], dtype=torch.float16) tensor([[ 0.2061, -1.6357, 1.1240, ..., -0.5273, 1.7285, 2.1289]],
device='cuda:0', dtype=torch.float16)
error
File ~/anaconda3/envs/profile/lib/python3.10/site-packages/torch/_tensor_str.py:145, in _Formatter.__init__(self, tensor)
142 self.max_width = max(self.max_width, len(value_str))
144 else:
--> 145 nonzero_finite_vals = torch.masked_select(
146 tensor_view, torch.isfinite(tensor_view) & tensor_view.ne(0)
147 )
149 if nonzero_finite_vals.numel() == 0:
150 # no valid number, do nothing
151 return
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Metadata
Metadata
Assignees
Labels
No labels