|
| 1 | +""" Attention Pool 2D |
| 2 | +
|
| 3 | +Implementations of 2D spatial feature pooling using multi-head attention instead of average pool. |
| 4 | +
|
| 5 | +Based on idea in CLIP by OpenAI, licensed Apache 2.0 |
| 6 | +https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py |
| 7 | +
|
| 8 | +Hacked together by / Copyright 2021 Ross Wightman |
| 9 | +""" |
| 10 | +import math |
| 11 | +from typing import List, Union, Tuple |
| 12 | + |
| 13 | +import torch |
| 14 | +import torch.nn as nn |
| 15 | + |
| 16 | +from .helpers import to_2tuple |
| 17 | +from .weight_init import trunc_normal_ |
| 18 | + |
| 19 | + |
| 20 | +def rot(x): |
| 21 | + return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape) |
| 22 | + |
| 23 | + |
| 24 | +def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb): |
| 25 | + return x * cos_emb + rot(x) * sin_emb |
| 26 | + |
| 27 | + |
| 28 | +def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb): |
| 29 | + if isinstance(x, torch.Tensor): |
| 30 | + x = [x] |
| 31 | + return [t * cos_emb + rot(t) * sin_emb for t in x] |
| 32 | + |
| 33 | + |
| 34 | +class RotaryEmbedding(nn.Module): |
| 35 | + """ Rotary position embedding |
| 36 | +
|
| 37 | + NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not |
| 38 | + been well tested, and will likely change. It will be moved to its own file. |
| 39 | +
|
| 40 | + The following impl/resources were referenced for this impl: |
| 41 | + * https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py |
| 42 | + * https://blog.eleuther.ai/rotary-embeddings/ |
| 43 | + """ |
| 44 | + def __init__(self, dim, max_freq=4): |
| 45 | + super().__init__() |
| 46 | + self.dim = dim |
| 47 | + self.register_buffer('bands', 2 ** torch.linspace(0., max_freq - 1, self.dim // 4), persistent=False) |
| 48 | + |
| 49 | + def get_embed(self, shape: torch.Size, device: torch.device = None, dtype: torch.dtype = None): |
| 50 | + """ |
| 51 | + NOTE: shape arg should include spatial dim only |
| 52 | + """ |
| 53 | + device = device or self.bands.device |
| 54 | + dtype = dtype or self.bands.dtype |
| 55 | + if not isinstance(shape, torch.Size): |
| 56 | + shape = torch.Size(shape) |
| 57 | + N = shape.numel() |
| 58 | + grid = torch.stack(torch.meshgrid( |
| 59 | + [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in shape]), dim=-1).unsqueeze(-1) |
| 60 | + emb = grid * math.pi * self.bands |
| 61 | + sin = emb.sin().reshape(N, -1).repeat_interleave(2, -1) |
| 62 | + cos = emb.cos().reshape(N, -1).repeat_interleave(2, -1) |
| 63 | + return sin, cos |
| 64 | + |
| 65 | + def forward(self, x): |
| 66 | + # assuming channel-first tensor where spatial dim are >= 2 |
| 67 | + sin_emb, cos_emb = self.get_embed(x.shape[2:]) |
| 68 | + return apply_rot_embed(x, sin_emb, cos_emb) |
| 69 | + |
| 70 | + |
| 71 | +class RotAttentionPool2d(nn.Module): |
| 72 | + """ Attention based 2D feature pooling w/ rotary (relative) pos embedding. |
| 73 | + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. |
| 74 | +
|
| 75 | + Adapted from the AttentionPool2d in CLIP w/ rotary embedding instead of learned embed. |
| 76 | + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py |
| 77 | +
|
| 78 | + NOTE: While this impl does not require a fixed feature size, performance at differeing resolutions from |
| 79 | + train varies widely and falls off dramatically. I'm not sure if there is a way around this... -RW |
| 80 | + """ |
| 81 | + def __init__( |
| 82 | + self, |
| 83 | + in_features: int, |
| 84 | + out_features: int = None, |
| 85 | + embed_dim: int = None, |
| 86 | + num_heads: int = 4, |
| 87 | + qkv_bias: bool = True, |
| 88 | + ): |
| 89 | + super().__init__() |
| 90 | + embed_dim = embed_dim or in_features |
| 91 | + out_features = out_features or in_features |
| 92 | + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) |
| 93 | + self.proj = nn.Linear(embed_dim, out_features) |
| 94 | + self.num_heads = num_heads |
| 95 | + assert embed_dim % num_heads == 0 |
| 96 | + self.head_dim = embed_dim // num_heads |
| 97 | + self.scale = self.head_dim ** -0.5 |
| 98 | + self.pos_embed = RotaryEmbedding(self.head_dim) |
| 99 | + |
| 100 | + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) |
| 101 | + nn.init.zeros_(self.qkv.bias) |
| 102 | + |
| 103 | + def forward(self, x): |
| 104 | + B, _, H, W = x.shape |
| 105 | + N = H * W |
| 106 | + sin_emb, cos_emb = self.pos_embed.get_embed(x.shape[2:]) |
| 107 | + x = x.reshape(B, -1, N).permute(0, 2, 1) |
| 108 | + |
| 109 | + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) |
| 110 | + |
| 111 | + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| 112 | + q, k, v = x[0], x[1], x[2] |
| 113 | + |
| 114 | + qc, q = q[:, :, :1], q[:, :, 1:] |
| 115 | + q = apply_rot_embed(q, sin_emb, cos_emb) |
| 116 | + q = torch.cat([qc, q], dim=2) |
| 117 | + |
| 118 | + kc, k = k[:, :, :1], k[:, :, 1:] |
| 119 | + k = apply_rot_embed(k, sin_emb, cos_emb) |
| 120 | + k = torch.cat([kc, k], dim=2) |
| 121 | + |
| 122 | + attn = (q @ k.transpose(-2, -1)) * self.scale |
| 123 | + attn = attn.softmax(dim=-1) |
| 124 | + |
| 125 | + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) |
| 126 | + x = self.proj(x) |
| 127 | + return x[:, 0] |
| 128 | + |
| 129 | + |
| 130 | +class AttentionPool2d(nn.Module): |
| 131 | + """ Attention based 2D feature pooling w/ learned (absolute) pos embedding. |
| 132 | + This is a multi-head attention based replacement for (spatial) average pooling in NN architectures. |
| 133 | +
|
| 134 | + It was based on impl in CLIP by OpenAI |
| 135 | + https://github.com/openai/CLIP/blob/3b473b0e682c091a9e53623eebc1ca1657385717/clip/model.py |
| 136 | +
|
| 137 | + NOTE: This requires feature size upon construction and well prevent adaptive sizing of the network. |
| 138 | + """ |
| 139 | + def __init__( |
| 140 | + self, |
| 141 | + in_features: int, |
| 142 | + feat_size: Union[int, Tuple[int, int]], |
| 143 | + out_features: int = None, |
| 144 | + embed_dim: int = None, |
| 145 | + num_heads: int = 4, |
| 146 | + qkv_bias: bool = True, |
| 147 | + ): |
| 148 | + super().__init__() |
| 149 | + |
| 150 | + embed_dim = embed_dim or in_features |
| 151 | + out_features = out_features or in_features |
| 152 | + assert embed_dim % num_heads == 0 |
| 153 | + self.feat_size = to_2tuple(feat_size) |
| 154 | + self.qkv = nn.Linear(in_features, embed_dim * 3, bias=qkv_bias) |
| 155 | + self.proj = nn.Linear(embed_dim, out_features) |
| 156 | + self.num_heads = num_heads |
| 157 | + self.head_dim = embed_dim // num_heads |
| 158 | + self.scale = self.head_dim ** -0.5 |
| 159 | + |
| 160 | + spatial_dim = self.feat_size[0] * self.feat_size[1] |
| 161 | + self.pos_embed = nn.Parameter(torch.zeros(spatial_dim + 1, in_features)) |
| 162 | + trunc_normal_(self.pos_embed, std=in_features ** -0.5) |
| 163 | + trunc_normal_(self.qkv.weight, std=in_features ** -0.5) |
| 164 | + nn.init.zeros_(self.qkv.bias) |
| 165 | + |
| 166 | + def forward(self, x): |
| 167 | + B, _, H, W = x.shape |
| 168 | + N = H * W |
| 169 | + assert self.feat_size[0] == H |
| 170 | + assert self.feat_size[1] == W |
| 171 | + x = x.reshape(B, -1, N).permute(0, 2, 1) |
| 172 | + x = torch.cat([x.mean(1, keepdim=True), x], dim=1) |
| 173 | + x = x + self.pos_embed.unsqueeze(0).to(x.dtype) |
| 174 | + |
| 175 | + x = self.qkv(x).reshape(B, N + 1, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
| 176 | + q, k, v = x[0], x[1], x[2] |
| 177 | + attn = (q @ k.transpose(-2, -1)) * self.scale |
| 178 | + attn = attn.softmax(dim=-1) |
| 179 | + |
| 180 | + x = (attn @ v).transpose(1, 2).reshape(B, N + 1, -1) |
| 181 | + x = self.proj(x) |
| 182 | + return x[:, 0] |
0 commit comments