Skip to content

Commit 22ddb7d

Browse files
committed
Update base for Update on "Reduce allocation overhead in quantized sdpa"
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
2 parents 6dbcca4 + 8af8252 commit 22ddb7d

File tree

12 files changed

+797
-33
lines changed

12 files changed

+797
-33
lines changed

backends/cadence/aot/quantizer/quantizer.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,17 @@ def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
372372
# Add 16-bit quantizers for LinearPattern
373373
quantizers.append(CadenceAtenQuantizer(LinearPattern(), qconfig_A16))
374374
super().__init__(quantizers)
375+
376+
377+
class CadenceWith16BitConvActivationsQuantizer(CadenceQuantizer):
378+
"""
379+
Quantizer including A16 conv
380+
"""
381+
382+
def __init__(self, quantizers: Optional[list[Quantizer]] = None) -> None:
383+
if quantizers is None:
384+
quantizers = []
385+
# Add 16-bit quantizers for Conv patterns
386+
quantizers.append(CadenceAtenQuantizer(Conv1dPattern(), qconfig_A16))
387+
quantizers.append(CadenceAtenQuantizer(Conv2dPattern(), qconfig_A16))
388+
super().__init__(quantizers)

backends/cadence/hifi/operators/op_quantized_conv2d_nchw_out.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/backends/cadence/hifi/operators/operators.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_conv2d.h>
1213

1314
#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
1415

@@ -532,6 +533,30 @@ void quantized_conv2d_nchw_out(
532533
__ET_UNUSED const Tensor& out_multiplier,
533534
__ET_UNUSED const Tensor& out_shift,
534535
Tensor& out) {
536+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
537+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
538+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
539+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
540+
::impl::generic::native::quantized_conv2d_nchw_out(
541+
ctx,
542+
input,
543+
weight,
544+
bias,
545+
stride,
546+
padding,
547+
dilation,
548+
groups,
549+
in_zero_point,
550+
weight_zero_point,
551+
bias_scale,
552+
output_scale,
553+
output_zero_point,
554+
out_multiplier,
555+
out_shift,
556+
out);
557+
return;
558+
}
559+
535560
const float bias_scale_float = bias_scale.const_data_ptr<float>()[0];
536561
const int32_t weight_zero_point_int =
537562
weight_zero_point.const_data_ptr<int32_t>()[0];
@@ -596,6 +621,30 @@ void quantized_conv2d_nchw_per_tensor_out(
596621
__ET_UNUSED int64_t out_multiplier,
597622
__ET_UNUSED int64_t out_shift,
598623
Tensor& out) {
624+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
625+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
626+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
627+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
628+
::impl::generic::native::quantized_conv2d_nchw_per_tensor_out(
629+
ctx,
630+
input,
631+
weight,
632+
bias,
633+
stride,
634+
padding,
635+
dilation,
636+
groups,
637+
in_zero_point,
638+
weight_zero_point,
639+
bias_scale,
640+
output_scale,
641+
output_zero_point,
642+
out_multiplier,
643+
out_shift,
644+
out);
645+
return;
646+
}
647+
599648
bool optimized = 0;
600649

601650
if ((input.scalar_type() == ScalarType::Char) ||

backends/cadence/hifi/operators/op_quantized_conv2d_nhwc_out.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/backends/cadence/hifi/operators/operators.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_conv2d.h>
1213

1314
#define ALIGN_PTR(x, bytes) ((((unsigned)(x)) + (bytes - 1)) & (~(bytes - 1)))
1415

@@ -435,9 +436,32 @@ void quantized_conv2d_nhwc_out(
435436
const Tensor& bias_scale,
436437
double output_scale,
437438
int64_t output_zero_point,
438-
__ET_UNUSED const Tensor& out_multiplier,
439-
__ET_UNUSED const Tensor& out_shift,
439+
const Tensor& out_multiplier,
440+
const Tensor& out_shift,
440441
Tensor& out) {
442+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
443+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
444+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
445+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
446+
::impl::generic::native::quantized_conv2d_nhwc_out(
447+
ctx,
448+
input,
449+
weight,
450+
bias,
451+
stride,
452+
padding,
453+
dilation,
454+
groups,
455+
in_zero_point,
456+
weight_zero_point,
457+
bias_scale,
458+
output_scale,
459+
output_zero_point,
460+
out_multiplier,
461+
out_shift,
462+
out);
463+
return;
464+
}
441465
const float bias_scale_float = bias_scale.const_data_ptr<float>()[0];
442466
const int32_t weight_zero_point_int =
443467
weight_zero_point.const_data_ptr<int32_t>()[0];
@@ -502,8 +526,31 @@ void quantized_conv2d_nhwc_per_tensor_out(
502526
__ET_UNUSED int64_t out_multiplier,
503527
__ET_UNUSED int64_t out_shift,
504528
Tensor& out) {
505-
bool optimized = 0;
529+
// Handle W8A16 heterogeneous type (int16_t activations, int8_t weights)
530+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
531+
input.scalar_type() == ::executorch::aten::ScalarType::Short &&
532+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
533+
::impl::generic::native::quantized_conv2d_nhwc_per_tensor_out(
534+
ctx,
535+
input,
536+
weight,
537+
bias,
538+
stride,
539+
padding,
540+
dilation,
541+
groups,
542+
in_zero_point,
543+
weight_zero_point,
544+
bias_scale,
545+
output_scale,
546+
output_zero_point,
547+
out_multiplier,
548+
out_shift,
549+
out);
550+
return;
551+
}
506552

553+
bool optimized = 0;
507554
if ((input.scalar_type() == ScalarType::Char) ||
508555
(input.scalar_type() == ScalarType::Byte))
509556
optimized = 1;

backends/cadence/hifi/operators/op_quantized_linear_out.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1010
#include <executorch/backends/cadence/hifi/operators/operators.h>
1111
#include <executorch/runtime/kernel/kernel_includes.h>
12+
#include <on_device_ai/Assistant/Jarvis/min_runtime/operators/generic/op_quantized_linear.h>
1213
#include <xa_nnlib_kernels_api.h>
1314
#include <xtensa/tie/xt_datacache.h>
1415
#include <algorithm>
@@ -218,7 +219,22 @@ void quantized_linear_out(
218219
int64_t out_zero_point,
219220
__ET_UNUSED const optional<Tensor>& offset,
220221
Tensor& out) {
221-
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
222+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
223+
in.scalar_type() == ::executorch::aten::ScalarType::Short &&
224+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
225+
::impl::generic::native::quantized_linear_out(
226+
ctx,
227+
in,
228+
weight,
229+
bias,
230+
in_zero_point,
231+
weight_zero_point,
232+
out_multiplier,
233+
out_shift,
234+
out_zero_point,
235+
offset,
236+
out);
237+
} else if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
222238
_quantized_linear_asym8u(
223239
in,
224240
weight,
@@ -260,7 +276,22 @@ void quantized_linear_per_tensor_out(
260276
int64_t out_zero_point,
261277
__ET_UNUSED const optional<Tensor>& offset,
262278
Tensor& out) {
263-
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
279+
if (out.scalar_type() == ::executorch::aten::ScalarType::Short &&
280+
in.scalar_type() == ::executorch::aten::ScalarType::Short &&
281+
weight.scalar_type() == ::executorch::aten::ScalarType::Char) {
282+
::impl::generic::native::quantized_linear_per_tensor_out(
283+
ctx,
284+
in,
285+
weight,
286+
bias,
287+
in_zero_point,
288+
weight_zero_point,
289+
out_multiplier,
290+
out_shift,
291+
out_zero_point,
292+
offset,
293+
out);
294+
} else if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
264295
_quantized_linear_per_tensor_asym8u(
265296
in,
266297
weight,

backends/cadence/hifi/operators/targets.bzl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ OPERATORS = [
6565
"ne",
6666
"permute_copy",
6767
"pow",
68-
"quantized_conv2d_nchw_out",
6968
"quantized_conv2d_nchw_asym8sxsym8s_asym8s_per_tensor_out",
7069
"quantized_conv2d_nchw_asym8uxsym8u_asym8u_per_tensor_out",
7170
"quantized_conv1d_ncl_asym8sxsym8s_asym8s_per_tensor_out",
@@ -74,7 +73,6 @@ OPERATORS = [
7473
"quantized_conv2d_nchw_depthwise_asym8uxsym8u_asym8u_per_tensor_out",
7574
"quantized_conv2d_nchw_dilated_asym8sxsym8s_asym8s_per_tensor_out",
7675
"quantized_conv2d_nchw_dilated_asym8uxsym8u_asym8u_per_tensor_out",
77-
"quantized_conv2d_nhwc_out",
7876
"quantized_conv2d_nhwc_asym8sxsym8s_asym8s_per_tensor_out",
7977
"quantized_conv2d_nhwc_asym8uxsym8u_asym8u_per_tensor_out",
8078
"quantized_conv1d_nlc_asym8sxsym8s_asym8s_per_tensor_out",
@@ -87,7 +85,6 @@ OPERATORS = [
8785
"quantized_fully_connected_asym8sxasym8s_asym8s_per_tensor_out",
8886
"quantized_fully_connected_asym8uxasym8u_asym8u_per_tensor_out",
8987
"quantized_layer_norm",
90-
"quantized_linear_out",
9188
"quantized_linear_asym8sxasym8s_asym8s_per_tensor_out",
9289
"quantized_linear_asym8uxasym8u_asym8u_per_tensor_out",
9390
"quantized_matmul_out",
@@ -122,3 +119,11 @@ def define_common_targets():
122119
# Define build targets for all operators registered in the tables above.
123120
for op in OPERATORS:
124121
define_operator(op)
122+
123+
# quantized_linear_out and quantized_linear_per_tensor_out needs additional dependency for int16 support
124+
define_operator("quantized_linear_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_linear"])
125+
define_operator("quantized_linear_per_tensor_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_linear"])
126+
127+
# quantized_conv2d_nchw_out and quantized_conv2d_nhwc_out need additional dependency for int16 support
128+
define_operator("quantized_conv2d_nchw_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_conv2d"])
129+
define_operator("quantized_conv2d_nhwc_out", deps=["fbcode//on_device_ai/Assistant/Jarvis/min_runtime/operators/generic:op_quantized_conv2d"])

0 commit comments

Comments
 (0)