@@ -515,52 +515,6 @@ def create_block(block: Union[str, nn.Module], **kwargs):
515515 return _block_registry [block ](** kwargs )
516516
517517
518- # class Stem(nn.Module):
519- #
520- # def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
521- # num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
522- # super().__init__()
523- # assert stride in (2, 4)
524- # if pool:
525- # assert stride == 4
526- # layers = layers or LayerFn()
527- #
528- # if isinstance(out_chs, (list, tuple)):
529- # num_rep = len(out_chs)
530- # stem_chs = out_chs
531- # else:
532- # stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
533- #
534- # self.stride = stride
535- # stem_strides = [2] + [1] * (num_rep - 1)
536- # if stride == 4 and not pool:
537- # # set last conv in stack to be strided if stride == 4 and no pooling layer
538- # stem_strides[-1] = 2
539- #
540- # num_act = num_rep if num_act is None else num_act
541- # # if num_act < num_rep, first convs in stack won't have bn + act
542- # stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
543- # prev_chs = in_chs
544- # convs = []
545- # for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
546- # layer_fn = layers.conv_norm_act if na else create_conv2d
547- # convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
548- # prev_chs = ch
549- # self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
550- #
551- # if not pool:
552- # self.pool = nn.Identity()
553- # elif 'max' in pool.lower():
554- # self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
555- # else:
556- # assert False, "Unknown pooling type"
557- #
558- # def forward(self, x):
559- # x = self.conv(x)
560- # x = self.pool(x)
561- # return x
562-
563-
564518class Stem (nn .Sequential ):
565519
566520 def __init__ (self , in_chs , out_chs , kernel_size = 3 , stride = 4 , pool = 'maxpool' ,
0 commit comments