Skip to content

Commit fed42b1

Browse files
authored
[release/2.0] Optimize INT8 LSTM weight scales calculation (#1566)
* optimize int8 weight scales calculation * add check on scale sizes
1 parent 76dd768 commit fed42b1

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

csrc/cpu/aten/RNN.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,25 +217,22 @@ at::ScalarType get_bias_dtype(
217217
std::vector<float> get_mkldnn_weight_scales_of_lstm(
218218
const at::Tensor& weight_ih,
219219
const at::Tensor& weight_hh) {
220-
std::vector<float> weight_scales;
221-
222220
at::Tensor weight_ih_scales_tensor =
223221
int8::utils::get_weight_scale_tensor(weight_ih);
224222
at::Tensor weight_hh_scales_tensor =
225223
int8::utils::get_weight_scale_tensor(weight_hh);
226224
TORCH_CHECK(
227-
weight_ih_scales_tensor.sizes() == weight_hh_scales_tensor.sizes(),
228-
"scales of weight_ih and weight_hh should be of same size");
225+
weight_ih_scales_tensor.sizes() == weight_hh_scales_tensor.sizes() == 1,
226+
"Expect scales of LSTM weight_ih and weight_hh to be of size 1");
229227

230228
// PyTorch scale: (max - min) / (qmax - qmin)
231229
// oneDNN scale: (qmax - qmin) / (max - min)
232-
for (size_t i = 0; i < weight_ih_scales_tensor.sizes()[0]; i++) {
233-
weight_scales.push_back(
234-
1. /
235-
std::max(
236-
weight_ih_scales_tensor[i].item().toFloat(),
237-
weight_hh_scales_tensor[i].item().toFloat()));
238-
}
230+
auto scale_tensor = 1. /
231+
torch::maximum(weight_ih_scales_tensor, weight_hh_scales_tensor)
232+
.to(c10::kFloat);
233+
scale_tensor = scale_tensor.contiguous();
234+
auto s_ptr = scale_tensor.data_ptr<float>();
235+
std::vector<float> weight_scales{s_ptr, s_ptr + scale_tensor.size(0)};
239236
return weight_scales;
240237
}
241238

0 commit comments

Comments
 (0)