Skip to content

Commit 5f2ce96

Browse files
authored
[bugfix]: deepgemm online quant (#1130)
1 parent f756420 commit 5f2ce96

File tree

5 files changed

+21
-9
lines changed

5 files changed

+21
-9
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,12 @@ def _fuse(self):
422422
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
423423
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
424424
if not self.quantized_weight and self.quant_method is not None:
425-
self.w1 = self.quant_method.quantize(w1)
426-
self.w2 = self.quant_method.quantize(w2)
425+
qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1)
426+
qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2)
427+
self.w1[0] = qw1
428+
self.w1[1] = qw1_scale
429+
self.w2[0] = qw2
430+
self.w2[1] = qw2_scale
427431
else:
428432
self.w1[0] = self._cuda(w1)
429433
self.w2[0] = self._cuda(w2)

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep_redundancy.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,15 @@ def _fuse(self):
102102
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
103103
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
104104
if not self._ep_w.quantized_weight and self._ep_w.quant_method is not None:
105-
self.w1 = self._ep_w.quant_method.quantize(w1)
106-
self.w2 = self._ep_w.quant_method.quantize(w2)
105+
qw1, qw1_scale, qw1_zero_point = self._ep_w.quant_method.quantize(w1)
106+
qw2, qw2_scale, qw2_zero_point = self._ep_w.quant_method.quantize(w2)
107+
self.w1[0] = qw1
108+
self.w1[1] = qw1_scale
109+
self.w2[0] = qw2
110+
self.w2[1] = qw2_scale
107111
else:
108112
self.w1[0] = w1
109113
self.w2[0] = w2
110-
111114
delattr(self, "w2_list")
112115
delattr(self, "experts_up_projs")
113116
delattr(self, "experts_gate_projs")

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,12 @@ def _fuse(self):
182182
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
183183
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
184184
if not self.quantized_weight and self.quant_method is not None:
185-
self.w1 = self.quant_method.quantize(w1)
186-
self.w2 = self.quant_method.quantize(w2)
185+
qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1)
186+
qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2)
187+
self.w1[0] = qw1
188+
self.w1[1] = qw1_scale
189+
self.w2[0] = qw2
190+
self.w2[1] = qw2_scale
187191
else:
188192
self.w1[0] = self._cuda(w1)
189193
self.w2[0] = self._cuda(w2)

lightllm/common/quantization/deepgemm_quant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def method_name(self):
6363
def quantize(self, weight: torch.Tensor):
6464
from lightllm.common.quantization.triton_quant.fp8.fp8w8a8_block_quant_kernel import weight_quant
6565

66-
return weight_quant(weight, self.block_size)
66+
weight, scale = weight_quant(weight, self.block_size)
67+
return weight, scale, None
6768

6869
def apply(
6970
self,

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_quant_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,4 @@ def weight_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor,
5555
return y_quant, s_scales
5656
else:
5757
y_quant, s_scales = mm_weight_quant(x, block_size)
58-
return y_quant.t(), s_scales.t()
58+
return y_quant, s_scales

0 commit comments

Comments
 (0)