File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
intel_extension_for_pytorch/quantization Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -231,6 +231,10 @@ def _maybe_observe(arg, tensor_info):
231231 # TODO: do not run this twice on input and output
232232 if str (tensor_id ) in self .tensor_id_to_observer :
233233 observer = self .tensor_id_to_observer [str (tensor_id )]
234+ if isinstance (arg , torch .Tensor ) and arg .dtype != torch .float32 :
235+ dtype = arg .dtype
236+ out = observer (arg .float ())
237+ return out .to (dtype )
234238 return observer (arg )
235239 else :
236240 return arg
@@ -290,7 +294,7 @@ def _observer_output(output, tensor_info):
290294 tensor_id = tensor_info .id
291295 if str (tensor_id ) in self .tensor_id_to_observer :
292296 obs = self .tensor_id_to_observer [str (tensor_id )]
293- obs (output )
297+ obs (output . float () )
294298 if isinstance (outputs , torch .Tensor ):
295299 tensor_info = seen_q_op_info .output_tensor_infos [0 ]
296300 _observer_output (outputs , tensor_info )
You can’t perform that action at this time.
0 commit comments