Skip to content

Commit a8e3405

Browse files
committed
Unbreak gamma remap impacting beit checkpoint load, version bump to 0.6.4
1 parent 1ccce50 commit a8e3405

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

timm/models/deit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
"""
1111
# Copyright (c) 2015-present, Facebook, Inc.
1212
# All rights reserved.
13+
from functools import partial
14+
1315
import torch
1416
from torch import nn as nn
1517

@@ -177,7 +179,7 @@ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
177179
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
178180
model = build_model_with_cfg(
179181
model_cls, variant, pretrained,
180-
pretrained_filter_fn=checkpoint_filter_fn,
182+
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
181183
**kwargs)
182184
return model
183185

timm/models/vision_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
626626
return posemb
627627

628628

629-
def checkpoint_filter_fn(state_dict, model):
629+
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
630630
""" convert patch embedding weight from manual patchify + linear proj to conv"""
631631
import re
632632
out_dict = {}
@@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model):
647647
getattr(model, 'num_prefix_tokens', 1),
648648
model.patch_embed.grid_size
649649
)
650-
elif 'gamma_' in k:
650+
elif adapt_layer_scale and 'gamma_' in k:
651651
# remap layer-scale gamma into sub-module (deit3 models)
652652
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
653653
elif 'pre_logits' in k:

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.6.3.dev0'
1+
__version__ = '0.6.4'

0 commit comments

Comments
 (0)