Skip to content

Commit 47e6958

Browse files
committed
Add hierdet_small (non sam) model def
1 parent 9fcbf39 commit 47e6958

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

timm/models/hieradet_sam2.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def prune_intermediate_layers(
500500
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
501501
self.blocks = self.blocks[:max_index + 1] # truncate blocks
502502
if prune_head:
503-
self.head.reset(0, reset_other=True)
503+
self.head.reset(0, reset_other=prune_norm)
504504
return take_indices
505505

506506
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
@@ -556,6 +556,10 @@ def _cfg(url='', **kwargs):
556556
min_input_size=(3, 256, 256),
557557
input_size=(3, 1024, 1024), pool_size=(32, 32),
558558
),
559+
"hieradet_small.untrained": _cfg(
560+
num_classes=1000,
561+
input_size=(3, 256, 256), pool_size=(8, 8),
562+
),
559563
})
560564

561565

@@ -604,12 +608,6 @@ def sam2_hiera_small(pretrained=False, **kwargs):
604608
return _create_hiera_det('sam2_hiera_small', pretrained=pretrained, **dict(model_args, **kwargs))
605609

606610

607-
# @register_model
608-
# def sam2_hiera_base(pretrained=False, **kwargs):
609-
# model_args = dict()
610-
# return _create_hiera_det('sam2_hiera_base', pretrained=pretrained, **dict(model_args, **kwargs))
611-
612-
613611
@register_model
614612
def sam2_hiera_base_plus(pretrained=False, **kwargs):
615613
model_args = dict(embed_dim=112, num_heads=2, global_pos_size=(14, 14))
@@ -626,3 +624,15 @@ def sam2_hiera_large(pretrained=False, **kwargs):
626624
window_spec=(8, 4, 16, 8),
627625
)
628626
return _create_hiera_det('sam2_hiera_large', pretrained=pretrained, **dict(model_args, **kwargs))
627+
628+
629+
@register_model
630+
def hieradet_small(pretrained=False, **kwargs):
631+
model_args = dict(stages=(1, 2, 11, 2), global_att_blocks=(7, 10, 13), window_spec=(8, 4, 16, 8))
632+
return _create_hiera_det('hieradet_small', pretrained=pretrained, **dict(model_args, **kwargs))
633+
634+
635+
# @register_model
636+
# def hieradet_base(pretrained=False, **kwargs):
637+
# model_args = dict(window_spec=(8, 4, 16, 8))
638+
# return _create_hiera_det('hieradet_base', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)