Skip to content

Commit 7136516

Browse files
committed
Add SigLIP weights
1 parent 42daa3b commit 7136516

File tree

3 files changed

+137
-123
lines changed

3 files changed

+137
-123
lines changed

timm/layers/attention_pool.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import Optional
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
7+
from .config import use_fused_attn
8+
from .mlp import Mlp
9+
from .weight_init import trunc_normal_tf_
10+
11+
12+
class AttentionPoolLatent(nn.Module):
13+
""" Attention pooling w/ latent query
14+
"""
15+
fused_attn: torch.jit.Final[bool]
16+
17+
def __init__(
18+
self,
19+
in_features: int,
20+
out_features: int = None,
21+
embed_dim: int = None,
22+
num_heads: int = 8,
23+
mlp_ratio: float = 4.0,
24+
qkv_bias: bool = True,
25+
qk_norm: bool = False,
26+
latent_len: int = 1,
27+
latent_dim: int = None,
28+
pos_embed: str = '',
29+
pool_type: str = 'token',
30+
norm_layer: Optional[nn.Module] = None,
31+
drop: float = 0.0,
32+
):
33+
super().__init__()
34+
embed_dim = embed_dim or in_features
35+
out_features = out_features or in_features
36+
assert embed_dim % num_heads == 0
37+
self.num_heads = num_heads
38+
self.head_dim = embed_dim // num_heads
39+
self.scale = self.head_dim ** -0.5
40+
self.pool = pool_type
41+
self.fused_attn = use_fused_attn()
42+
43+
if pos_embed == 'abs':
44+
spatial_len = self.feat_size
45+
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
46+
else:
47+
self.pos_embed = None
48+
49+
self.latent_dim = latent_dim or embed_dim
50+
self.latent_len = latent_len
51+
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
52+
53+
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
54+
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
55+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
56+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
57+
self.proj = nn.Linear(embed_dim, embed_dim)
58+
self.proj_drop = nn.Dropout(drop)
59+
60+
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
61+
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
62+
63+
self.init_weights()
64+
65+
def init_weights(self):
66+
if self.pos_embed is not None:
67+
trunc_normal_tf_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
68+
trunc_normal_tf_(self.latent, std=self.latent_dim ** -0.5)
69+
70+
def forward(self, x):
71+
B, N, C = x.shape
72+
73+
if self.pos_embed is not None:
74+
# FIXME interpolate
75+
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
76+
77+
q_latent = self.latent.expand(B, -1, -1)
78+
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
79+
80+
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
81+
k, v = kv.unbind(0)
82+
83+
q, k = self.q_norm(q), self.k_norm(k)
84+
85+
if self.fused_attn:
86+
x = F.scaled_dot_product_attention(q, k, v)
87+
else:
88+
q = q * self.scale
89+
attn = q @ k.transpose(-2, -1)
90+
attn = attn.softmax(dim=-1)
91+
x = attn @ v
92+
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
93+
x = self.proj(x)
94+
x = self.proj_drop(x)
95+
96+
x = x + self.mlp(self.norm(x))
97+
98+
# optional pool if latent seq_len > 1 and pooled output is desired
99+
if self.pool == 'token':
100+
x = x[:, 0]
101+
elif self.pool == 'avg':
102+
x = x.mean(1)
103+
return x

timm/models/_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
376376
"""
377377
if filename == HF_WEIGHTS_NAME:
378378
yield HF_SAFE_WEIGHTS_NAME
379-
# if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet
380-
# yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
379+
if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
380+
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
381381
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
382382
yield filename[:-4] + ".safetensors"

timm/models/vision_transformer.py

Lines changed: 32 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737

3838
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
3939
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
40-
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
41-
resample_abs_pos_embed, RmsNorm, PatchDropout, use_fused_attn, SwiGLUPacked
40+
from timm.layers import PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, SwiGLUPacked, \
41+
trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn
4242
from ._builder import build_model_with_cfg
4343
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
4444
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
@@ -377,95 +377,6 @@ def forward(self, x):
377377
return self._forward(x)
378378

379379

380-
class AttentionPoolLatent(nn.Module):
381-
""" Attention pooling w/ latent query
382-
"""
383-
def __init__(
384-
self,
385-
in_features: int,
386-
out_features: int = None,
387-
embed_dim: int = None,
388-
num_heads: int = 8,
389-
mlp_ratio: float = 4.0,
390-
qkv_bias: bool = True,
391-
qk_norm: bool = False,
392-
latent_len: int = 1,
393-
latent_dim: int = None,
394-
pos_embed: str = '',
395-
pool_type: str = 'token',
396-
norm_layer: Optional[nn.Module] = None,
397-
drop: float = 0.0,
398-
):
399-
super().__init__()
400-
embed_dim = embed_dim or in_features
401-
out_features = out_features or in_features
402-
assert embed_dim % num_heads == 0
403-
self.num_heads = num_heads
404-
self.head_dim = embed_dim // num_heads
405-
self.scale = self.head_dim ** -0.5
406-
self.pool = pool_type
407-
self.fused_attn = use_fused_attn()
408-
409-
if pos_embed == 'abs':
410-
spatial_len = self.feat_size
411-
self.pos_embed = nn.Parameter(torch.zeros(spatial_len, in_features))
412-
else:
413-
self.pos_embed = None
414-
415-
self.latent_dim = latent_dim or embed_dim
416-
self.latent_len = latent_len
417-
self.latent = nn.Parameter(torch.zeros(1, self.latent_len, embed_dim))
418-
419-
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
420-
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
421-
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
422-
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
423-
self.proj = nn.Linear(embed_dim, embed_dim)
424-
self.proj_drop = nn.Dropout(drop)
425-
426-
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
427-
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
428-
429-
def init_weights(self):
430-
if self.pos_embed is not None:
431-
trunc_normal_(self.pos_embed, std=self.pos_embed.shape[1] ** -0.5)
432-
433-
def forward(self, x):
434-
B, N, C = x.shape
435-
436-
if self.pos_embed is not None:
437-
# FIXME interpolate
438-
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
439-
440-
q_latent = self.latent.expand(B, -1, -1)
441-
q = self.q(q_latent).reshape(B, self.latent_len, self.num_heads, self.head_dim).transpose(1, 2)
442-
443-
kv = self.kv(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
444-
k, v = kv.unbind(0)
445-
446-
q, k = self.q_norm(q), self.k_norm(k)
447-
448-
if self.fused_attn:
449-
x = F.scaled_dot_product_attention(q, k, v)
450-
else:
451-
q = q * self.scale
452-
attn = q @ k.transpose(-2, -1)
453-
attn = attn.softmax(dim=-1)
454-
x = attn @ v
455-
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
456-
x = self.proj(x)
457-
x = self.proj_drop(x)
458-
459-
x = x + self.mlp(self.norm(x))
460-
461-
# optional pool if latent seq_len > 1 and pooled output is desired
462-
if self.pool == 'token':
463-
x = x[:, 0]
464-
elif self.pool == 'avg':
465-
x = x.mean(1)
466-
return x
467-
468-
469380
class VisionTransformer(nn.Module):
470381
""" Vision Transformer
471382
@@ -1072,6 +983,12 @@ def checkpoint_filter_fn(
1072983
if "encoder" in state_dict:
1073984
state_dict = _convert_ijepa(state_dict, model)
1074985

986+
if 'visual.trunk.pos_embed' in state_dict:
987+
# convert an OpenCLIP model with timm vision encoder
988+
prefix = 'visual.trunk.'
989+
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
990+
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
991+
1075992
for k, v in state_dict.items():
1076993
if 'patch_embed.proj.weight' in k:
1077994
O, I, H, W = model.patch_embed.proj.weight.shape
@@ -1634,48 +1551,42 @@ def _cfg(url='', **kwargs):
16341551
license='cc-by-nc-4.0',
16351552
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
16361553

1637-
'vit_base_patch16_siglip_224': _cfg(
1638-
file='/data/n/temp/siglip/webli_en_b16_224_63724782.npz',
1639-
custom_load=True,
1640-
# hf_hub_id='timm/',
1554+
'vit_base_patch16_siglip_224.webli': _cfg(
1555+
hf_hub_id='timm/ViT-B-16-SigLIP',
1556+
hf_hub_filename='open_clip_pytorch_model.bin',
16411557
num_classes=0),
1642-
'vit_base_patch16_siglip_256': _cfg(
1643-
file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1644-
custom_load=True,
1558+
'vit_base_patch16_siglip_256.webli': _cfg(
1559+
hf_hub_id='timm/ViT-B-16-SigLIP-256',
1560+
hf_hub_filename='open_clip_pytorch_model.bin',
16451561
input_size=(3, 256, 256),
1646-
# hf_hub_id='timm/',
16471562
num_classes=0),
1648-
'vit_base_patch16_siglip_384': _cfg(
1649-
file='',
1650-
custom_load=True,
1563+
'vit_base_patch16_siglip_384.webli': _cfg(
1564+
hf_hub_id='timm/ViT-B-16-SigLIP-384',
1565+
hf_hub_filename='open_clip_pytorch_model.bin',
16511566
input_size=(3, 384, 384),
1652-
# hf_hub_id='timm/',
16531567
num_classes=0),
1654-
'vit_base_patch16_siglip_512': _cfg(
1655-
file='',
1656-
custom_load=True,
1568+
'vit_base_patch16_siglip_512.webli': _cfg(
1569+
hf_hub_id='timm/ViT-B-16-SigLIP-512',
1570+
hf_hub_filename='open_clip_pytorch_model.bin',
16571571
input_size=(3, 512, 512),
1658-
# hf_hub_id='timm/',
16591572
num_classes=0),
1660-
'vit_large_patch16_siglip_256': _cfg(
1661-
custom_load=True,
1573+
'vit_large_patch16_siglip_256.webli': _cfg(
1574+
hf_hub_id='timm/ViT-L-16-SigLIP-256',
1575+
hf_hub_filename='open_clip_pytorch_model.bin',
16621576
input_size=(3, 256, 256),
1663-
# hf_hub_id='timm/',
16641577
num_classes=0),
1665-
'vit_large_patch16_siglip_384': _cfg(
1666-
custom_load=True,
1578+
'vit_large_patch16_siglip_384.webli': _cfg(
1579+
hf_hub_id='timm/ViT-L-16-SigLIP-384',
1580+
hf_hub_filename='open_clip_pytorch_model.bin',
16671581
input_size=(3, 384, 384),
1668-
# hf_hub_id='timm/',
16691582
num_classes=0),
1670-
'vit_so400m_patch14_siglip_224': _cfg(
1671-
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1672-
custom_load=True,
1673-
# hf_hub_id='timm/',
1583+
'vit_so400m_patch14_siglip_224.webli': _cfg(
1584+
hf_hub_id='timm/ViT-SO400M-14-SigLIP',
1585+
hf_hub_filename='open_clip_pytorch_model.bin',
16741586
num_classes=0),
1675-
'vit_so400m_patch14_siglip_384': _cfg(
1676-
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
1677-
custom_load=True,
1678-
# hf_hub_id='timm/',
1587+
'vit_so400m_patch14_siglip_384.webli': _cfg(
1588+
hf_hub_id='timm/ViT-SO400M-14-SigLIP-384',
1589+
hf_hub_filename='open_clip_pytorch_model.bin',
16791590
input_size=(3, 384, 384),
16801591
num_classes=0),
16811592
})

0 commit comments

Comments
 (0)