@@ -169,7 +169,7 @@ def _check_inner(self, node):
169169 return False
170170 self .reduction_depth = 0
171171 if node .target in [
172- torch .ops .aten .scaled_dot_product_attention ,
172+ torch .ops .aten .scaled_dot_product_attention . default ,
173173 ]:
174174 # Attention: input (batch_size, sequence_length, hidden_size)
175175 # or (batch_size, kv_num_heads, total_sequence_length, head_size)
@@ -181,10 +181,10 @@ def _check_inner(self, node):
181181 )
182182 self .reduction_depth = hidden_size
183183 elif node .target in [
184- torch .ops .aten .convolution ,
185- torch .ops .aten .conv1d ,
186- torch .ops .aten .conv2d ,
187- torch .ops .aten .conv3d ,
184+ torch .ops .aten .convolution . default ,
185+ torch .ops .aten .conv1d . default ,
186+ torch .ops .aten .conv2d . default ,
187+ torch .ops .aten .conv3d . default ,
188188 ]:
189189 # Conv: input (N x C x D1 x D2 ... x Dn)
190190 # weight (out_channels, in_channels, kD1, kD2, ... kDn)
@@ -201,10 +201,11 @@ def _check_inner(self, node):
201201 self .reduction_depth = in_channels * kernel_volume
202202 elif node .target in [
203203 torch .ops .aten .matmul ,
204- torch .ops .aten .dot ,
205- torch .ops .aten .mm ,
206- torch .ops .aten .mv ,
207- torch .ops .aten .bmm ,
204+ torch .ops .aten .matmul .default ,
205+ torch .ops .aten .dot .default ,
206+ torch .ops .aten .mm .default ,
207+ torch .ops .aten .mv .default ,
208+ torch .ops .aten .bmm .default ,
208209 ]:
209210 # GEMM: A (M, K) @ B (K, N) = C (M, N)
210211 self .reduction_depth = input_0_dims [- 1 ]
0 commit comments