Skip to content

Commit 7c97e66

Browse files
committed
Remove commented code, add more consistent seed fn
1 parent 364dd6a commit 7c97e66

File tree

4 files changed

+11
-47
lines changed

4 files changed

+11
-47
lines changed

timm/models/byobnet.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -515,52 +515,6 @@ def create_block(block: Union[str, nn.Module], **kwargs):
515515
return _block_registry[block](**kwargs)
516516

517517

518-
# class Stem(nn.Module):
519-
#
520-
# def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
521-
# num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
522-
# super().__init__()
523-
# assert stride in (2, 4)
524-
# if pool:
525-
# assert stride == 4
526-
# layers = layers or LayerFn()
527-
#
528-
# if isinstance(out_chs, (list, tuple)):
529-
# num_rep = len(out_chs)
530-
# stem_chs = out_chs
531-
# else:
532-
# stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
533-
#
534-
# self.stride = stride
535-
# stem_strides = [2] + [1] * (num_rep - 1)
536-
# if stride == 4 and not pool:
537-
# # set last conv in stack to be strided if stride == 4 and no pooling layer
538-
# stem_strides[-1] = 2
539-
#
540-
# num_act = num_rep if num_act is None else num_act
541-
# # if num_act < num_rep, first convs in stack won't have bn + act
542-
# stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
543-
# prev_chs = in_chs
544-
# convs = []
545-
# for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
546-
# layer_fn = layers.conv_norm_act if na else create_conv2d
547-
# convs.append(layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
548-
# prev_chs = ch
549-
# self.conv = nn.Sequential(*convs) if len(convs) > 1 else convs[0]
550-
#
551-
# if not pool:
552-
# self.pool = nn.Identity()
553-
# elif 'max' in pool.lower():
554-
# self.pool = nn.MaxPool2d(3, 2, 1) if pool else nn.Identity()
555-
# else:
556-
# assert False, "Unknown pooling type"
557-
#
558-
# def forward(self, x):
559-
# x = self.conv(x)
560-
# x = self.pool(x)
561-
# return x
562-
563-
564518
class Stem(nn.Sequential):
565519

566520
def __init__(self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',

timm/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@
99
from .misc import natural_key, add_bool_arg
1010
from .model import unwrap_model, get_state_dict
1111
from .model_ema import ModelEma, ModelEmaV2
12+
from .random import random_seed
1213
from .summary import update_summary, get_outdir

timm/utils/random.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import random
2+
import numpy as np
3+
import torch
4+
5+
6+
def random_seed(seed=42, rank=0):
7+
torch.manual_seed(seed + rank)
8+
np.random.seed(seed + rank)
9+
random.seed(seed + rank)

train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def main():
329329
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
330330
"Install NVIDA apex or upgrade to PyTorch 1.6")
331331

332-
torch.manual_seed(args.seed + args.rank)
332+
random_seed(args.seed, args.rank)
333333

334334
model = create_model(
335335
args.model,

0 commit comments

Comments
 (0)