Skip to content

Commit 0fbd5e6

Browse files
committed
fix groups
1 parent a5aa5f9 commit 0fbd5e6

File tree

3 files changed

+47
-80
lines changed

3 files changed

+47
-80
lines changed

Nbs/04_YaResNet.ipynb

Lines changed: 41 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
},
1212
{
1313
"cell_type": "code",
14-
"execution_count": 4,
14+
"execution_count": 5,
1515
"source": [
1616
"#hide\n",
1717
"# from nbdev.showdoc import *\n",
@@ -36,50 +36,17 @@
3636
},
3737
{
3838
"cell_type": "code",
39-
"execution_count": 5,
39+
"execution_count": 1,
4040
"source": [
41-
"# YaResBlock - former NewResBlock.\n",
42-
"# Yet another ResNet.\n",
43-
"class YaResBlock(nn.Module):\n",
44-
" '''YaResBlock. Reduce by pool instead of stride 2'''\n",
45-
" se_block = SEBlock\n",
46-
"\n",
47-
" def __init__(self, expansion, ni, nh, stride=1,\n",
48-
" conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,\n",
49-
" pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False,\n",
50-
" groups=1, dw=False):\n",
51-
" super().__init__()\n",
52-
" nf, ni = nh * expansion, ni * expansion\n",
53-
" if groups != 1:\n",
54-
" groups = int(nh / groups)\n",
55-
" self.reduce = noop if stride == 1 else pool\n",
56-
" layers = [(\"conv_0\", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,\n",
57-
" groups=nh if dw else groups)),\n",
58-
" (\"conv_1\", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n",
59-
" ] if expansion == 1 else [\n",
60-
" (\"conv_0\", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),\n",
61-
" (\"conv_1\", conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,\n",
62-
" groups=nh if dw else groups)),\n",
63-
" (\"conv_2\", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n",
64-
" ]\n",
65-
" if se:\n",
66-
" layers.append(('se', self.se_block(nf)))\n",
67-
" if sa:\n",
68-
" layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))\n",
69-
" self.convs = nn.Sequential(OrderedDict(layers))\n",
70-
" self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False)\n",
71-
" self.merge = act_fn\n",
72-
"\n",
73-
" def forward(self, x):\n",
74-
" o = self.reduce(x)\n",
75-
" return self.merge(self.convs(o) + self.idconv(o))"
41+
"#hide\n",
42+
"from model_constructor.yaresnet import YaResBlock"
7643
],
7744
"outputs": [],
7845
"metadata": {}
7946
},
8047
{
8148
"cell_type": "code",
82-
"execution_count": 6,
49+
"execution_count": 2,
8350
"source": [
8451
"#collapse_output\n",
8552
"bl = YaResBlock(1,64,64,sa=True)\n",
@@ -110,14 +77,14 @@
11077
]
11178
},
11279
"metadata": {},
113-
"execution_count": 6
80+
"execution_count": 2
11481
}
11582
],
11683
"metadata": {}
11784
},
11885
{
11986
"cell_type": "code",
120-
"execution_count": 7,
87+
"execution_count": 6,
12188
"source": [
12289
"#hide\n",
12390
"bs_test = 16\n",
@@ -139,7 +106,7 @@
139106
},
140107
{
141108
"cell_type": "code",
142-
"execution_count": 8,
109+
"execution_count": 7,
143110
"source": [
144111
"#collapse_output\n",
145112
"bl = YaResBlock(1,64,64,se=True)\n",
@@ -176,14 +143,14 @@
176143
]
177144
},
178145
"metadata": {},
179-
"execution_count": 8
146+
"execution_count": 7
180147
}
181148
],
182149
"metadata": {}
183150
},
184151
{
185152
"cell_type": "code",
186-
"execution_count": 9,
153+
"execution_count": 8,
187154
"source": [
188155
"#hide\n",
189156
"bs_test = 16\n",
@@ -205,7 +172,7 @@
205172
},
206173
{
207174
"cell_type": "code",
208-
"execution_count": 10,
175+
"execution_count": 9,
209176
"source": [
210177
"#collapse_output\n",
211178
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)\n",
@@ -243,14 +210,14 @@
243210
]
244211
},
245212
"metadata": {},
246-
"execution_count": 10
213+
"execution_count": 9
247214
}
248215
],
249216
"metadata": {}
250217
},
251218
{
252219
"cell_type": "code",
253-
"execution_count": 11,
220+
"execution_count": 10,
254221
"source": [
255222
"#hide\n",
256223
"bs_test = 16\n",
@@ -272,7 +239,7 @@
272239
},
273240
{
274241
"cell_type": "code",
275-
"execution_count": 12,
242+
"execution_count": 11,
276243
"source": [
277244
"#collapse_output\n",
278245
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, groups=4)\n",
@@ -292,7 +259,7 @@
292259
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
293260
" )\n",
294261
" (conv_1): ConvLayer(\n",
295-
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n",
262+
" (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n",
296263
" (act_fn): LeakyReLU(negative_slope=0.01)\n",
297264
" (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
298265
" )\n",
@@ -310,14 +277,14 @@
310277
]
311278
},
312279
"metadata": {},
313-
"execution_count": 12
280+
"execution_count": 11
314281
}
315282
],
316283
"metadata": {}
317284
},
318285
{
319286
"cell_type": "code",
320-
"execution_count": 13,
287+
"execution_count": 12,
321288
"source": [
322289
"#hide\n",
323290
"bs_test = 16\n",
@@ -339,7 +306,7 @@
339306
},
340307
{
341308
"cell_type": "code",
342-
"execution_count": 14,
309+
"execution_count": 13,
343310
"source": [
344311
"#collapse_output\n",
345312
"bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, dw=True)\n",
@@ -377,14 +344,14 @@
377344
]
378345
},
379346
"metadata": {},
380-
"execution_count": 14
347+
"execution_count": 13
381348
}
382349
],
383350
"metadata": {}
384351
},
385352
{
386353
"cell_type": "code",
387-
"execution_count": 15,
354+
"execution_count": 14,
388355
"source": [
389356
"#hide\n",
390357
"bs_test = 16\n",
@@ -413,7 +380,7 @@
413380
},
414381
{
415382
"cell_type": "code",
416-
"execution_count": 16,
383+
"execution_count": 15,
417384
"source": [
418385
"yaresnet = Net(block=YaResBlock, stem_sizes = [3, 32, 64, 64], name='YaResNet')"
419386
],
@@ -422,7 +389,7 @@
422389
},
423390
{
424391
"cell_type": "code",
425-
"execution_count": 17,
392+
"execution_count": 16,
426393
"source": [
427394
"yaresnet"
428395
],
@@ -441,14 +408,14 @@
441408
]
442409
},
443410
"metadata": {},
444-
"execution_count": 17
411+
"execution_count": 16
445412
}
446413
],
447414
"metadata": {}
448415
},
449416
{
450417
"cell_type": "code",
451-
"execution_count": 18,
418+
"execution_count": 17,
452419
"source": [
453420
"yaresnet.block_sizes, yaresnet.layers"
454421
],
@@ -461,14 +428,14 @@
461428
]
462429
},
463430
"metadata": {},
464-
"execution_count": 18
431+
"execution_count": 17
465432
}
466433
],
467434
"metadata": {}
468435
},
469436
{
470437
"cell_type": "code",
471-
"execution_count": 19,
438+
"execution_count": 18,
472439
"source": [
473440
"#collapse_output\n",
474441
"yaresnet.stem"
@@ -499,14 +466,14 @@
499466
]
500467
},
501468
"metadata": {},
502-
"execution_count": 19
469+
"execution_count": 18
503470
}
504471
],
505472
"metadata": {}
506473
},
507474
{
508475
"cell_type": "code",
509-
"execution_count": 20,
476+
"execution_count": 19,
510477
"source": [
511478
"#hide\n",
512479
"bs_test = 16\n",
@@ -536,7 +503,7 @@
536503
},
537504
{
538505
"cell_type": "code",
539-
"execution_count": 21,
506+
"execution_count": 20,
540507
"source": [
541508
"#collapse_output\n",
542509
"yaresnet.body"
@@ -686,14 +653,14 @@
686653
]
687654
},
688655
"metadata": {},
689-
"execution_count": 21
656+
"execution_count": 20
690657
}
691658
],
692659
"metadata": {}
693660
},
694661
{
695662
"cell_type": "code",
696-
"execution_count": 22,
663+
"execution_count": 21,
697664
"source": [
698665
"#collapse_output\n",
699666
"yaresnet.head"
@@ -711,7 +678,7 @@
711678
]
712679
},
713680
"metadata": {},
714-
"execution_count": 22
681+
"execution_count": 21
715682
}
716683
],
717684
"metadata": {}
@@ -725,7 +692,7 @@
725692
},
726693
{
727694
"cell_type": "code",
728-
"execution_count": 24,
695+
"execution_count": 22,
729696
"source": [
730697
"#collapse_output\n",
731698
"yaresnet.act_fn = Mish()\n",
@@ -902,7 +869,7 @@
902869
]
903870
},
904871
"metadata": {},
905-
"execution_count": 24
872+
"execution_count": 22
906873
}
907874
],
908875
"metadata": {}
@@ -923,7 +890,7 @@
923890
},
924891
{
925892
"cell_type": "code",
926-
"execution_count": 25,
893+
"execution_count": 23,
927894
"source": [
928895
"yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1}\n",
929896
"yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters)\n",
@@ -934,7 +901,7 @@
934901
},
935902
{
936903
"cell_type": "code",
937-
"execution_count": 26,
904+
"execution_count": 24,
938905
"source": [
939906
"model = yaresnet50(c_out=10)"
940907
],
@@ -943,7 +910,7 @@
943910
},
944911
{
945912
"cell_type": "code",
946-
"execution_count": 27,
913+
"execution_count": 25,
947914
"source": [
948915
"model"
949916
],
@@ -962,14 +929,14 @@
962929
]
963930
},
964931
"metadata": {},
965-
"execution_count": 27
932+
"execution_count": 25
966933
}
967934
],
968935
"metadata": {}
969936
},
970937
{
971938
"cell_type": "code",
972-
"execution_count": 28,
939+
"execution_count": 26,
973940
"source": [
974941
"model.c_out, model.layers"
975942
],
@@ -982,14 +949,14 @@
982949
]
983950
},
984951
"metadata": {},
985-
"execution_count": 28
952+
"execution_count": 26
986953
}
987954
],
988955
"metadata": {}
989956
},
990957
{
991958
"cell_type": "code",
992-
"execution_count": 29,
959+
"execution_count": 27,
993960
"source": [
994961
"#hide\n",
995962
"bs_test = 16\n",

model_constructor/net.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def __init__(self, expansion, ni, nh, stride=1,
3232
groups=1, dw=False):
3333
super().__init__()
3434
nf, ni = nh * expansion, ni * expansion
35-
if groups != 1:
36-
groups = int(nh / groups)
35+
# if groups != 1:
36+
# groups = int(nh / groups)
3737
if expansion == 1:
3838
layers = [("conv_0", conv_layer(ni, nh, 3, stride=stride, act_fn=act_fn, bn_1st=bn_1st,
3939
groups=nh if dw else groups)),
@@ -71,8 +71,8 @@ def __init__(self, expansion, ni, nh, stride=1,
7171
groups=1, dw=False):
7272
super().__init__()
7373
nf, ni = nh * expansion, ni * expansion
74-
if groups != 1:
75-
groups = int(nh / groups)
74+
# if groups != 1:
75+
# groups = int(nh / groups)
7676
self.reduce = noop if stride == 1 else pool
7777
if expansion == 1:
7878
layers = [("conv_0", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,

0 commit comments

Comments
 (0)