Skip to content

Commit c8ab747

Browse files
committed
BEiT-V2 checkpoints didn't remove 'module' from weights, adapt checkpoint filter
1 parent 73049dc commit c8ab747

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

timm/models/beit.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,14 +384,21 @@ def forward(self, x):
384384
return x
385385

386386

387+
def _beit_checkpoint_filter_fn(state_dict, model):
388+
if 'module' in state_dict:
389+
# beit v2 didn't strip module
390+
state_dict = state_dict['module']
391+
return checkpoint_filter_fn(state_dict, model)
392+
393+
387394
def _create_beit(variant, pretrained=False, **kwargs):
388395
if kwargs.get('features_only', None):
389396
raise RuntimeError('features_only not implemented for Beit models.')
390397

391398
model = build_model_with_cfg(
392399
Beit, variant, pretrained,
393400
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
394-
pretrained_filter_fn=checkpoint_filter_fn,
401+
pretrained_filter_fn=_beit_checkpoint_filter_fn,
395402
**kwargs)
396403
return model
397404

0 commit comments

Comments
 (0)