File tree Expand file tree Collapse file tree 1 file changed +8
-1
lines changed
Expand file tree Collapse file tree 1 file changed +8
-1
lines changed Original file line number Diff line number Diff line change @@ -384,14 +384,21 @@ def forward(self, x):
384384 return x
385385
386386
387+ def _beit_checkpoint_filter_fn (state_dict , model ):
388+ if 'module' in state_dict :
389+ # beit v2 didn't strip module
390+ state_dict = state_dict ['module' ]
391+ return checkpoint_filter_fn (state_dict , model )
392+
393+
387394def _create_beit (variant , pretrained = False , ** kwargs ):
388395 if kwargs .get ('features_only' , None ):
389396 raise RuntimeError ('features_only not implemented for Beit models.' )
390397
391398 model = build_model_with_cfg (
392399 Beit , variant , pretrained ,
393400 # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
394- pretrained_filter_fn = checkpoint_filter_fn ,
401+ pretrained_filter_fn = _beit_checkpoint_filter_fn ,
395402 ** kwargs )
396403 return model
397404
You can’t perform that action at this time.
0 commit comments