11import torch
22import torch .nn as nn
3+ import torch .nn .functional as F
34import copy
45import logging
56
@@ -72,10 +73,13 @@ def __init__(self, dense_module):
7273 self .padding = dense_module .padding
7374 self .dilation = dense_module .dilation
7475 self .groups = dense_module .groups
76+ self .padding_mode = dense_module .padding_mode
77+ self ._reversed_padding_repeated_twice = dense_module ._reversed_padding_repeated_twice
7578 self .prepack_input_shape = dense_module .input_shape if hasattr (dense_module , "input_shape" ) else []
7679 self .weight_channels_last = dense_module .weight .is_contiguous (memory_format = torch .channels_last ) \
7780 or dense_module .weight .is_contiguous (memory_format = torch .channels_last_3d )
7881 self .weight_size = dense_module .weight .size ()
82+ self ._real_padding = self .padding if self .padding_mode == 'zeros' else tuple ([0 ] * (len (self .weight_size ) - 2 ))
7983
8084 # TODO: ".clone()" will make weight shared by multiple module not shared anymore
8185 # related issues: https://github.com/intel-innersource/frameworks.ai.pytorch.ipex-cpu/issues/65
@@ -91,7 +95,7 @@ def __init__(self, dense_module):
9195 self .register_parameter ('bias' , None )
9296 # create conv op context
9397 self .ctx = torch .ops .ipex_prepack .convolution_prepack (
94- dense_module .weight , self .bias , self .stride , self .padding ,
98+ dense_module .weight , self .bias , self .stride , self ._real_padding ,
9599 self .dilation , self .groups ,
96100 self .weight_channels_last , self .prepack_input_shape
97101 )
@@ -117,14 +121,32 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
117121 with torch .no_grad ():
118122 loaded_weight , loaded_bias , fp32_loaded_weight , weight_trail = _load_from_state_dict_pre_hook (self , state_dict , prefix )
119123 loaded_ctx = torch .ops .ipex_prepack .convolution_prepack (
120- loaded_weight , loaded_bias , self .stride , self .padding ,
124+ loaded_weight , loaded_bias , self .stride , self ._real_padding ,
121125 self .dilation , self .groups ,
122126 self .weight_channels_last , self .prepack_input_shape
123127 )
124128 _load_from_state_dict_post_hook (self , loaded_ctx , fp32_loaded_weight , weight_trail )
125129
126130 def forward (self , x ):
127- return torch .ops .torch_ipex .convolution_forward (x , self .weight , self .bias , self .ctx .get_data_handle (), self .weight_size , self .padding , self .stride , self .dilation )
131+ if self .padding_mode != 'zeros' :
132+ return torch .ops .torch_ipex .convolution_forward (
133+ F .pad (x , self ._reversed_padding_repeated_twice , mode = self .padding_mode ),
134+ self .weight ,
135+ self .bias ,
136+ self .ctx .get_data_handle (),
137+ self .weight_size ,
138+ self ._real_padding ,
139+ self .stride ,
140+ self .dilation )
141+ return torch .ops .torch_ipex .convolution_forward (
142+ x ,
143+ self .weight ,
144+ self .bias ,
145+ self .ctx .get_data_handle (),
146+ self .weight_size ,
147+ self ._real_padding ,
148+ self .stride ,
149+ self .dilation )
128150
129151class _IPEXConv1d (_IPEXConvNd ):
130152 def __init__ (self , dense_module ):
@@ -457,10 +479,13 @@ def record_input_shape_for_prepack(module, sample_input):
457479
458480 def hook_function (self , input ):
459481 # input for linear/conv/transpose conv received here will be Tuple[Tensor]
460- self .input_shape = input [0 ].shape
482+ if self in [torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d ] and self .padding_mode != 'zeros' :
483+ self .input_shape = F .pad (input [0 ], self ._reversed_padding_repeated_twice , mode = self .padding_mode ).shape
484+ else :
485+ self .input_shape = input [0 ].shape
461486
462487 def register_hook_function (module ):
463- if type (module ) in [torch .nn .Linear , torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .ConvTranspose2d ]:
488+ if type (module ) in [torch .nn .Linear , torch .nn .Conv1d , torch .nn .Conv2d , torch .nn .Conv3d , torch . nn . ConvTranspose2d ]:
464489 module .register_forward_pre_hook (hook_function )
465490
466491 def register_hook_function_rec (module ):
0 commit comments