Skip to content

Commit dd01078

Browse files
author
ayasyrev
committed
move out _make_ funcs from Net class
1 parent 89ba107 commit dd01078

File tree

2 files changed

+187
-66
lines changed

2 files changed

+187
-66
lines changed

model_constructor/net.py

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,41 @@ def forward(self, x):
6666
o = self.reduce(x)
6767
return self.merge(self.convs(o) + self.idconv(o))
6868

69+
# Cell
70+
def _make_stem(self):
71+
stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1],
72+
stride=2 if i==0 else 1,
73+
bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,
74+
act_fn=self.act_fn, bn_1st=self.bn_1st))
75+
for i in range(len(self.stem_sizes)-1)]
76+
stem.append(('stem_pool', self.stem_pool))
77+
if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))
78+
return nn.Sequential(OrderedDict(stem))
79+
80+
# Cell
81+
def _make_layer(self,expansion,ni,nf,blocks,stride,sa):
82+
return nn.Sequential(OrderedDict(
83+
[(f"bl_{i}", self.block(expansion, ni if i==0 else nf, nf,
84+
stride if i==0 else 1, sa=sa if i==blocks-1 else False,
85+
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
86+
zero_bn=self.zero_bn, bn_1st=self.bn_1st))
87+
for i in range(blocks)]))
88+
89+
# Cell
90+
def _make_body(self):
91+
blocks = [(f"l_{i}", self._make_layer(self,self.expansion,
92+
self.block_szs[i], self.block_szs[i+1], l,
93+
1 if i==0 else 2, self.sa if i==0 else False))
94+
for i,l in enumerate(self.layers)]
95+
return nn.Sequential(OrderedDict(blocks))
96+
97+
# Cell
98+
def _make_head(self):
99+
head = [('pool', nn.AdaptiveAvgPool2d(1)),
100+
('flat', Flatten()),
101+
('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]
102+
return nn.Sequential(OrderedDict(head))
103+
69104
# Cell
70105
# v8
71106
class Net():
@@ -83,53 +118,61 @@ def __init__(self, expansion=1, layers=[2,2,2,2], c_in=3, c_out=1000, name='Net'
83118
self.sa=False
84119
self.bn_1st = True
85120
self.zero_bn=True
86-
self._init_cnn = init_cnn
87121
self.conv_layer = ConvLayer
122+
self._init_cnn = init_cnn
123+
self._make_stem = _make_stem
124+
self._make_layer = _make_layer
125+
self._make_body = _make_body
126+
self._make_head = _make_head
127+
88128

89129
@property
90130
def block_szs(self):
91131
return [64//self.expansion,64,128,256,512] +[256]*(len(self.layers)-4)
92132

93133
@property
94134
def stem(self):
95-
return self._make_stem()
135+
return self._make_stem(self)
96136
@property
97137
def head(self):
98-
return self._make_head()
138+
return self._make_head(self)
139+
# @property
140+
# def _make_layer(self):
141+
# return self.__make_layer(self)
99142
@property
100143
def body(self):
101-
return self._make_body()
102-
103-
def _make_stem(self):
104-
stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1],
105-
stride=2 if i==0 else 1,
106-
bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,
107-
act_fn=self.act_fn, bn_1st=self.bn_1st))
108-
for i in range(len(self.stem_sizes)-1)]
109-
stem.append(('stem_pool', self.stem_pool))
110-
if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))
111-
return nn.Sequential(OrderedDict(stem))
112-
113-
def _make_head(self):
114-
head = [('pool', nn.AdaptiveAvgPool2d(1)),
115-
('flat', Flatten()),
116-
('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]
117-
return nn.Sequential(OrderedDict(head))
118-
119-
def _make_body(self):
120-
blocks = [(f"l_{i}", self._make_layer(self.expansion,
121-
self.block_szs[i], self.block_szs[i+1], l,
122-
1 if i==0 else 2, self.sa if i==0 else False))
123-
for i,l in enumerate(self.layers)]
124-
return nn.Sequential(OrderedDict(blocks))
125-
126-
def _make_layer(self,expansion,ni,nf,blocks,stride,sa):
127-
return nn.Sequential(OrderedDict(
128-
[(f"bl_{i}", self.block(expansion, ni if i==0 else nf, nf,
129-
stride if i==0 else 1, sa=sa if i==blocks-1 else False,
130-
conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
131-
zero_bn=self.zero_bn, bn_1st=self.bn_1st))
132-
for i in range(blocks)]))
144+
return self._make_body(self)
145+
146+
# def _make_stem(self):
147+
# stem = [(f"conv_{i}", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1],
148+
# stride=2 if i==0 else 1,
149+
# bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,
150+
# act_fn=self.act_fn, bn_1st=self.bn_1st))
151+
# for i in range(len(self.stem_sizes)-1)]
152+
# stem.append(('stem_pool', self.stem_pool))
153+
# if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))
154+
# return nn.Sequential(OrderedDict(stem))
155+
156+
# def _make_head(self):
157+
# head = [('pool', nn.AdaptiveAvgPool2d(1)),
158+
# ('flat', Flatten()),
159+
# ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]
160+
# return nn.Sequential(OrderedDict(head))
161+
162+
# def _make_body(self):
163+
# blocks = [(f"l_{i}", self._make_layer(self.expansion,
164+
# self.block_szs[i], self.block_szs[i+1], l,
165+
# 1 if i==0 else 2, self.sa if i==0 else False))
166+
# for i,l in enumerate(self.layers)]
167+
# return nn.Sequential(OrderedDict(blocks))
168+
169+
# def _make_layer(self,expansion,ni,nf,blocks,stride,sa):
170+
# return nn.Sequential(OrderedDict(
171+
# [(f"bl_{i}", self.block(expansion, ni if i==0 else nf, nf,
172+
# stride if i==0 else 1, sa=sa if i==blocks-1 else False,
173+
# conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,
174+
# zero_bn=self.zero_bn, bn_1st=self.bn_1st))
175+
# for i in range(blocks)]))
133176

134177
def __call__(self):
135178
model = nn.Sequential(OrderedDict([

nbs/04_Net.ipynb

Lines changed: 109 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,76 @@
342342
"assert y.shape == torch.Size([bs_test, 512, 16, 16]), f\"size\""
343343
]
344344
},
345+
{
346+
"cell_type": "markdown",
347+
"metadata": {},
348+
"source": [
349+
"# Stem, Body, Head"
350+
]
351+
},
352+
{
353+
"cell_type": "code",
354+
"execution_count": null,
355+
"metadata": {},
356+
"outputs": [],
357+
"source": [
358+
"# export\n",
359+
"def _make_stem(self):\n",
360+
" stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n",
361+
" stride=2 if i==0 else 1, \n",
362+
" bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n",
363+
" act_fn=self.act_fn, bn_1st=self.bn_1st))\n",
364+
" for i in range(len(self.stem_sizes)-1)]\n",
365+
" stem.append(('stem_pool', self.stem_pool))\n",
366+
" if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n",
367+
" return nn.Sequential(OrderedDict(stem))"
368+
]
369+
},
370+
{
371+
"cell_type": "code",
372+
"execution_count": null,
373+
"metadata": {},
374+
"outputs": [],
375+
"source": [
376+
"# export\n",
377+
"def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n",
378+
" return nn.Sequential(OrderedDict(\n",
379+
" [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n",
380+
" stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n",
381+
" conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n",
382+
" zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n",
383+
" for i in range(blocks)]))"
384+
]
385+
},
386+
{
387+
"cell_type": "code",
388+
"execution_count": null,
389+
"metadata": {},
390+
"outputs": [],
391+
"source": [
392+
"# export\n",
393+
"def _make_body(self):\n",
394+
" blocks = [(f\"l_{i}\", self._make_layer(self,self.expansion, \n",
395+
" self.block_szs[i], self.block_szs[i+1], l, \n",
396+
" 1 if i==0 else 2, self.sa if i==0 else False))\n",
397+
" for i,l in enumerate(self.layers)]\n",
398+
" return nn.Sequential(OrderedDict(blocks))"
399+
]
400+
},
401+
{
402+
"cell_type": "code",
403+
"execution_count": null,
404+
"metadata": {},
405+
"outputs": [],
406+
"source": [
407+
"# export\n",
408+
"def _make_head(self):\n",
409+
" head = [('pool', nn.AdaptiveAvgPool2d(1)),\n",
410+
" ('flat', Flatten()),\n",
411+
" ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n",
412+
" return nn.Sequential(OrderedDict(head))"
413+
]
414+
},
345415
{
346416
"cell_type": "markdown",
347417
"metadata": {},
@@ -372,53 +442,61 @@
372442
" self.sa=False\n",
373443
" self.bn_1st = True\n",
374444
" self.zero_bn=True\n",
375-
" self._init_cnn = init_cnn\n",
376445
" self.conv_layer = ConvLayer\n",
446+
" self._init_cnn = init_cnn\n",
447+
" self._make_stem = _make_stem\n",
448+
" self._make_layer = _make_layer\n",
449+
" self._make_body = _make_body\n",
450+
" self._make_head = _make_head\n",
451+
" \n",
377452
" \n",
378453
" @property\n",
379454
" def block_szs(self):\n",
380455
" return [64//self.expansion,64,128,256,512] +[256]*(len(self.layers)-4) \n",
381456
"\n",
382457
" @property\n",
383458
" def stem(self):\n",
384-
" return self._make_stem()\n",
459+
" return self._make_stem(self)\n",
385460
" @property\n",
386461
" def head(self):\n",
387-
" return self._make_head()\n",
462+
" return self._make_head(self)\n",
463+
"# @property\n",
464+
"# def _make_layer(self):\n",
465+
"# return self.__make_layer(self)\n",
388466
" @property\n",
389467
" def body(self):\n",
390-
" return self._make_body()\n",
468+
" return self._make_body(self)\n",
391469
" \n",
392-
" def _make_stem(self):\n",
393-
" stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n",
394-
" stride=2 if i==0 else 1, \n",
395-
" bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n",
396-
" act_fn=self.act_fn, bn_1st=self.bn_1st))\n",
397-
" for i in range(len(self.stem_sizes)-1)]\n",
398-
" stem.append(('stem_pool', self.stem_pool))\n",
399-
" if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n",
400-
" return nn.Sequential(OrderedDict(stem))\n",
470+
"# def _make_stem(self):\n",
471+
"# stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n",
472+
"# stride=2 if i==0 else 1, \n",
473+
"# bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n",
474+
"# act_fn=self.act_fn, bn_1st=self.bn_1st))\n",
475+
"# for i in range(len(self.stem_sizes)-1)]\n",
476+
"# stem.append(('stem_pool', self.stem_pool))\n",
477+
"# if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n",
478+
"# return nn.Sequential(OrderedDict(stem))\n",
401479
" \n",
402-
" def _make_head(self):\n",
403-
" head = [('pool', nn.AdaptiveAvgPool2d(1)),\n",
404-
" ('flat', Flatten()),\n",
405-
" ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n",
406-
" return nn.Sequential(OrderedDict(head))\n",
480+
"# def _make_head(self):\n",
481+
"# head = [('pool', nn.AdaptiveAvgPool2d(1)),\n",
482+
"# ('flat', Flatten()),\n",
483+
"# ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n",
484+
"# return nn.Sequential(OrderedDict(head))\n",
407485
" \n",
408-
" def _make_body(self):\n",
409-
" blocks = [(f\"l_{i}\", self._make_layer(self.expansion, \n",
410-
" self.block_szs[i], self.block_szs[i+1], l, \n",
411-
" 1 if i==0 else 2, self.sa if i==0 else False))\n",
412-
" for i,l in enumerate(self.layers)]\n",
413-
" return nn.Sequential(OrderedDict(blocks))\n",
486+
"# def _make_body(self):\n",
487+
"# blocks = [(f\"l_{i}\", self._make_layer(self.expansion, \n",
488+
"# self.block_szs[i], self.block_szs[i+1], l, \n",
489+
"# 1 if i==0 else 2, self.sa if i==0 else False))\n",
490+
"# for i,l in enumerate(self.layers)]\n",
491+
"# return nn.Sequential(OrderedDict(blocks))\n",
414492
" \n",
415-
" def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n",
416-
" return nn.Sequential(OrderedDict(\n",
417-
" [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n",
418-
" stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n",
419-
" conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n",
420-
" zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n",
421-
" for i in range(blocks)]))\n",
493+
"# def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n",
494+
"# return nn.Sequential(OrderedDict(\n",
495+
"# [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n",
496+
"# stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n",
497+
"# conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n",
498+
"# zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n",
499+
"# for i in range(blocks)]))\n",
422500
" \n",
423501
" def __call__(self):\n",
424502
" model = nn.Sequential(OrderedDict([\n",

0 commit comments

Comments
 (0)