|
342 | 342 | "assert y.shape == torch.Size([bs_test, 512, 16, 16]), f\"size\"" |
343 | 343 | ] |
344 | 344 | }, |
| 345 | + { |
| 346 | + "cell_type": "markdown", |
| 347 | + "metadata": {}, |
| 348 | + "source": [ |
| 349 | + "# Stem, Body, Head" |
| 350 | + ] |
| 351 | + }, |
| 352 | + { |
| 353 | + "cell_type": "code", |
| 354 | + "execution_count": null, |
| 355 | + "metadata": {}, |
| 356 | + "outputs": [], |
| 357 | + "source": [ |
| 358 | + "# export\n", |
| 359 | + "def _make_stem(self):\n", |
| 360 | + " stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n", |
| 361 | + " stride=2 if i==0 else 1, \n", |
| 362 | + " bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n", |
| 363 | + " act_fn=self.act_fn, bn_1st=self.bn_1st))\n", |
| 364 | + " for i in range(len(self.stem_sizes)-1)]\n", |
| 365 | + " stem.append(('stem_pool', self.stem_pool))\n", |
| 366 | + " if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n", |
| 367 | + " return nn.Sequential(OrderedDict(stem))" |
| 368 | + ] |
| 369 | + }, |
| 370 | + { |
| 371 | + "cell_type": "code", |
| 372 | + "execution_count": null, |
| 373 | + "metadata": {}, |
| 374 | + "outputs": [], |
| 375 | + "source": [ |
| 376 | + "# export\n", |
| 377 | + "def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n", |
| 378 | + " return nn.Sequential(OrderedDict(\n", |
| 379 | + " [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n", |
| 380 | + " stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n", |
| 381 | + " conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n", |
| 382 | + " zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n", |
| 383 | + " for i in range(blocks)]))" |
| 384 | + ] |
| 385 | + }, |
| 386 | + { |
| 387 | + "cell_type": "code", |
| 388 | + "execution_count": null, |
| 389 | + "metadata": {}, |
| 390 | + "outputs": [], |
| 391 | + "source": [ |
| 392 | + "# export\n", |
| 393 | + "def _make_body(self):\n", |
| 394 | + " blocks = [(f\"l_{i}\", self._make_layer(self,self.expansion, \n", |
| 395 | + " self.block_szs[i], self.block_szs[i+1], l, \n", |
| 396 | + " 1 if i==0 else 2, self.sa if i==0 else False))\n", |
| 397 | + " for i,l in enumerate(self.layers)]\n", |
| 398 | + " return nn.Sequential(OrderedDict(blocks))" |
| 399 | + ] |
| 400 | + }, |
| 401 | + { |
| 402 | + "cell_type": "code", |
| 403 | + "execution_count": null, |
| 404 | + "metadata": {}, |
| 405 | + "outputs": [], |
| 406 | + "source": [ |
| 407 | + "# export\n", |
| 408 | + "def _make_head(self):\n", |
| 409 | + " head = [('pool', nn.AdaptiveAvgPool2d(1)),\n", |
| 410 | + " ('flat', Flatten()),\n", |
| 411 | + " ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n", |
| 412 | + " return nn.Sequential(OrderedDict(head))" |
| 413 | + ] |
| 414 | + }, |
345 | 415 | { |
346 | 416 | "cell_type": "markdown", |
347 | 417 | "metadata": {}, |
|
372 | 442 | " self.sa=False\n", |
373 | 443 | " self.bn_1st = True\n", |
374 | 444 | " self.zero_bn=True\n", |
375 | | - " self._init_cnn = init_cnn\n", |
376 | 445 | " self.conv_layer = ConvLayer\n", |
| 446 | + " self._init_cnn = init_cnn\n", |
| 447 | + " self._make_stem = _make_stem\n", |
| 448 | + " self._make_layer = _make_layer\n", |
| 449 | + " self._make_body = _make_body\n", |
| 450 | + " self._make_head = _make_head\n", |
| 451 | + " \n", |
377 | 452 | " \n", |
378 | 453 | " @property\n", |
379 | 454 | " def block_szs(self):\n", |
380 | 455 | " return [64//self.expansion,64,128,256,512] +[256]*(len(self.layers)-4) \n", |
381 | 456 | "\n", |
382 | 457 | " @property\n", |
383 | 458 | " def stem(self):\n", |
384 | | - " return self._make_stem()\n", |
| 459 | + " return self._make_stem(self)\n", |
385 | 460 | " @property\n", |
386 | 461 | " def head(self):\n", |
387 | | - " return self._make_head()\n", |
| 462 | + " return self._make_head(self)\n", |
| 463 | + "# @property\n", |
| 464 | + "# def _make_layer(self):\n", |
| 465 | + "# return self.__make_layer(self)\n", |
388 | 466 | " @property\n", |
389 | 467 | " def body(self):\n", |
390 | | - " return self._make_body()\n", |
| 468 | + " return self._make_body(self)\n", |
391 | 469 | " \n", |
392 | | - " def _make_stem(self):\n", |
393 | | - " stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n", |
394 | | - " stride=2 if i==0 else 1, \n", |
395 | | - " bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n", |
396 | | - " act_fn=self.act_fn, bn_1st=self.bn_1st))\n", |
397 | | - " for i in range(len(self.stem_sizes)-1)]\n", |
398 | | - " stem.append(('stem_pool', self.stem_pool))\n", |
399 | | - " if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n", |
400 | | - " return nn.Sequential(OrderedDict(stem))\n", |
| 470 | + "# def _make_stem(self):\n", |
| 471 | + "# stem = [(f\"conv_{i}\", self.conv_layer(self.stem_sizes[i], self.stem_sizes[i+1], \n", |
| 472 | + "# stride=2 if i==0 else 1, \n", |
| 473 | + "# bn_layer=(not self.stem_bn_end) if i==(len(self.stem_sizes)-2) else True,\n", |
| 474 | + "# act_fn=self.act_fn, bn_1st=self.bn_1st))\n", |
| 475 | + "# for i in range(len(self.stem_sizes)-1)]\n", |
| 476 | + "# stem.append(('stem_pool', self.stem_pool))\n", |
| 477 | + "# if self.stem_bn_end: stem.append(('norm', self.norm(self.stem_sizes[-1])))\n", |
| 478 | + "# return nn.Sequential(OrderedDict(stem))\n", |
401 | 479 | " \n", |
402 | | - " def _make_head(self):\n", |
403 | | - " head = [('pool', nn.AdaptiveAvgPool2d(1)),\n", |
404 | | - " ('flat', Flatten()),\n", |
405 | | - " ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n", |
406 | | - " return nn.Sequential(OrderedDict(head))\n", |
| 480 | + "# def _make_head(self):\n", |
| 481 | + "# head = [('pool', nn.AdaptiveAvgPool2d(1)),\n", |
| 482 | + "# ('flat', Flatten()),\n", |
| 483 | + "# ('fc', nn.Linear(self.block_szs[-1]*self.expansion, self.c_out))]\n", |
| 484 | + "# return nn.Sequential(OrderedDict(head))\n", |
407 | 485 | " \n", |
408 | | - " def _make_body(self):\n", |
409 | | - " blocks = [(f\"l_{i}\", self._make_layer(self.expansion, \n", |
410 | | - " self.block_szs[i], self.block_szs[i+1], l, \n", |
411 | | - " 1 if i==0 else 2, self.sa if i==0 else False))\n", |
412 | | - " for i,l in enumerate(self.layers)]\n", |
413 | | - " return nn.Sequential(OrderedDict(blocks))\n", |
| 486 | + "# def _make_body(self):\n", |
| 487 | + "# blocks = [(f\"l_{i}\", self._make_layer(self.expansion, \n", |
| 488 | + "# self.block_szs[i], self.block_szs[i+1], l, \n", |
| 489 | + "# 1 if i==0 else 2, self.sa if i==0 else False))\n", |
| 490 | + "# for i,l in enumerate(self.layers)]\n", |
| 491 | + "# return nn.Sequential(OrderedDict(blocks))\n", |
414 | 492 | " \n", |
415 | | - " def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n", |
416 | | - " return nn.Sequential(OrderedDict(\n", |
417 | | - " [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n", |
418 | | - " stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n", |
419 | | - " conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n", |
420 | | - " zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n", |
421 | | - " for i in range(blocks)]))\n", |
| 493 | + "# def _make_layer(self,expansion,ni,nf,blocks,stride,sa):\n", |
| 494 | + "# return nn.Sequential(OrderedDict(\n", |
| 495 | + "# [(f\"bl_{i}\", self.block(expansion, ni if i==0 else nf, nf, \n", |
| 496 | + "# stride if i==0 else 1, sa=sa if i==blocks-1 else False,\n", |
| 497 | + "# conv_layer=self.conv_layer, act_fn=self.act_fn, pool=self.pool,\n", |
| 498 | + "# zero_bn=self.zero_bn, bn_1st=self.bn_1st))\n", |
| 499 | + "# for i in range(blocks)]))\n", |
422 | 500 | " \n", |
423 | 501 | " def __call__(self):\n", |
424 | 502 | " model = nn.Sequential(OrderedDict([\n", |
|
0 commit comments