Skip to content

Commit 5f12de4

Browse files
committed
Add initial AttentionPool2d that's being trialed. Fix comment and still trying to improve reliability of sgd test.
1 parent 76881d2 commit 5f12de4

File tree

3 files changed

+185
-3
lines changed

3 files changed

+185
-3
lines changed

tests/test_optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,10 @@ def test_sgd(optimizer):
317317
# lambda opt: ReduceLROnPlateau(opt)]
318318
# )
319319
_test_basic_cases(
320-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1)
320+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1)
321321
)
322322
_test_basic_cases(
323-
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=3e-3, momentum=1, weight_decay=.1)
323+
lambda weight, bias: create_optimizer_v2([weight, bias], optimizer, lr=1e-3, momentum=1, weight_decay=.1)
324324
)
325325
_test_rosenbrock(
326326
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)

timm/models/byoanet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def halonet26t(pretrained=False, **kwargs):
246246

247247
@register_model
248248
def sehalonet33ts(pretrained=False, **kwargs):
249-
""" HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
249+
""" HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
250250
"""
251251
return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)
252252

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
""" Attention Pool 2D
2+
3+
Implementations of 2D spatial feature pooling using multi-head attention instead of average pool.
4+
5+
Based on idea in CLIP by OpenAI, licensed Apache 2.0
6+
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
7+
8+
Hacked together by / Copyright 2021 Ross Wightman
9+
"""
10+
import math
11+
from typing import List, Union, Tuple
12+
13+
import torch
14+
import torch.nn as nn
15+
16+
from .helpers import to_2tuple
17+
from .weight_init import trunc_normal_
18+
19+
20+
def rot(x):
21+
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
22+
23+
24+
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
25+
return x * cos_emb + rot(x) * sin_emb
26+
27+
28+
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
29+
if isinstance(x, torch.Tensor):
30+
x = [x]
31+
return [t * cos_emb + rot(t) * sin_emb for t in x]
32+
33+
34+
class RotaryEmbedding(nn.Module):
35+
""" Rotary position embedding
36+
37+
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
38+
been well tested, and will likely change. It will be moved to its own file.
39+
40+
The following impl/resources were referenced for this impl:
41+
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
42+
* https://blog.eleuther.ai/rotary-embeddings/
43+
"""
44+
def __init__(self, dim, max_freq=4):
45+
super().__init__()
46+
self.dim = dim
47+
self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False)
48+
49+
def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None):
50+
"""
51+
NOTE: shape arg should include spatial dim only
52+
"""
53+
device = device or self.bands.device
54+
dtype = dtype or self.bands.dtype
55+
if not isinstance(shape, torch.Size):
56+
shape = torch.Size(shape)
57+
N = shape.numel()
58+
grid = torch.stack(torch.meshgrid(
59+
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1)
60+
emb = grid * math.pi * self.bands
61+
sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1)
62+
cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1)
63+
return sin, cos
64+
65+
def forward(self, x):
66+
# assuming channel-first tensor where spatial dim are >= 2
67+
sin_emb, cos_emb = self.get_embed(x.shape[2:])
68+
return apply_rot_embed(x, sin_emb, cos_emb)
69+
70+
71+
class RotAttentionPool2d(nn.Module):
72+
""" Attention based 2D feature pooling w/ rotary (relative) pos embedding.
73+
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
74+
75+
Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed.
76+
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
77+
78+
NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from
79+
train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW
80+
"""
81+
def __init__(
82+
self,
83+
in_features: int,
84+
out_features: int = None,
85+
embed_dim: int = None,
86+
num_heads: int = 4,
87+
qkv_bias: bool = True,
88+
):
89+
super().__init__()
90+
embed_dim = embed_dim or in_features
91+
out_features = out_features or in_features
92+
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
93+
self.proj = nn.Linear(embed_dim, out_features)
94+
self.num_heads = num_heads
95+
assert embed_dim % num_heads == 0
96+
self.head_dim = embed_dim // num_heads
97+
self.scale = self.head_dim ** -0.5
98+
self.pos_embed = RotaryEmbedding(self.head_dim)
99+
100+
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
101+
nn.init.zeros_(self.qkv.bias)
102+
103+
def forward(self, x):
104+
B, _, H, W = x.shape
105+
N = H * W
106+
sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:])
107+
x = x.reshape(B, -1, N).permute(0, 2, 1)
108+
109+
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
110+
111+
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
112+
q, k, v = x[0], x[1], x[2]
113+
114+
qc, q = q[:, :, :1], q[:, :, 1:]
115+
q = apply_rot_embed(q, sin_emb, cos_emb)
116+
q = torch.cat([qc, q], dim=2)
117+
118+
kc, k = k[:, :, :1], k[:, :, 1:]
119+
k = apply_rot_embed(k, sin_emb, cos_emb)
120+
k = torch.cat([kc, k], dim=2)
121+
122+
attn = (q @ k.transpose(-2, -1)) * self.scale
123+
attn = attn.softmax(dim=-1)
124+
125+
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
126+
x = self.proj(x)
127+
return x[:, 0]
128+
129+
130+
class AttentionPool2d(nn.Module):
131+
""" Attention based 2D feature pooling w/ learned (absolute) pos embedding.
132+
This is a multi-head attention based replacement for (spatial) average pooling in NN architectures.
133+
134+
It was based on impl in CLIP by OpenAI
135+
https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py
136+
137+
NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network.
138+
"""
139+
def __init__(
140+
self,
141+
in_features: int,
142+
feat_size: Union[int, Tuple[int, int]],
143+
out_features: int = None,
144+
embed_dim: int = None,
145+
num_heads: int = 4,
146+
qkv_bias: bool = True,
147+
):
148+
super().__init__()
149+
150+
embed_dim = embed_dim or in_features
151+
out_features = out_features or in_features
152+
assert embed_dim % num_heads == 0
153+
self.feat_size = to_2tuple(feat_size)
154+
self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias)
155+
self.proj = nn.Linear(embed_dim, out_features)
156+
self.num_heads = num_heads
157+
self.head_dim = embed_dim // num_heads
158+
self.scale = self.head_dim ** -0.5
159+
160+
spatial_dim = self.feat_size[0] * self.feat_size[1]
161+
self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features))
162+
trunc_normal_(self.pos_embed, std=in_features ** -0.5)
163+
trunc_normal_(self.qkv.weight, std=in_features ** -0.5)
164+
nn.init.zeros_(self.qkv.bias)
165+
166+
def forward(self, x):
167+
B, _, H, W = x.shape
168+
N = H * W
169+
assert self.feat_size[0] == H
170+
assert self.feat_size[1] == W
171+
x = x.reshape(B, -1, N).permute(0, 2, 1)
172+
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
173+
x = x + self.pos_embed.unsqueeze(0).to(x.dtype)
174+
175+
x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
176+
q, k, v = x[0], x[1], x[2]
177+
attn = (q @ k.transpose(-2, -1)) * self.scale
178+
attn = attn.softmax(dim=-1)
179+
180+
x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1)
181+
x = self.proj(x)
182+
return x[:, 0]

0 commit comments

Comments
 (0)