|
38 | 38 | "# export\n", |
39 | 39 | "import torch.nn as nn\n", |
40 | 40 | "import torch\n", |
41 | | - "from collections import OrderedDict" |
| 41 | + "from collections import OrderedDict\n", |
| 42 | + "from model_constructor.layers import *" |
42 | 43 | ] |
43 | 44 | }, |
44 | 45 | { |
|
47 | 48 | "metadata": {}, |
48 | 49 | "outputs": [], |
49 | 50 | "source": [ |
50 | | - "test_eq(1, 1**2)" |
51 | | - ] |
52 | | - }, |
53 | | - { |
54 | | - "cell_type": "code", |
55 | | - "execution_count": null, |
56 | | - "metadata": {}, |
57 | | - "outputs": [], |
58 | | - "source": [ |
59 | | - "# test_eq(1, 2)" |
60 | | - ] |
61 | | - }, |
62 | | - { |
63 | | - "cell_type": "markdown", |
64 | | - "metadata": {}, |
65 | | - "source": [ |
66 | | - "# Base ConvLayer" |
67 | | - ] |
68 | | - }, |
69 | | - { |
70 | | - "cell_type": "code", |
71 | | - "execution_count": null, |
72 | | - "metadata": {}, |
73 | | - "outputs": [], |
74 | | - "source": [ |
75 | | - "# export\n", |
76 | | - "_act_fn = nn.ReLU(inplace=True)\n", |
| 51 | + "# _act_fn = nn.ReLU(inplace=True)\n", |
77 | 52 | "\n", |
78 | | - "class ConvLayer(nn.Sequential):\n", |
79 | | - " \"\"\"Basic conv layers block\"\"\"\n", |
80 | | - " def __init__(self, ni, nf, ks=3, stride=1, \n", |
81 | | - " act=True, act_fn=_act_fn, \n", |
82 | | - " bn_layer=True, bn_1st=False, zero_bn=False, \n", |
83 | | - " padding=None, bias=True, groups=1):\n", |
| 53 | + "# class ConvLayer(nn.Sequential):\n", |
| 54 | + "# \"\"\"Basic conv layers block\"\"\"\n", |
| 55 | + "# def __init__(self, ni, nf, ks=3, stride=1, \n", |
| 56 | + "# act=True, act_fn=_act_fn, \n", |
| 57 | + "# bn_layer=True, bn_1st=False, zero_bn=False, \n", |
| 58 | + "# padding=None, bias=True, groups=1):\n", |
84 | 59 | "\n", |
85 | | - " self.act = act\n", |
86 | | - " if padding==None: padding = ks//2 \n", |
87 | | - " layers = [('conv', nn.Conv2d(ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups))]\n", |
88 | | - " act_bn = [('act_fn', act_fn)] if act else []\n", |
89 | | - " if bn_layer:\n", |
90 | | - " bn = nn.BatchNorm2d(nf)\n", |
91 | | - " nn.init.constant_(bn.weight, 0. if zero_bn else 1.) \n", |
92 | | - " act_bn += [('bn', bn)]\n", |
93 | | - " if bn_1st: act_bn.reverse()\n", |
94 | | - " layers += act_bn\n", |
95 | | - " super().__init__(OrderedDict(layers))" |
| 60 | + "# self.act = act\n", |
| 61 | + "# if padding==None: padding = ks//2 \n", |
| 62 | + "# layers = [('conv', nn.Conv2d(ni, nf, ks, stride=stride, padding=padding, bias=bias, groups=groups))]\n", |
| 63 | + "# act_bn = [('act_fn', act_fn)] if act else []\n", |
| 64 | + "# if bn_layer:\n", |
| 65 | + "# bn = nn.BatchNorm2d(nf)\n", |
| 66 | + "# nn.init.constant_(bn.weight, 0. if zero_bn else 1.) \n", |
| 67 | + "# act_bn += [('bn', bn)]\n", |
| 68 | + "# if bn_1st: act_bn.reverse()\n", |
| 69 | + "# layers += act_bn\n", |
| 70 | + "# super().__init__(OrderedDict(layers))" |
96 | 71 | ] |
97 | 72 | }, |
98 | 73 | { |
99 | 74 | "cell_type": "code", |
100 | 75 | "execution_count": null, |
101 | 76 | "metadata": {}, |
102 | | - "outputs": [ |
103 | | - { |
104 | | - "data": { |
105 | | - "text/plain": [ |
106 | | - "ConvLayer(\n", |
107 | | - " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", |
108 | | - " (act_fn): ReLU(inplace=True)\n", |
109 | | - " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
110 | | - ")" |
111 | | - ] |
112 | | - }, |
113 | | - "execution_count": null, |
114 | | - "metadata": {}, |
115 | | - "output_type": "execute_result" |
116 | | - } |
117 | | - ], |
| 77 | + "outputs": [], |
118 | 78 | "source": [ |
119 | | - "conv_layer = ConvLayer(32, 64)\n", |
120 | | - "conv_layer" |
| 79 | + "# conv_layer = ConvLayer(32, 64)\n", |
| 80 | + "# conv_layer" |
121 | 81 | ] |
122 | 82 | }, |
123 | 83 | { |
124 | 84 | "cell_type": "code", |
125 | 85 | "execution_count": null, |
126 | 86 | "metadata": {}, |
127 | | - "outputs": [ |
128 | | - { |
129 | | - "data": { |
130 | | - "text/plain": [ |
131 | | - "ConvLayer(\n", |
132 | | - " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", |
133 | | - " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
134 | | - ")" |
135 | | - ] |
136 | | - }, |
137 | | - "execution_count": null, |
138 | | - "metadata": {}, |
139 | | - "output_type": "execute_result" |
140 | | - } |
141 | | - ], |
| 87 | + "outputs": [], |
142 | 88 | "source": [ |
143 | | - "conv_layer = ConvLayer(32, 64, act=False)\n", |
144 | | - "conv_layer" |
| 89 | + "# conv_layer = ConvLayer(32, 64, act=False)\n", |
| 90 | + "# conv_layer" |
145 | 91 | ] |
146 | 92 | }, |
147 | 93 | { |
148 | 94 | "cell_type": "code", |
149 | 95 | "execution_count": null, |
150 | 96 | "metadata": {}, |
151 | | - "outputs": [ |
152 | | - { |
153 | | - "data": { |
154 | | - "text/plain": [ |
155 | | - "ConvLayer(\n", |
156 | | - " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", |
157 | | - " (act_fn): ReLU(inplace=True)\n", |
158 | | - ")" |
159 | | - ] |
160 | | - }, |
161 | | - "execution_count": null, |
162 | | - "metadata": {}, |
163 | | - "output_type": "execute_result" |
164 | | - } |
165 | | - ], |
| 97 | + "outputs": [], |
166 | 98 | "source": [ |
167 | | - "conv_layer = ConvLayer(32, 64, bn_layer=False)\n", |
168 | | - "conv_layer" |
| 99 | + "# conv_layer = ConvLayer(32, 64, bn_layer=False)\n", |
| 100 | + "# conv_layer" |
169 | 101 | ] |
170 | 102 | }, |
171 | 103 | { |
172 | 104 | "cell_type": "code", |
173 | 105 | "execution_count": null, |
174 | 106 | "metadata": {}, |
175 | | - "outputs": [ |
176 | | - { |
177 | | - "data": { |
178 | | - "text/plain": [ |
179 | | - "ConvLayer(\n", |
180 | | - " (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n", |
181 | | - " (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", |
182 | | - " (act_fn): ReLU(inplace=True)\n", |
183 | | - ")" |
184 | | - ] |
185 | | - }, |
186 | | - "execution_count": null, |
187 | | - "metadata": {}, |
188 | | - "output_type": "execute_result" |
189 | | - } |
190 | | - ], |
| 107 | + "outputs": [], |
191 | 108 | "source": [ |
192 | | - "conv_layer = ConvLayer(32, 64, bn_1st=True)\n", |
193 | | - "conv_layer" |
| 109 | + "# conv_layer = ConvLayer(32, 64, bn_1st=True)\n", |
| 110 | + "# conv_layer" |
194 | 111 | ] |
195 | 112 | }, |
196 | 113 | { |
|
206 | 123 | "metadata": {}, |
207 | 124 | "outputs": [], |
208 | 125 | "source": [ |
209 | | - "# export\n", |
210 | | - "class Flatten(nn.Module):\n", |
211 | | - " '''flat x to vector'''\n", |
212 | | - " def __init__(self):\n", |
213 | | - " super().__init__()\n", |
214 | | - " def forward(self, x): return x.view(x.size(0), -1)" |
| 126 | + "# class Flatten(nn.Module):\n", |
| 127 | + "# '''flat x to vector'''\n", |
| 128 | + "# def __init__(self):\n", |
| 129 | + "# super().__init__()\n", |
| 130 | + "# def forward(self, x): return x.view(x.size(0), -1)" |
215 | 131 | ] |
216 | 132 | }, |
217 | 133 | { |
|
220 | 136 | "metadata": {}, |
221 | 137 | "outputs": [], |
222 | 138 | "source": [ |
223 | | - "# export\n", |
224 | | - "class Noop(nn.Module): # alternative name Merge\n", |
225 | | - " '''Dummy module for vizualize skip conn'''\n", |
226 | | - " def __init__(self):\n", |
227 | | - " super().__init__()\n", |
| 139 | + "# class Noop(nn.Module): # alternative name Merge\n", |
| 140 | + "# '''Dummy module for vizualize skip conn'''\n", |
| 141 | + "# def __init__(self):\n", |
| 142 | + "# super().__init__()\n", |
228 | 143 | " \n", |
229 | | - " def forward(self, x):\n", |
230 | | - " return x" |
| 144 | + "# def forward(self, x):\n", |
| 145 | + "# return x" |
231 | 146 | ] |
232 | 147 | }, |
233 | 148 | { |
|
440 | 355 | "cell_type": "markdown", |
441 | 356 | "metadata": {}, |
442 | 357 | "source": [ |
443 | | - "## Block" |
| 358 | + "## BasicBlock" |
444 | 359 | ] |
445 | 360 | }, |
446 | 361 | { |
|
452 | 367 | "# export\n", |
453 | 368 | "class BasicBlock(nn.Module):\n", |
454 | 369 | " \"\"\"Basic block (simplified) as in pytorch resnet\"\"\"\n", |
455 | | - " expansion = 1\n", |
456 | 370 | " def __init__(self, ni, nf, stride=1, bn_1st=False, zero_bn=False,\n", |
457 | 371 | "# groups=1, base_width=64, dilation=1, norm_layer=None\n", |
458 | | - " **kwargs):\n", |
| 372 | + " expansion = 1, **kwargs):\n", |
459 | 373 | " super().__init__()\n", |
460 | 374 | " self.downsample = not ni==nf or stride==2\n", |
461 | 375 | " self.conv = nn.Sequential(OrderedDict([\n", |
|
464 | 378 | " if self.downsample:\n", |
465 | 379 | " self.downsample = ConvLayer(ni, nf, ks=1, stride=stride, act=False, **kwargs)\n", |
466 | 380 | " self.merge = Noop()\n", |
467 | | - " self.act_conn = _act_fn\n", |
| 381 | + " self.act_conn = act_fn\n", |
468 | 382 | " \n", |
469 | 383 | " def forward(self, x):\n", |
470 | 384 | " identity = x\n", |
|
558 | 472 | { |
559 | 473 | "data": { |
560 | 474 | "text/plain": [ |
561 | | - "torch.Size([64, 64, 32, 32])" |
| 475 | + "torch.Size([64, 64, 16, 16])" |
562 | 476 | ] |
563 | 477 | }, |
564 | 478 | "execution_count": null, |
|
654 | 568 | " body_in=64, body_out=512, \n", |
655 | 569 | " layer_szs=[64,128,256,], blocks=[2,2,2,2],\n", |
656 | 570 | " expansion=1, **kwargs): # Downsample Module as parameter\n", |
657 | | - " layer_szs = [body_in] + layer_szs + [body_out]\n", |
| 571 | + " layer_szs = [body_in//expansion] + layer_szs + [body_out]\n", |
658 | 572 | " num_layers = len(layer_szs)-1\n", |
659 | 573 | " layers = [(f\"layer_{i}\", self._make_layer(block, layer_szs[i], layer_szs[i+1], blocks[i], 1 if i==0 else 2, **kwargs))\n", |
660 | 574 | " for i in range(num_layers)]\n", |
|
971 | 885 | " # block_szs = [64,128,128,256,256,512]\n", |
972 | 886 | " super().__init__(OrderedDict([\n", |
973 | 887 | " ('stem', stem(c_in=c_in,stem_out=body_in, **kwargs)),\n", |
974 | | - " ('body', body(block, body_in, body_out, layer_szs=layer_szs, blocks=blocks, **kwargs)),\n", |
| 888 | + " ('body', body(block, body_in, body_out, \n", |
| 889 | + " layer_szs=layer_szs, blocks=blocks, expansion=expansion, **kwargs)),\n", |
975 | 890 | " ('head', head(body_out*expansion, num_classes, **kwargs))\n", |
976 | 891 | " ]))\n", |
977 | 892 | " init_model(self)" |
|
1919 | 1834 | "cell_type": "markdown", |
1920 | 1835 | "metadata": {}, |
1921 | 1836 | "source": [ |
1922 | | - "# fin" |
| 1837 | + "# model_constructor\n", |
| 1838 | + "by ayasyrev" |
1923 | 1839 | ] |
1924 | 1840 | }, |
1925 | 1841 | { |
|
1939 | 1855 | "output_type": "stream", |
1940 | 1856 | "text": [ |
1941 | 1857 | "Converted 00_constructor.ipynb.\n", |
1942 | | - "Converted 01_resnet.ipynb.\n", |
| 1858 | + "Converted 01_layers.ipynb.\n", |
| 1859 | + "Converted 02_resnet.ipynb.\n", |
1943 | 1860 | "Converted index.ipynb.\n" |
1944 | 1861 | ] |
1945 | 1862 | } |
|
0 commit comments