|
37 | 37 |
|
38 | 38 | from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ |
39 | 39 | OPENAI_CLIP_MEAN, OPENAI_CLIP_STD |
40 | | -from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ |
41 | | - resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked |
| 40 | +from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \ |
| 41 | + trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn |
42 | 42 | from ._builder import build_model_with_cfg |
43 | 43 | from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv |
44 | 44 | from ._registry import generate_default_cfgs, register_model, register_model_deprecations |
@@ -377,95 +377,6 @@ def forward(self, x): |
377 | 377 | return self._forward(x) |
378 | 378 |
|
379 | 379 |
|
380 | | -class AttentionPoolLatent(nn.Module): |
381 | | - """ Attention pooling w/ latent query |
382 | | - """ |
383 | | - def __init__( |
384 | | - self, |
385 | | - in_features: int, |
386 | | - out_features: int = None, |
387 | | - embed_dim: int = None, |
388 | | - num_heads: int = 8, |
389 | | - mlp_ratio: float = 4.0, |
390 | | - qkv_bias: bool = True, |
391 | | - qk_norm: bool = False, |
392 | | - latent_len: int = 1, |
393 | | - latent_dim: int = None, |
394 | | - pos_embed: str = '', |
395 | | - pool_type: str = 'token', |
396 | | - norm_layer: Optional[nn.Module] = None, |
397 | | - drop: float = 0.0, |
398 | | - ): |
399 | | - super().__init__() |
400 | | - embed_dim = embed_dim or in_features |
401 | | - out_features = out_features or in_features |
402 | | - assert embed_dim % num_heads == 0 |
403 | | - self.num_heads = num_heads |
404 | | - self.head_dim = embed_dim // num_heads |
405 | | - self.scale = self.head_dim ** -0.5 |
406 | | - self.pool = pool_type |
407 | | - self.fused_attn = use_fused_attn() |
408 | | - |
409 | | - if pos_embed == 'abs': |
410 | | - spatial_len = self.feat_size |
411 | | - self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features)) |
412 | | - else: |
413 | | - self.pos_embed = None |
414 | | - |
415 | | - self.latent_dim = latent_dim or embed_dim |
416 | | - self.latent_len = latent_len |
417 | | - self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim)) |
418 | | - |
419 | | - self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias) |
420 | | - self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias) |
421 | | - self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
422 | | - self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
423 | | - self.proj = nn.Linear(embed_dim, embed_dim) |
424 | | - self.proj_drop = nn.Dropout(drop) |
425 | | - |
426 | | - self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity() |
427 | | - self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio)) |
428 | | - |
429 | | - def init_weights(self): |
430 | | - if self.pos_embed is not None: |
431 | | - trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5) |
432 | | - |
433 | | - def forward(self, x): |
434 | | - B, N, C = x.shape |
435 | | - |
436 | | - if self.pos_embed is not None: |
437 | | - # FIXME interpolate |
438 | | - x = x + self.pos_embed.unsqueeze(0).to(x.dtype) |
439 | | - |
440 | | - q_latent = self.latent.expand(B, -1, -1) |
441 | | - q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2) |
442 | | - |
443 | | - kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
444 | | - k, v = kv.unbind(0) |
445 | | - |
446 | | - q, k = self.q_norm(q), self.k_norm(k) |
447 | | - |
448 | | - if self.fused_attn: |
449 | | - x = F.scaled_dot_product_attention(q, k, v) |
450 | | - else: |
451 | | - q = q * self.scale |
452 | | - attn = q @ k.transpose(-2, -1) |
453 | | - attn = attn.softmax(dim=-1) |
454 | | - x = attn @ v |
455 | | - x = x.transpose(1, 2).reshape(B, self.latent_len, C) |
456 | | - x = self.proj(x) |
457 | | - x = self.proj_drop(x) |
458 | | - |
459 | | - x = x + self.mlp(self.norm(x)) |
460 | | - |
461 | | - # optional pool if latent seq_len > 1 and pooled output is desired |
462 | | - if self.pool == 'token': |
463 | | - x = x[:, 0] |
464 | | - elif self.pool == 'avg': |
465 | | - x = x.mean(1) |
466 | | - return x |
467 | | - |
468 | | - |
469 | 380 | class VisionTransformer(nn.Module): |
470 | 381 | """ Vision Transformer |
471 | 382 |
|
@@ -1072,6 +983,12 @@ def checkpoint_filter_fn( |
1072 | 983 | if "encoder" in state_dict: |
1073 | 984 | state_dict = _convert_ijepa(state_dict, model) |
1074 | 985 |
|
| 986 | + if 'visual.trunk.pos_embed' in state_dict: |
| 987 | + # convert an OpenCLIP model with timm vision encoder |
| 988 | + prefix = 'visual.trunk.' |
| 989 | + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} |
| 990 | + # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) |
| 991 | + |
1075 | 992 | for k, v in state_dict.items(): |
1076 | 993 | if 'patch_embed.proj.weight' in k: |
1077 | 994 | O, I, H, W = model.patch_embed.proj.weight.shape |
@@ -1634,48 +1551,42 @@ def _cfg(url='', **kwargs): |
1634 | 1551 | license='cc-by-nc-4.0', |
1635 | 1552 | mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), |
1636 | 1553 |
|
1637 | | - 'vit_base_patch16_siglip_224': _cfg( |
1638 | | - file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz', |
1639 | | - custom_load=True, |
1640 | | - # hf_hub_id='timm/', |
| 1554 | + 'vit_base_patch16_siglip_224.webli': _cfg( |
| 1555 | + hf_hub_id='timm/ViT-B-16-SigLIP', |
| 1556 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1641 | 1557 | num_classes=0), |
1642 | | - 'vit_base_patch16_siglip_256': _cfg( |
1643 | | - file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', |
1644 | | - custom_load=True, |
| 1558 | + 'vit_base_patch16_siglip_256.webli': _cfg( |
| 1559 | + hf_hub_id='timm/ViT-B-16-SigLIP-256', |
| 1560 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1645 | 1561 | input_size=(3, 256, 256), |
1646 | | - # hf_hub_id='timm/', |
1647 | 1562 | num_classes=0), |
1648 | | - 'vit_base_patch16_siglip_384': _cfg( |
1649 | | - file='', |
1650 | | - custom_load=True, |
| 1563 | + 'vit_base_patch16_siglip_384.webli': _cfg( |
| 1564 | + hf_hub_id='timm/ViT-B-16-SigLIP-384', |
| 1565 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1651 | 1566 | input_size=(3, 384, 384), |
1652 | | - # hf_hub_id='timm/', |
1653 | 1567 | num_classes=0), |
1654 | | - 'vit_base_patch16_siglip_512': _cfg( |
1655 | | - file='', |
1656 | | - custom_load=True, |
| 1568 | + 'vit_base_patch16_siglip_512.webli': _cfg( |
| 1569 | + hf_hub_id='timm/ViT-B-16-SigLIP-512', |
| 1570 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1657 | 1571 | input_size=(3, 512, 512), |
1658 | | - # hf_hub_id='timm/', |
1659 | 1572 | num_classes=0), |
1660 | | - 'vit_large_patch16_siglip_256': _cfg( |
1661 | | - custom_load=True, |
| 1573 | + 'vit_large_patch16_siglip_256.webli': _cfg( |
| 1574 | + hf_hub_id='timm/ViT-L-16-SigLIP-256', |
| 1575 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1662 | 1576 | input_size=(3, 256, 256), |
1663 | | - # hf_hub_id='timm/', |
1664 | 1577 | num_classes=0), |
1665 | | - 'vit_large_patch16_siglip_384': _cfg( |
1666 | | - custom_load=True, |
| 1578 | + 'vit_large_patch16_siglip_384.webli': _cfg( |
| 1579 | + hf_hub_id='timm/ViT-L-16-SigLIP-384', |
| 1580 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1667 | 1581 | input_size=(3, 384, 384), |
1668 | | - # hf_hub_id='timm/', |
1669 | 1582 | num_classes=0), |
1670 | | - 'vit_so400m_patch14_siglip_224': _cfg( |
1671 | | - # file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', |
1672 | | - custom_load=True, |
1673 | | - # hf_hub_id='timm/', |
| 1583 | + 'vit_so400m_patch14_siglip_224.webli': _cfg( |
| 1584 | + hf_hub_id='timm/ViT-SO400M-14-SigLIP', |
| 1585 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1674 | 1586 | num_classes=0), |
1675 | | - 'vit_so400m_patch14_siglip_384': _cfg( |
1676 | | - #file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz', |
1677 | | - custom_load=True, |
1678 | | - # hf_hub_id='timm/', |
| 1587 | + 'vit_so400m_patch14_siglip_384.webli': _cfg( |
| 1588 | + hf_hub_id='timm/ViT-SO400M-14-SigLIP-384', |
| 1589 | + hf_hub_filename='open_clip_pytorch_model.bin', |
1679 | 1590 | input_size=(3, 384, 384), |
1680 | 1591 | num_classes=0), |
1681 | 1592 | }) |
|
0 commit comments