@@ -116,7 +116,8 @@ at::Tensor convolution_kernel(
116116 at::IntArrayRef padding,
117117 at::IntArrayRef dilation,
118118 int64_t groups,
119- const ideep::attr_t & attr) {
119+ const ideep::attr_t & attr,
120+ at::MemoryFormat memory_format) {
120121 // Base convolution kernel, this base kernel will not change input's format,
121122 // so make sure you has make process the input's format before call this
122123 // function, the output wil has same format with input.
@@ -132,9 +133,8 @@ at::Tensor convolution_kernel(
132133
133134 at::Tensor output;
134135 if (input.dim () != 3 ) {
135- output = at::empty (
136- output_sizes,
137- input.options ().memory_format (input.suggest_memory_format ()));
136+ output =
137+ at::empty (output_sizes, input.options ().memory_format (memory_format));
138138 } else {
139139 // This a temporary workaround before channels last 1D is formally supported
140140 // in PyTorch. We will force to return nwc output.
@@ -164,7 +164,8 @@ at::Tensor convolution_forward_impl(
164164 c10::optional<at::IntArrayRef> kernel_size,
165165 c10::optional<at::IntArrayRef> padding,
166166 c10::optional<at::IntArrayRef> stride,
167- c10::optional<at::IntArrayRef> dilation) {
167+ c10::optional<at::IntArrayRef> dilation,
168+ c10::optional<bool > weight_channels_last) {
168169#if defined(IPEX_DISP_OP)
169170 printf (" torch_ipex::convolution_forward_impl\n " );
170171#endif
@@ -385,7 +386,8 @@ at::Tensor IPEXConvolutionOp::_forward(
385386 c10::optional<at::IntArrayRef> kernel_size,
386387 c10::optional<at::IntArrayRef> padding,
387388 c10::optional<at::IntArrayRef> stride,
388- c10::optional<at::IntArrayRef> dilation) {
389+ c10::optional<at::IntArrayRef> dilation,
390+ c10::optional<bool > weight_channels_last) {
389391 at::AutoDispatchBelowADInplaceOrView g;
390392 RECORD_FUNCTION (
391393 " IPEXConvolutionOp::_forward" , c10::ArrayRef<c10::IValue>({}));
@@ -401,7 +403,8 @@ at::Tensor IPEXConvolutionOp::_forward(
401403 kernel_size,
402404 padding,
403405 stride,
404- dilation);
406+ dilation,
407+ weight_channels_last);
405408}
406409
407410at::Tensor IPEXConvolutionOp::forward (
@@ -413,7 +416,8 @@ at::Tensor IPEXConvolutionOp::forward(
413416 c10::optional<at::IntArrayRef> kernel_size,
414417 c10::optional<at::IntArrayRef> padding,
415418 c10::optional<at::IntArrayRef> stride,
416- c10::optional<at::IntArrayRef> dilation) {
419+ c10::optional<at::IntArrayRef> dilation,
420+ c10::optional<bool > weight_channels_last) {
417421 RECORD_FUNCTION (" IPEXConvolutionOp::forward" , c10::ArrayRef<c10::IValue>({}));
418422
419423 at::AutoDispatchBelowADInplaceOrView g;
@@ -432,7 +436,8 @@ at::Tensor IPEXConvolutionOp::forward(
432436 kernel_size,
433437 padding,
434438 stride,
435- dilation);
439+ dilation,
440+ weight_channels_last);
436441}
437442
438443torch::autograd::variable_list IPEXConvolutionOp::backward (
@@ -463,6 +468,7 @@ torch::autograd::variable_list IPEXConvolutionOp::backward(
463468 at::Tensor (),
464469 at::Tensor (),
465470 at::Tensor (),
471+ at::Tensor (),
466472 at::Tensor ()};
467473}
468474
@@ -474,7 +480,8 @@ at::Tensor convolution_forward(
474480 c10::optional<at::IntArrayRef> kernel_size,
475481 c10::optional<at::IntArrayRef> padding,
476482 c10::optional<at::IntArrayRef> stride,
477- c10::optional<at::IntArrayRef> dilation) {
483+ c10::optional<at::IntArrayRef> dilation,
484+ c10::optional<bool > weight_channels_last) {
478485 if (at::GradMode::is_enabled ()) {
479486 return IPEXConvolutionOp::apply (
480487 input,
@@ -484,7 +491,8 @@ at::Tensor convolution_forward(
484491 kernel_size,
485492 padding,
486493 stride,
487- dilation);
494+ dilation,
495+ weight_channels_last);
488496 }
489497 return IPEXConvolutionOp::_forward (
490498 input,
@@ -494,7 +502,8 @@ at::Tensor convolution_forward(
494502 kernel_size,
495503 padding,
496504 stride,
497- dilation);
505+ dilation,
506+ weight_channels_last);
498507}
499508
500509at::Tensor convolution_forward_meta (
@@ -505,11 +514,12 @@ at::Tensor convolution_forward_meta(
505514 c10::optional<at::IntArrayRef> kernel_size,
506515 c10::optional<at::IntArrayRef> padding,
507516 c10::optional<at::IntArrayRef> stride,
508- c10::optional<at::IntArrayRef> dilation) {
517+ c10::optional<at::IntArrayRef> dilation,
518+ c10::optional<bool > weight_channels_last) {
509519 TORCH_CHECK (
510520 kernel_size.has_value () && padding.has_value () && stride.has_value () &&
511- dilation.has_value (),
512- " kernel_size, padding, stride and dilation must have value for convolution_forward_meta" );
521+ dilation.has_value () && weight_channels_last. has_value () ,
522+ " kernel_size, padding, stride, dilation and weight_channels_last must have value for convolution_forward_meta" );
513523 auto input_size = input.sym_sizes ();
514524 c10::SymDimVector output_sizes = calc_conv_output_size (
515525 input_size,
@@ -518,6 +528,28 @@ at::Tensor convolution_forward_meta(
518528 stride.value (),
519529 dilation.value ());
520530 auto output = at::empty_symint (output_sizes, input.options ());
531+
532+ bool use_channels_last =
533+ input.suggest_memory_format () == at::MemoryFormat::ChannelsLast ||
534+ input.suggest_memory_format () == at::MemoryFormat::ChannelsLast3d ||
535+ weight_channels_last.value ();
536+
537+ auto memory_format = at::MemoryFormat::Contiguous;
538+ if (use_channels_last) {
539+ if (input.dim () == 4 ) {
540+ memory_format = at::MemoryFormat::ChannelsLast;
541+ } else if (input.dim () == 5 ) {
542+ memory_format = at::MemoryFormat::ChannelsLast3d;
543+ }
544+ }
545+
546+ if (!is_channels_last_1d (output)) {
547+ output = output.contiguous (memory_format);
548+ if (input.dim () == 3 ) {
549+ output = to_channels_last_1d (output);
550+ }
551+ }
552+
521553 return output;
522554}
523555
@@ -535,7 +567,8 @@ at::Tensor convolution_forward(
535567 c10::optional<at::IntArrayRef> kernel_size,
536568 c10::optional<at::IntArrayRef> padding,
537569 c10::optional<at::IntArrayRef> stride,
538- c10::optional<at::IntArrayRef> dilation) {
570+ c10::optional<at::IntArrayRef> dilation,
571+ c10::optional<bool > weight_channels_last) {
539572 c10::impl::ExcludeDispatchKeyGuard no_autocastCPU (DispatchKey::AutocastCPU);
540573 static auto op = torch::Dispatcher::singleton ()
541574 .findSchemaOrThrow (" torch_ipex::convolution_forward" , " " )
@@ -551,7 +584,8 @@ at::Tensor convolution_forward(
551584 kernel_size,
552585 padding,
553586 stride,
554- dilation);
587+ dilation,
588+ weight_channels_last);
555589}
556590
557591} // namespace autocast
@@ -562,7 +596,7 @@ namespace {
562596TORCH_LIBRARY_FRAGMENT (torch_ipex, m) {
563597 m.def (
564598 " convolution_forward(Tensor input, Tensor weight, Tensor? bias, "
565- " Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation) -> Tensor" );
599+ " Tensor W_prepack, int[]? kernel_size, int[]? padding, int[]? stride, int[]? dilation, bool? weight_channels_last ) -> Tensor" );
566600 m.impl (
567601 " convolution_forward" ,
568602 c10::DispatchKey::Autograd,
0 commit comments