1- __all__ = ['ConvTwist' , 'ConvLayerTwist' , 'NewResBlockTwist' , 'ResBlockTwist' ]
2-
3- # Cell
4- # from functools import partial
51from collections import OrderedDict
62from .layers import ConvLayer , noop , act_fn , SimpleSelfAttention
73
8- # Cell
94import torch
105import torch .nn as nn
116import torch .nn .functional as F
127import numpy as np
138
14- # Cell
9+
10+ __all__ = ['ConvTwist' , 'ConvLayerTwist' , 'NewResBlockTwist' , 'ResBlockTwist' ]
11+
12+
1513class ConvTwist (nn .Module ):
1614 '''Replacement for Conv2d (kernelsize 3x3)'''
1715 permute = True
@@ -102,12 +100,15 @@ def forward(self, inpt):
102100 def extra_repr (self ):
103101 return f"twist: { self .twist } , permute: { self .permute } , same: { self .same } , groups: { self .groups } "
104102
105- # Cell
103+
106104class ConvLayerTwist (ConvLayer ): # replace Conv2d by Twist
105+ '''Conv layer with ConvTwist'''
107106 Conv2d = ConvTwist
108107
109- # Cell
108+
110109class NewResBlockTwist (nn .Module ):
110+ '''Resnet block with ConvTwist'''
111+
111112 def __init__ (self , expansion , ni , nh , stride = 1 ,
112113 conv_layer = ConvLayer , act_fn = act_fn , bn_1st = True ,
113114 pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False , sym = False , zero_bn = True , ** kvargs ):
@@ -131,8 +132,10 @@ def forward(self, x):
131132 o = self .reduce (x )
132133 return self .merge (self .convs (o ) + self .idconv (o ))
133134
134- # Cell
135+
135136class ResBlockTwist (nn .Module ):
137+ '''Resnet block with ConvTwist'''
138+
136139 def __init__ (self , expansion , ni , nh , stride = 1 ,
137140 conv_layer = ConvLayer , act_fn = act_fn , zero_bn = True , bn_1st = True ,
138141 pool = nn .AvgPool2d (2 , ceil_mode = True ), sa = False , sym = False , ** kvargs ):
@@ -153,4 +156,4 @@ def __init__(self, expansion, ni, nh, stride=1,
153156 self .act_fn = act_fn
154157
155158 def forward (self , x ):
156- return self .act_fn (self .convs (x ) + self .idconv (self .pool (x )))
159+ return self .act_fn (self .convs (x ) + self .idconv (self .pool (x )))
0 commit comments