Skip to content

Commit 0c96685

Browse files
author
ayasyrev
committed
v 0.0.3 Fixed Blocks, models, rearranged modules
1 parent a488754 commit 0c96685

File tree

6 files changed

+0
-370
lines changed

6 files changed

+0
-370
lines changed

docs/constructor.html

Lines changed: 0 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -334,112 +334,6 @@ <h1 id="Body">Body<a class="anchor-link" href="#Body">&#182;</a></h1>
334334
<h2 id="BasicBlock">BasicBlock<a class="anchor-link" href="#BasicBlock">&#182;</a></h2>
335335
</div>
336336
</div>
337-
</div>
338-
<div class="cell border-box-sizing code_cell rendered">
339-
<div class="input">
340-
341-
<div class="inner_cell">
342-
<div class="input_area">
343-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># class BasicBlock(nn.Module):</span>
344-
<span class="c1"># &quot;&quot;&quot;Basic block (simplified) as in pytorch resnet&quot;&quot;&quot;</span>
345-
<span class="c1"># def __init__(self, ni, nf, expansion=1, stride=1,</span>
346-
<span class="c1"># bn_1st=True, zero_bn=False, </span>
347-
<span class="c1"># # groups=1, base_width=64, dilation=1, norm_layer=None</span>
348-
<span class="c1"># conv_layer=ConvLayer, **kwargs):</span>
349-
<span class="c1"># super().__init__()</span>
350-
<span class="c1"># self.downsample = not ni==nf or stride==2</span>
351-
<span class="c1"># self.conv = nn.Sequential(OrderedDict([</span>
352-
<span class="c1"># (&#39;conv_0&#39;, conv_layer(ni, nf, stride=stride, bn_1st=bn_1st, **kwargs)),</span>
353-
<span class="c1"># (&#39;conv_1&#39;, conv_layer(nf, nf, zero_bn=zero_bn, bn_1st=bn_1st, **kwargs))]))</span>
354-
<span class="c1"># if self.downsample:</span>
355-
<span class="c1"># self.downsample = conv_layer(ni, nf, ks=1, stride=stride, act=False, **kwargs)</span>
356-
<span class="c1"># self.merge = Noop()</span>
357-
<span class="c1"># self.act_conn = act_fn</span>
358-
359-
<span class="c1"># def forward(self, x):</span>
360-
<span class="c1"># identity = x</span>
361-
<span class="c1"># out = self.conv(x)</span>
362-
<span class="c1"># if self.downsample:</span>
363-
<span class="c1"># identity = self.downsample(x)</span>
364-
<span class="c1"># return self.act_conn(self.merge(out + identity))</span>
365-
</pre></div>
366-
367-
</div>
368-
</div>
369-
</div>
370-
371-
</div>
372-
<div class="cell border-box-sizing code_cell rendered">
373-
<div class="input">
374-
375-
<div class="inner_cell">
376-
<div class="input_area">
377-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># b_block = BasicBlock(64,64)</span>
378-
<span class="c1"># b_block</span>
379-
</pre></div>
380-
381-
</div>
382-
</div>
383-
</div>
384-
385-
</div>
386-
<div class="cell border-box-sizing code_cell rendered">
387-
<div class="input">
388-
389-
<div class="inner_cell">
390-
<div class="input_area">
391-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># b_block = BasicBlock(64,64, stride=2)</span>
392-
<span class="c1"># b_block</span>
393-
</pre></div>
394-
395-
</div>
396-
</div>
397-
</div>
398-
399-
</div>
400-
<div class="cell border-box-sizing code_cell rendered">
401-
<div class="input">
402-
403-
<div class="inner_cell">
404-
<div class="input_area">
405-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># xb = torch.randn(64, 64, 32, 32)</span>
406-
<span class="c1"># y = b_block(xb)</span>
407-
<span class="c1"># y.shape</span>
408-
</pre></div>
409-
410-
</div>
411-
</div>
412-
</div>
413-
414-
</div>
415-
<div class="cell border-box-sizing code_cell rendered">
416-
<div class="input">
417-
418-
<div class="inner_cell">
419-
<div class="input_area">
420-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># b_block = BasicBlock(64,128, stride=2)</span>
421-
<span class="c1"># b_block</span>
422-
</pre></div>
423-
424-
</div>
425-
</div>
426-
</div>
427-
428-
</div>
429-
<div class="cell border-box-sizing code_cell rendered">
430-
<div class="input">
431-
432-
<div class="inner_cell">
433-
<div class="input_area">
434-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># xb = torch.randn(64, 64, 32, 32)</span>
435-
<span class="c1"># y = b_block(xb)</span>
436-
<span class="c1"># y.shape</span>
437-
</pre></div>
438-
439-
</div>
440-
</div>
441-
</div>
442-
443337
</div>
444338
<div class="cell border-box-sizing text_cell rendered"><div class="inner_cell">
445339
<div class="text_cell_render border-box-sizing rendered_html">

docs/resnet.html

Lines changed: 0 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -37,70 +37,6 @@ <h1 id="Bottleneck">Bottleneck<a class="anchor-link" href="#Bottleneck">&#182;</
3737
<div class="cell border-box-sizing code_cell rendered">
3838
<div class="input">
3939

40-
<div class="inner_cell">
41-
<div class="input_area">
42-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># class Bottleneck(nn.Module):</span>
43-
<span class="c1"># &#39;&#39;&#39;Bottlneck block for resnet models&#39;&#39;&#39;</span>
44-
<span class="c1"># def __init__(self, ni, nh, expansion=4, stride=1, </span>
45-
<span class="c1"># bn_1st=False, zero_bn=False, **kwargs):</span>
46-
<span class="c1"># # groups=1, base_width=64, dilation=1, norm_layer=None</span>
47-
<span class="c1"># super().__init__()</span>
48-
<span class="c1"># self.downsample = not ni==nh or stride==2</span>
49-
<span class="c1"># ni = ni*expansion</span>
50-
<span class="c1"># nf = nh*expansion</span>
51-
<span class="c1"># self.conv = nn.Sequential(OrderedDict([</span>
52-
<span class="c1"># (&#39;conv_0&#39;, ConvLayer(ni, nh, ks=1, bn_1st=bn_1st, **kwargs)),</span>
53-
<span class="c1"># (&#39;conv_1&#39;, ConvLayer(nh, nh, stride=stride, bn_1st=bn_1st, **kwargs)),</span>
54-
<span class="c1"># (&#39;conv_2&#39;, ConvLayer(nh, nf, ks=1, zero_bn=zero_bn, bn_1st=bn_1st, **kwargs))]))</span>
55-
<span class="c1"># if self.downsample:</span>
56-
<span class="c1"># self.downsample = ConvLayer(ni, nf, ks=1, stride=stride, act=False, **kwargs)</span>
57-
<span class="c1"># self.merge = Noop()</span>
58-
<span class="c1"># self.act_conn = act_fn</span>
59-
60-
<span class="c1"># def forward(self, x):</span>
61-
<span class="c1"># identity = x</span>
62-
<span class="c1"># out = self.conv(x)</span>
63-
<span class="c1"># if self.downsample:</span>
64-
<span class="c1"># identity = self.downsample(x)</span>
65-
<span class="c1"># return self.act_conn(self.merge(out + identity))</span>
66-
</pre></div>
67-
68-
</div>
69-
</div>
70-
</div>
71-
72-
</div>
73-
<div class="cell border-box-sizing code_cell rendered">
74-
<div class="input">
75-
76-
<div class="inner_cell">
77-
<div class="input_area">
78-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># b_block = Bottleneck(16,64)</span>
79-
<span class="c1"># b_block</span>
80-
</pre></div>
81-
82-
</div>
83-
</div>
84-
</div>
85-
86-
</div>
87-
<div class="cell border-box-sizing code_cell rendered">
88-
<div class="input">
89-
90-
<div class="inner_cell">
91-
<div class="input_area">
92-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># b_block = Bottleneck(64,64)</span>
93-
<span class="c1"># b_block</span>
94-
</pre></div>
95-
96-
</div>
97-
</div>
98-
</div>
99-
100-
</div>
101-
<div class="cell border-box-sizing code_cell rendered">
102-
<div class="input">
103-
10440
<div class="inner_cell">
10541
<div class="input_area">
10642
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">body</span> <span class="o">=</span> <span class="n">Body</span><span class="p">(</span><span class="n">Bottleneck</span><span class="p">,</span> <span class="n">expansion</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>

docs/xresnet.html

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -31,41 +31,6 @@
3131
<div class="cell border-box-sizing code_cell rendered">
3232
<div class="input">
3333

34-
<div class="inner_cell">
35-
<div class="input_area">
36-
<div class=" highlight hl-ipython3"><pre><span></span><span class="c1"># class ResBlock(nn.Module):</span>
37-
<span class="c1"># def __init__(self, ni, nh, expansion=1, stride=1, conv_layer=ConvLayer, </span>
38-
<span class="c1"># act_fn=act_fn, **kwargs):</span>
39-
<span class="c1"># super().__init__()</span>
40-
<span class="c1"># # print(f&quot;ni {ni}, nh {nh}, exp {expansion}, stride {stride}&quot;)</span>
41-
<span class="c1"># nf,ni = nh*expansion,ni*expansion</span>
42-
<span class="c1"># # print(f&quot;new ni {ni}, nf {nf}&quot;)</span>
43-
<span class="c1"># layers = [(&#39;conv_0&#39;, conv_layer(ni, nh, 3, stride=stride)),</span>
44-
<span class="c1"># (&#39;conv_1&#39;, conv_layer(nh, nf, 3, zero_bn=True, act=False))</span>
45-
<span class="c1"># ] if expansion == 1 else [</span>
46-
<span class="c1"># (&#39;conv_0&#39;, conv_layer(ni, nh, 1)),</span>
47-
<span class="c1"># (&#39;conv_1&#39;, conv_layer(nh, nh, 3, stride=stride)),</span>
48-
<span class="c1"># (&#39;conv_2&#39;, conv_layer(nh, nf, 1, zero_bn=True, act=False))</span>
49-
<span class="c1"># ]</span>
50-
<span class="c1"># self.convs = nn.Sequential(OrderedDict(layers))</span>
51-
<span class="c1"># # TODO: check whether act=True works better</span>
52-
<span class="c1"># identity = [] if stride==1 else [(&#39;pool&#39;, nn.AvgPool2d(2, ceil_mode=True))]</span>
53-
<span class="c1"># identity += [] if ni==nf else [(&#39;idconv&#39;, conv_layer(ni, nf, 1, act=False))]</span>
54-
<span class="c1"># self.identity = Noop() if identity==[] else nn.Sequential(OrderedDict(identity))</span>
55-
<span class="c1"># self.merge = Noop() # us it to visualize in repr residual connection</span>
56-
<span class="c1"># self.act_fn = act_fn</span>
57-
58-
<span class="c1"># def forward(self, x): return self.act_fn(self.merge(self.convs(x) + self.identity(x)))</span>
59-
</pre></div>
60-
61-
</div>
62-
</div>
63-
</div>
64-
65-
</div>
66-
<div class="cell border-box-sizing code_cell rendered">
67-
<div class="input">
68-
6934
<div class="inner_cell">
7035
<div class="input_area">
7136
<div class=" highlight hl-ipython3"><pre><span></span><span class="n">body</span> <span class="o">=</span> <span class="n">Body</span><span class="p">(</span><span class="n">ResBlock</span><span class="p">,</span> <span class="n">expansion</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span>

nbs/00_constructor.ipynb

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -291,88 +291,6 @@
291291
"## BasicBlock"
292292
]
293293
},
294-
{
295-
"cell_type": "code",
296-
"execution_count": null,
297-
"metadata": {},
298-
"outputs": [],
299-
"source": [
300-
"# class BasicBlock(nn.Module):\n",
301-
"# \"\"\"Basic block (simplified) as in pytorch resnet\"\"\"\n",
302-
"# def __init__(self, ni, nf, expansion=1, stride=1,\n",
303-
"# bn_1st=True, zero_bn=False, \n",
304-
"# # groups=1, base_width=64, dilation=1, norm_layer=None\n",
305-
"# conv_layer=ConvLayer, **kwargs):\n",
306-
"# super().__init__()\n",
307-
"# self.downsample = not ni==nf or stride==2\n",
308-
"# self.conv = nn.Sequential(OrderedDict([\n",
309-
"# ('conv_0', conv_layer(ni, nf, stride=stride, bn_1st=bn_1st, **kwargs)),\n",
310-
"# ('conv_1', conv_layer(nf, nf, zero_bn=zero_bn, bn_1st=bn_1st, **kwargs))]))\n",
311-
"# if self.downsample:\n",
312-
"# self.downsample = conv_layer(ni, nf, ks=1, stride=stride, act=False, **kwargs)\n",
313-
"# self.merge = Noop()\n",
314-
"# self.act_conn = act_fn\n",
315-
" \n",
316-
"# def forward(self, x):\n",
317-
"# identity = x\n",
318-
"# out = self.conv(x)\n",
319-
"# if self.downsample:\n",
320-
"# identity = self.downsample(x)\n",
321-
"# return self.act_conn(self.merge(out + identity))"
322-
]
323-
},
324-
{
325-
"cell_type": "code",
326-
"execution_count": null,
327-
"metadata": {},
328-
"outputs": [],
329-
"source": [
330-
"# b_block = BasicBlock(64,64)\n",
331-
"# b_block"
332-
]
333-
},
334-
{
335-
"cell_type": "code",
336-
"execution_count": null,
337-
"metadata": {},
338-
"outputs": [],
339-
"source": [
340-
"# b_block = BasicBlock(64,64, stride=2)\n",
341-
"# b_block"
342-
]
343-
},
344-
{
345-
"cell_type": "code",
346-
"execution_count": null,
347-
"metadata": {},
348-
"outputs": [],
349-
"source": [
350-
"# xb = torch.randn(64, 64, 32, 32)\n",
351-
"# y = b_block(xb)\n",
352-
"# y.shape"
353-
]
354-
},
355-
{
356-
"cell_type": "code",
357-
"execution_count": null,
358-
"metadata": {},
359-
"outputs": [],
360-
"source": [
361-
"# b_block = BasicBlock(64,128, stride=2)\n",
362-
"# b_block"
363-
]
364-
},
365-
{
366-
"cell_type": "code",
367-
"execution_count": null,
368-
"metadata": {},
369-
"outputs": [],
370-
"source": [
371-
"# xb = torch.randn(64, 64, 32, 32)\n",
372-
"# y = b_block(xb)\n",
373-
"# y.shape"
374-
]
375-
},
376294
{
377295
"cell_type": "markdown",
378296
"metadata": {},

nbs/02_resnet.ipynb

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -59,58 +59,6 @@
5959
"# Bottleneck"
6060
]
6161
},
62-
{
63-
"cell_type": "code",
64-
"execution_count": null,
65-
"metadata": {},
66-
"outputs": [],
67-
"source": [
68-
"# class Bottleneck(nn.Module):\n",
69-
"# '''Bottlneck block for resnet models'''\n",
70-
"# def __init__(self, ni, nh, expansion=4, stride=1, \n",
71-
"# bn_1st=False, zero_bn=False, **kwargs):\n",
72-
"# # groups=1, base_width=64, dilation=1, norm_layer=None\n",
73-
"# super().__init__()\n",
74-
"# self.downsample = not ni==nh or stride==2\n",
75-
"# ni = ni*expansion\n",
76-
"# nf = nh*expansion\n",
77-
"# self.conv = nn.Sequential(OrderedDict([\n",
78-
"# ('conv_0', ConvLayer(ni, nh, ks=1, bn_1st=bn_1st, **kwargs)),\n",
79-
"# ('conv_1', ConvLayer(nh, nh, stride=stride, bn_1st=bn_1st, **kwargs)),\n",
80-
"# ('conv_2', ConvLayer(nh, nf, ks=1, zero_bn=zero_bn, bn_1st=bn_1st, **kwargs))]))\n",
81-
"# if self.downsample:\n",
82-
"# self.downsample = ConvLayer(ni, nf, ks=1, stride=stride, act=False, **kwargs)\n",
83-
"# self.merge = Noop()\n",
84-
"# self.act_conn = act_fn\n",
85-
"\n",
86-
"# def forward(self, x):\n",
87-
"# identity = x\n",
88-
"# out = self.conv(x)\n",
89-
"# if self.downsample:\n",
90-
"# identity = self.downsample(x)\n",
91-
"# return self.act_conn(self.merge(out + identity))"
92-
]
93-
},
94-
{
95-
"cell_type": "code",
96-
"execution_count": null,
97-
"metadata": {},
98-
"outputs": [],
99-
"source": [
100-
"# b_block = Bottleneck(16,64)\n",
101-
"# b_block"
102-
]
103-
},
104-
{
105-
"cell_type": "code",
106-
"execution_count": null,
107-
"metadata": {},
108-
"outputs": [],
109-
"source": [
110-
"# b_block = Bottleneck(64,64)\n",
111-
"# b_block"
112-
]
113-
},
11462
{
11563
"cell_type": "code",
11664
"execution_count": null,

0 commit comments

Comments
 (0)