@@ -217,25 +217,22 @@ at::ScalarType get_bias_dtype(
217217std::vector<float > get_mkldnn_weight_scales_of_lstm (
218218 const at::Tensor& weight_ih,
219219 const at::Tensor& weight_hh) {
220- std::vector<float > weight_scales;
221-
222220 at::Tensor weight_ih_scales_tensor =
223221 int8::utils::get_weight_scale_tensor (weight_ih);
224222 at::Tensor weight_hh_scales_tensor =
225223 int8::utils::get_weight_scale_tensor (weight_hh);
226224 TORCH_CHECK (
227- weight_ih_scales_tensor.sizes () == weight_hh_scales_tensor.sizes (),
228- " scales of weight_ih and weight_hh should be of same size" );
225+ weight_ih_scales_tensor.sizes () == weight_hh_scales_tensor.sizes () == 1 ,
226+ " Expect scales of LSTM weight_ih and weight_hh to be of size 1 " );
229227
230228 // PyTorch scale: (max - min) / (qmax - qmin)
231229 // oneDNN scale: (qmax - qmin) / (max - min)
232- for (size_t i = 0 ; i < weight_ih_scales_tensor.sizes ()[0 ]; i++) {
233- weight_scales.push_back (
234- 1 . /
235- std::max (
236- weight_ih_scales_tensor[i].item ().toFloat (),
237- weight_hh_scales_tensor[i].item ().toFloat ()));
238- }
230+ auto scale_tensor = 1 . /
231+ torch::maximum (weight_ih_scales_tensor, weight_hh_scales_tensor)
232+ .to (c10::kFloat );
233+ scale_tensor = scale_tensor.contiguous ();
234+ auto s_ptr = scale_tensor.data_ptr <float >();
235+ std::vector<float > weight_scales{s_ptr, s_ptr + scale_tensor.size (0 )};
239236 return weight_scales;
240237}
241238
0 commit comments