|
11 | 11 | }, |
12 | 12 | { |
13 | 13 | "cell_type": "code", |
14 | | - "execution_count": 4, |
| 14 | + "execution_count": 5, |
15 | 15 | "source": [ |
16 | 16 | "#hide\n", |
17 | 17 | "# from nbdev.showdoc import *\n", |
|
36 | 36 | }, |
37 | 37 | { |
38 | 38 | "cell_type": "code", |
39 | | - "execution_count": 5, |
| 39 | + "execution_count": 1, |
40 | 40 | "source": [ |
41 | | - "# YaResBlock - former NewResBlock.\n", |
42 | | - "# Yet another ResNet.\n", |
43 | | - "class YaResBlock(nn.Module):\n", |
44 | | - " '''YaResBlock. Reduce by pool instead of stride 2'''\n", |
45 | | - " se_block = SEBlock\n", |
46 | | - "\n", |
47 | | - " def __init__(self, expansion, ni, nh, stride=1,\n", |
48 | | - " conv_layer=ConvLayer, act_fn=act_fn, zero_bn=True, bn_1st=True,\n", |
49 | | - " pool=nn.AvgPool2d(2, ceil_mode=True), sa=False, sym=False, se=False,\n", |
50 | | - " groups=1, dw=False):\n", |
51 | | - " super().__init__()\n", |
52 | | - " nf, ni = nh * expansion, ni * expansion\n", |
53 | | - " if groups != 1:\n", |
54 | | - " groups = int(nh / groups)\n", |
55 | | - " self.reduce = noop if stride == 1 else pool\n", |
56 | | - " layers = [(\"conv_0\", conv_layer(ni, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,\n", |
57 | | - " groups=nh if dw else groups)),\n", |
58 | | - " (\"conv_1\", conv_layer(nh, nf, 3, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n", |
59 | | - " ] if expansion == 1 else [\n", |
60 | | - " (\"conv_0\", conv_layer(ni, nh, 1, act_fn=act_fn, bn_1st=bn_1st)),\n", |
61 | | - " (\"conv_1\", conv_layer(nh, nh, 3, stride=1, act_fn=act_fn, bn_1st=bn_1st,\n", |
62 | | - " groups=nh if dw else groups)),\n", |
63 | | - " (\"conv_2\", conv_layer(nh, nf, 1, zero_bn=zero_bn, act=False, bn_1st=bn_1st))\n", |
64 | | - " ]\n", |
65 | | - " if se:\n", |
66 | | - " layers.append(('se', self.se_block(nf)))\n", |
67 | | - " if sa:\n", |
68 | | - " layers.append(('sa', SimpleSelfAttention(nf, ks=1, sym=sym)))\n", |
69 | | - " self.convs = nn.Sequential(OrderedDict(layers))\n", |
70 | | - " self.idconv = noop if ni == nf else conv_layer(ni, nf, 1, act=False)\n", |
71 | | - " self.merge = act_fn\n", |
72 | | - "\n", |
73 | | - " def forward(self, x):\n", |
74 | | - " o = self.reduce(x)\n", |
75 | | - " return self.merge(self.convs(o) + self.idconv(o))" |
| 41 | + "#hide\n", |
| 42 | + "from model_constructor.yaresnet import YaResBlock" |
76 | 43 | ], |
77 | 44 | "outputs": [], |
78 | 45 | "metadata": {} |
79 | 46 | }, |
80 | 47 | { |
81 | 48 | "cell_type": "code", |
82 | | - "execution_count": 6, |
| 49 | + "execution_count": 2, |
83 | 50 | "source": [ |
84 | 51 | "#collapse_output\n", |
85 | 52 | "bl = YaResBlock(1,64,64,sa=True)\n", |
|
110 | 77 | ] |
111 | 78 | }, |
112 | 79 | "metadata": {}, |
113 | | - "execution_count": 6 |
| 80 | + "execution_count": 2 |
114 | 81 | } |
115 | 82 | ], |
116 | 83 | "metadata": {} |
117 | 84 | }, |
118 | 85 | { |
119 | 86 | "cell_type": "code", |
120 | | - "execution_count": 7, |
| 87 | + "execution_count": 6, |
121 | 88 | "source": [ |
122 | 89 | "#hide\n", |
123 | 90 | "bs_test = 16\n", |
|
139 | 106 | }, |
140 | 107 | { |
141 | 108 | "cell_type": "code", |
142 | | - "execution_count": 8, |
| 109 | + "execution_count": 7, |
143 | 110 | "source": [ |
144 | 111 | "#collapse_output\n", |
145 | 112 | "bl = YaResBlock(1,64,64,se=True)\n", |
|
176 | 143 | ] |
177 | 144 | }, |
178 | 145 | "metadata": {}, |
179 | | - "execution_count": 8 |
| 146 | + "execution_count": 7 |
180 | 147 | } |
181 | 148 | ], |
182 | 149 | "metadata": {} |
183 | 150 | }, |
184 | 151 | { |
185 | 152 | "cell_type": "code", |
186 | | - "execution_count": 9, |
| 153 | + "execution_count": 8, |
187 | 154 | "source": [ |
188 | 155 | "#hide\n", |
189 | 156 | "bs_test = 16\n", |
|
205 | 172 | }, |
206 | 173 | { |
207 | 174 | "cell_type": "code", |
208 | | - "execution_count": 10, |
| 175 | + "execution_count": 9, |
209 | 176 | "source": [ |
210 | 177 | "#collapse_output\n", |
211 | 178 | "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False)\n", |
|
243 | 210 | ] |
244 | 211 | }, |
245 | 212 | "metadata": {}, |
246 | | - "execution_count": 10 |
| 213 | + "execution_count": 9 |
247 | 214 | } |
248 | 215 | ], |
249 | 216 | "metadata": {} |
250 | 217 | }, |
251 | 218 | { |
252 | 219 | "cell_type": "code", |
253 | | - "execution_count": 11, |
| 220 | + "execution_count": 10, |
254 | 221 | "source": [ |
255 | 222 | "#hide\n", |
256 | 223 | "bs_test = 16\n", |
|
272 | 239 | }, |
273 | 240 | { |
274 | 241 | "cell_type": "code", |
275 | | - "execution_count": 12, |
| 242 | + "execution_count": 11, |
276 | 243 | "source": [ |
277 | 244 | "#collapse_output\n", |
278 | 245 | "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, groups=4)\n", |
|
292 | 259 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
293 | 260 | " )\n", |
294 | 261 | " (conv_1): ConvLayer(\n", |
295 | | - " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)\n", |
| 262 | + " (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=4, bias=False)\n", |
296 | 263 | " (act_fn): LeakyReLU(negative_slope=0.01)\n", |
297 | 264 | " (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
298 | 265 | " )\n", |
|
310 | 277 | ] |
311 | 278 | }, |
312 | 279 | "metadata": {}, |
313 | | - "execution_count": 12 |
| 280 | + "execution_count": 11 |
314 | 281 | } |
315 | 282 | ], |
316 | 283 | "metadata": {} |
317 | 284 | }, |
318 | 285 | { |
319 | 286 | "cell_type": "code", |
320 | | - "execution_count": 13, |
| 287 | + "execution_count": 12, |
321 | 288 | "source": [ |
322 | 289 | "#hide\n", |
323 | 290 | "bs_test = 16\n", |
|
339 | 306 | }, |
340 | 307 | { |
341 | 308 | "cell_type": "code", |
342 | | - "execution_count": 14, |
| 309 | + "execution_count": 13, |
343 | 310 | "source": [ |
344 | 311 | "#collapse_output\n", |
345 | 312 | "bl = YaResBlock(4,64,128,stride=2,act_fn=nn.LeakyReLU(), bn_1st=False, dw=True)\n", |
|
377 | 344 | ] |
378 | 345 | }, |
379 | 346 | "metadata": {}, |
380 | | - "execution_count": 14 |
| 347 | + "execution_count": 13 |
381 | 348 | } |
382 | 349 | ], |
383 | 350 | "metadata": {} |
384 | 351 | }, |
385 | 352 | { |
386 | 353 | "cell_type": "code", |
387 | | - "execution_count": 15, |
| 354 | + "execution_count": 14, |
388 | 355 | "source": [ |
389 | 356 | "#hide\n", |
390 | 357 | "bs_test = 16\n", |
|
413 | 380 | }, |
414 | 381 | { |
415 | 382 | "cell_type": "code", |
416 | | - "execution_count": 16, |
| 383 | + "execution_count": 15, |
417 | 384 | "source": [ |
418 | 385 | "yaresnet = Net(block=YaResBlock, stem_sizes = [3, 32, 64, 64], name='YaResNet')" |
419 | 386 | ], |
|
422 | 389 | }, |
423 | 390 | { |
424 | 391 | "cell_type": "code", |
425 | | - "execution_count": 17, |
| 392 | + "execution_count": 16, |
426 | 393 | "source": [ |
427 | 394 | "yaresnet" |
428 | 395 | ], |
|
441 | 408 | ] |
442 | 409 | }, |
443 | 410 | "metadata": {}, |
444 | | - "execution_count": 17 |
| 411 | + "execution_count": 16 |
445 | 412 | } |
446 | 413 | ], |
447 | 414 | "metadata": {} |
448 | 415 | }, |
449 | 416 | { |
450 | 417 | "cell_type": "code", |
451 | | - "execution_count": 18, |
| 418 | + "execution_count": 17, |
452 | 419 | "source": [ |
453 | 420 | "yaresnet.block_sizes, yaresnet.layers" |
454 | 421 | ], |
|
461 | 428 | ] |
462 | 429 | }, |
463 | 430 | "metadata": {}, |
464 | | - "execution_count": 18 |
| 431 | + "execution_count": 17 |
465 | 432 | } |
466 | 433 | ], |
467 | 434 | "metadata": {} |
468 | 435 | }, |
469 | 436 | { |
470 | 437 | "cell_type": "code", |
471 | | - "execution_count": 19, |
| 438 | + "execution_count": 18, |
472 | 439 | "source": [ |
473 | 440 | "#collapse_output\n", |
474 | 441 | "yaresnet.stem" |
|
499 | 466 | ] |
500 | 467 | }, |
501 | 468 | "metadata": {}, |
502 | | - "execution_count": 19 |
| 469 | + "execution_count": 18 |
503 | 470 | } |
504 | 471 | ], |
505 | 472 | "metadata": {} |
506 | 473 | }, |
507 | 474 | { |
508 | 475 | "cell_type": "code", |
509 | | - "execution_count": 20, |
| 476 | + "execution_count": 19, |
510 | 477 | "source": [ |
511 | 478 | "#hide\n", |
512 | 479 | "bs_test = 16\n", |
|
536 | 503 | }, |
537 | 504 | { |
538 | 505 | "cell_type": "code", |
539 | | - "execution_count": 21, |
| 506 | + "execution_count": 20, |
540 | 507 | "source": [ |
541 | 508 | "#collapse_output\n", |
542 | 509 | "yaresnet.body" |
|
686 | 653 | ] |
687 | 654 | }, |
688 | 655 | "metadata": {}, |
689 | | - "execution_count": 21 |
| 656 | + "execution_count": 20 |
690 | 657 | } |
691 | 658 | ], |
692 | 659 | "metadata": {} |
693 | 660 | }, |
694 | 661 | { |
695 | 662 | "cell_type": "code", |
696 | | - "execution_count": 22, |
| 663 | + "execution_count": 21, |
697 | 664 | "source": [ |
698 | 665 | "#collapse_output\n", |
699 | 666 | "yaresnet.head" |
|
711 | 678 | ] |
712 | 679 | }, |
713 | 680 | "metadata": {}, |
714 | | - "execution_count": 22 |
| 681 | + "execution_count": 21 |
715 | 682 | } |
716 | 683 | ], |
717 | 684 | "metadata": {} |
|
725 | 692 | }, |
726 | 693 | { |
727 | 694 | "cell_type": "code", |
728 | | - "execution_count": 24, |
| 695 | + "execution_count": 22, |
729 | 696 | "source": [ |
730 | 697 | "#collapse_output\n", |
731 | 698 | "yaresnet.act_fn = Mish()\n", |
|
902 | 869 | ] |
903 | 870 | }, |
904 | 871 | "metadata": {}, |
905 | | - "execution_count": 24 |
| 872 | + "execution_count": 22 |
906 | 873 | } |
907 | 874 | ], |
908 | 875 | "metadata": {} |
|
923 | 890 | }, |
924 | 891 | { |
925 | 892 | "cell_type": "code", |
926 | | - "execution_count": 25, |
| 893 | + "execution_count": 23, |
927 | 894 | "source": [ |
928 | 895 | "yaresnet_parameters = {'block': YaResBlock, 'stem_sizes': [3, 32, 64, 64], 'act_fn': Mish(), 'stem_stride_on': 1}\n", |
929 | 896 | "yaresnet34 = partial(Net, name='YaResnet34', expansion=1, layers=[3, 4, 6, 3], **yaresnet_parameters)\n", |
|
934 | 901 | }, |
935 | 902 | { |
936 | 903 | "cell_type": "code", |
937 | | - "execution_count": 26, |
| 904 | + "execution_count": 24, |
938 | 905 | "source": [ |
939 | 906 | "model = yaresnet50(c_out=10)" |
940 | 907 | ], |
|
943 | 910 | }, |
944 | 911 | { |
945 | 912 | "cell_type": "code", |
946 | | - "execution_count": 27, |
| 913 | + "execution_count": 25, |
947 | 914 | "source": [ |
948 | 915 | "model" |
949 | 916 | ], |
|
962 | 929 | ] |
963 | 930 | }, |
964 | 931 | "metadata": {}, |
965 | | - "execution_count": 27 |
| 932 | + "execution_count": 25 |
966 | 933 | } |
967 | 934 | ], |
968 | 935 | "metadata": {} |
969 | 936 | }, |
970 | 937 | { |
971 | 938 | "cell_type": "code", |
972 | | - "execution_count": 28, |
| 939 | + "execution_count": 26, |
973 | 940 | "source": [ |
974 | 941 | "model.c_out, model.layers" |
975 | 942 | ], |
|
982 | 949 | ] |
983 | 950 | }, |
984 | 951 | "metadata": {}, |
985 | | - "execution_count": 28 |
| 952 | + "execution_count": 26 |
986 | 953 | } |
987 | 954 | ], |
988 | 955 | "metadata": {} |
989 | 956 | }, |
990 | 957 | { |
991 | 958 | "cell_type": "code", |
992 | | - "execution_count": 29, |
| 959 | + "execution_count": 27, |
993 | 960 | "source": [ |
994 | 961 | "#hide\n", |
995 | 962 | "bs_test = 16\n", |
|
0 commit comments