1+ from typing import Optional
2+
3+ import torch
4+ import torch .nn as nn
5+ import torch .nn .functional as F
6+
7+ from .config import use_fused_attn
8+ from .mlp import Mlp
9+ from .weight_init import trunc_normal_tf_
10+
11+
12+ class AttentionPoolLatent (nn .Module ):
13+ """ Attention pooling w/ latent query
14+ """
15+ fused_attn : torch .jit .Final [bool ]
16+
17+ def __init__ (
18+ self ,
19+ in_features : int ,
20+ out_features : int = None ,
21+ embed_dim : int = None ,
22+ num_heads : int = 8 ,
23+ mlp_ratio : float = 4.0 ,
24+ qkv_bias : bool = True ,
25+ qk_norm : bool = False ,
26+ latent_len : int = 1 ,
27+ latent_dim : int = None ,
28+ pos_embed : str = '' ,
29+ pool_type : str = 'token' ,
30+ norm_layer : Optional [nn .Module ] = None ,
31+ drop : float = 0.0 ,
32+ ):
33+ super ().__init__ ()
34+ embed_dim = embed_dim or in_features
35+ out_features = out_features or in_features
36+ assert embed_dim % num_heads == 0
37+ self .num_heads = num_heads
38+ self .head_dim = embed_dim // num_heads
39+ self .scale = self .head_dim ** - 0.5
40+ self .pool = pool_type
41+ self .fused_attn = use_fused_attn ()
42+
43+ if pos_embed == 'abs' :
44+ spatial_len = self .feat_size
45+ self .pos_embed = nn .Parameter (torch .zeros (spatial_len , in_features ))
46+ else :
47+ self .pos_embed = None
48+
49+ self .latent_dim = latent_dim or embed_dim
50+ self .latent_len = latent_len
51+ self .latent = nn .Parameter (torch .zeros (1 , self .latent_len , embed_dim ))
52+
53+ self .q = nn .Linear (embed_dim , embed_dim , bias = qkv_bias )
54+ self .kv = nn .Linear (embed_dim , embed_dim * 2 , bias = qkv_bias )
55+ self .q_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
56+ self .k_norm = norm_layer (self .head_dim ) if qk_norm else nn .Identity ()
57+ self .proj = nn .Linear (embed_dim , embed_dim )
58+ self .proj_drop = nn .Dropout (drop )
59+
60+ self .norm = norm_layer (out_features ) if norm_layer is not None else nn .Identity ()
61+ self .mlp = Mlp (embed_dim , int (embed_dim * mlp_ratio ))
62+
63+ self .init_weights ()
64+
65+ def init_weights (self ):
66+ if self .pos_embed is not None :
67+ trunc_normal_tf_ (self .pos_embed , std = self .pos_embed .shape [1 ] ** - 0.5 )
68+ trunc_normal_tf_ (self .latent , std = self .latent_dim ** - 0.5 )
69+
70+ def forward (self , x ):
71+ B , N , C = x .shape
72+
73+ if self .pos_embed is not None :
74+ # FIXME interpolate
75+ x = x + self .pos_embed .unsqueeze (0 ).to (x .dtype )
76+
77+ q_latent = self .latent .expand (B , - 1 , - 1 )
78+ q = self .q (q_latent ).reshape (B , self .latent_len , self .num_heads , self .head_dim ).transpose (1 , 2 )
79+
80+ kv = self .kv (x ).reshape (B , N , 2 , self .num_heads , self .head_dim ).permute (2 , 0 , 3 , 1 , 4 )
81+ k , v = kv .unbind (0 )
82+
83+ q , k = self .q_norm (q ), self .k_norm (k )
84+
85+ if self .fused_attn :
86+ x = F .scaled_dot_product_attention (q , k , v )
87+ else :
88+ q = q * self .scale
89+ attn = q @ k .transpose (- 2 , - 1 )
90+ attn = attn .softmax (dim = - 1 )
91+ x = attn @ v
92+ x = x .transpose (1 , 2 ).reshape (B , self .latent_len , C )
93+ x = self .proj (x )
94+ x = self .proj_drop (x )
95+
96+ x = x + self .mlp (self .norm (x ))
97+
98+ # optional pool if latent seq_len > 1 and pooled output is desired
99+ if self .pool == 'token' :
100+ x = x [:, 0 ]
101+ elif self .pool == 'avg' :
102+ x = x .mean (1 )
103+ return x
0 commit comments