|
| 1 | +""" Conditional Convolution |
| 2 | +
|
| 3 | +Hacked together by Ross Wightman |
| 4 | +""" |
| 5 | + |
| 6 | +import math |
| 7 | +from functools import partial |
| 8 | +import numpy as np |
| 9 | +import torch |
| 10 | +from torch import nn as nn |
| 11 | +from torch.nn import functional as F |
| 12 | + |
| 13 | +from .conv2d_same import get_padding_value, conv2d_same |
| 14 | +from .conv_helpers import tup_pair |
| 15 | + |
| 16 | + |
| 17 | +def get_condconv_initializer(initializer, num_experts, expert_shape): |
| 18 | + def condconv_initializer(weight): |
| 19 | + """CondConv initializer function.""" |
| 20 | + num_params = np.prod(expert_shape) |
| 21 | + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or |
| 22 | + weight.shape[1] != num_params): |
| 23 | + raise (ValueError( |
| 24 | + 'CondConv variables must have shape [num_experts, num_params]')) |
| 25 | + for i in range(num_experts): |
| 26 | + initializer(weight[i].view(expert_shape)) |
| 27 | + return condconv_initializer |
| 28 | + |
| 29 | + |
| 30 | +class CondConv2d(nn.Module): |
| 31 | + """ Conditional Convolution |
| 32 | + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py |
| 33 | +
|
| 34 | + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: |
| 35 | + https://github.com/pytorch/pytorch/issues/17983 |
| 36 | + """ |
| 37 | + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] |
| 38 | + |
| 39 | + def __init__(self, in_channels, out_channels, kernel_size=3, |
| 40 | + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): |
| 41 | + super(CondConv2d, self).__init__() |
| 42 | + |
| 43 | + self.in_channels = in_channels |
| 44 | + self.out_channels = out_channels |
| 45 | + self.kernel_size = tup_pair(kernel_size) |
| 46 | + self.stride = tup_pair(stride) |
| 47 | + padding_val, is_padding_dynamic = get_padding_value( |
| 48 | + padding, kernel_size, stride=stride, dilation=dilation) |
| 49 | + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript |
| 50 | + self.padding = tup_pair(padding_val) |
| 51 | + self.dilation = tup_pair(dilation) |
| 52 | + self.groups = groups |
| 53 | + self.num_experts = num_experts |
| 54 | + |
| 55 | + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size |
| 56 | + weight_num_param = 1 |
| 57 | + for wd in self.weight_shape: |
| 58 | + weight_num_param *= wd |
| 59 | + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) |
| 60 | + |
| 61 | + if bias: |
| 62 | + self.bias_shape = (self.out_channels,) |
| 63 | + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) |
| 64 | + else: |
| 65 | + self.register_parameter('bias', None) |
| 66 | + |
| 67 | + self.reset_parameters() |
| 68 | + |
| 69 | + def reset_parameters(self): |
| 70 | + init_weight = get_condconv_initializer( |
| 71 | + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) |
| 72 | + init_weight(self.weight) |
| 73 | + if self.bias is not None: |
| 74 | + fan_in = np.prod(self.weight_shape[1:]) |
| 75 | + bound = 1 / math.sqrt(fan_in) |
| 76 | + init_bias = get_condconv_initializer( |
| 77 | + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) |
| 78 | + init_bias(self.bias) |
| 79 | + |
| 80 | + def forward(self, x, routing_weights): |
| 81 | + B, C, H, W = x.shape |
| 82 | + weight = torch.matmul(routing_weights, self.weight) |
| 83 | + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size |
| 84 | + weight = weight.view(new_weight_shape) |
| 85 | + bias = None |
| 86 | + if self.bias is not None: |
| 87 | + bias = torch.matmul(routing_weights, self.bias) |
| 88 | + bias = bias.view(B * self.out_channels) |
| 89 | + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel |
| 90 | + x = x.view(1, B * C, H, W) |
| 91 | + if self.dynamic_padding: |
| 92 | + out = conv2d_same( |
| 93 | + x, weight, bias, stride=self.stride, padding=self.padding, |
| 94 | + dilation=self.dilation, groups=self.groups * B) |
| 95 | + else: |
| 96 | + out = F.conv2d( |
| 97 | + x, weight, bias, stride=self.stride, padding=self.padding, |
| 98 | + dilation=self.dilation, groups=self.groups * B) |
| 99 | + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) |
| 100 | + |
| 101 | + # Literal port (from TF definition) |
| 102 | + # x = torch.split(x, 1, 0) |
| 103 | + # weight = torch.split(weight, 1, 0) |
| 104 | + # if self.bias is not None: |
| 105 | + # bias = torch.matmul(routing_weights, self.bias) |
| 106 | + # bias = torch.split(bias, 1, 0) |
| 107 | + # else: |
| 108 | + # bias = [None] * B |
| 109 | + # out = [] |
| 110 | + # for xi, wi, bi in zip(x, weight, bias): |
| 111 | + # wi = wi.view(*self.weight_shape) |
| 112 | + # if bi is not None: |
| 113 | + # bi = bi.view(*self.bias_shape) |
| 114 | + # out.append(self.conv_fn( |
| 115 | + # xi, wi, bi, stride=self.stride, padding=self.padding, |
| 116 | + # dilation=self.dilation, groups=self.groups)) |
| 117 | + # out = torch.cat(out, 0) |
| 118 | + return out |
0 commit comments