Skip to content

Commit f1f43a2

Browse files
committed
fix twist
1 parent 47f7d07 commit f1f43a2

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

model_constructor/twist.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
1-
__all__ = ['ConvTwist', 'ConvLayerTwist', 'NewResBlockTwist', 'ResBlockTwist']
2-
3-
# Cell
4-
# from functools import partial
51
from collections import OrderedDict
62
from .layers import ConvLayer, noop, act_fn, SimpleSelfAttention
73

8-
# Cell
94
import torch
105
import torch.nn as nn
116
import torch.nn.functional as F
127
import numpy as np
138

14-
# Cell
9+
10+
__all__ = ['ConvTwist', 'ConvLayerTwist', 'NewResBlockTwist', 'ResBlockTwist']
11+
12+
1513
class ConvTwist(nn.Module):
1614
'''Replacement for Conv2d (kernelsize 3x3)'''
1715
permute = True
@@ -102,12 +100,15 @@ def forward(self, inpt):
102100
def extra_repr(self):
103101
return f"twist: {self.twist}, permute: {self.permute}, same: {self.same}, groups: {self.groups}"
104102

105-
# Cell
103+
106104
class ConvLayerTwist(ConvLayer): # replace Conv2d by Twist
105+
'''Conv layer with ConvTwist'''
107106
Conv2d = ConvTwist
108107

109-
# Cell
108+
110109
class NewResBlockTwist(nn.Module):
110+
'''Resnet block with ConvTwist'''
111+
111112
def __init__(self, expansion, ni, nh, stride=1,
112113
conv_layer=ConvLayer, act_fn=act_fn, bn_1st=True,
113114
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, zero_bn=True, **kvargs):
@@ -131,8 +132,10 @@ def forward(self, x):
131132
o = self.reduce(x)
132133
return self.merge(self.convs(o) + self.idconv(o))
133134

134-
# Cell
135+
135136
class ResBlockTwist(nn.Module):
137+
'''Resnet block with ConvTwist'''
138+
136139
def __init__(self, expansion, ni, nh, stride=1,
137140
conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,
138141
pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, **kvargs):
@@ -153,4 +156,4 @@ def __init__(self, expansion, ni, nh, stride=1,
153156
self.act_fn = act_fn
154157

155158
def forward(self, x):
156-
return self.act_fn(self.convs(x) + self.idconv(self.pool(x)))
159+
return self.act_fn(self.convs(x) + self.idconv(self.pool(x)))

0 commit comments

Comments
 (0)