1+ """ Image to Patch Hybird Embedding Layer
2+
3+ Hacked together by / Copyright 2020 Ross Wightman
4+ """
5+ import logging
6+ import math
7+ from typing import List , Optional , Tuple , Union
8+
9+ import torch
10+ from torch import nn as nn
11+ import torch .nn .functional as F
12+
13+ from .format import Format , nchw_to
14+ from .helpers import to_2tuple
15+ from .patch_embed import resample_patch_embed
16+
17+
18+ _logger = logging .getLogger (__name__ )
19+
20+
21+ class HybridEmbed (nn .Module ):
22+ """ CNN Feature Map Embedding
23+ Extract feature map from CNN, flatten, project to embedding dim.
24+ """
25+ output_fmt : Format
26+ dynamic_img_pad : torch .jit .Final [bool ]
27+
28+ def __init__ (
29+ self ,
30+ backbone : nn .Module ,
31+ img_size : Union [int , Tuple [int , int ]] = 224 ,
32+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
33+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
34+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
35+ in_chans : int = 3 ,
36+ embed_dim : int = 768 ,
37+ bias : bool = True ,
38+ proj : bool = True ,
39+ flatten : bool = True ,
40+ output_fmt : Optional [str ] = None ,
41+ strict_img_size : bool = True ,
42+ dynamic_img_pad : bool = False ,
43+ ):
44+ super ().__init__ ()
45+ assert isinstance (backbone , nn .Module )
46+ self .backbone = backbone
47+ self .in_chans = in_chans
48+ (
49+ self .img_size ,
50+ self .patch_size ,
51+ self .feature_size ,
52+ self .feature_ratio ,
53+ self .feature_dim ,
54+ self .grid_size ,
55+ self .num_patches ,
56+ ) = self ._init_backbone (
57+ img_size = img_size ,
58+ patch_size = patch_size ,
59+ feature_size = feature_size ,
60+ feature_ratio = feature_ratio ,
61+ )
62+
63+ if output_fmt is not None :
64+ self .flatten = False
65+ self .output_fmt = Format (output_fmt )
66+ else :
67+ # flatten spatial dim and transpose to channels last, kept for bwd compat
68+ self .flatten = flatten
69+ self .output_fmt = Format .NCHW
70+ self .strict_img_size = strict_img_size
71+ self .dynamic_img_pad = dynamic_img_pad
72+ if not dynamic_img_pad :
73+ assert self .feature_size [0 ] % self .patch_size [0 ] == 0 and self .feature_size [1 ] % self .patch_size [1 ] == 0
74+
75+ if proj :
76+ self .proj = nn .Conv2d (
77+ self .feature_dim ,
78+ embed_dim ,
79+ kernel_size = patch_size ,
80+ stride = patch_size ,
81+ bias = bias ,
82+ )
83+ else :
84+ assert self .feature_dim == embed_dim , \
85+ f'The feature dim ({ self .feature_dim } must match embed dim ({ embed_dim } ) when projection disabled.'
86+ self .proj = nn .Identity ()
87+
88+ def _init_backbone (
89+ self ,
90+ img_size : Union [int , Tuple [int , int ]] = 224 ,
91+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
92+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
93+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
94+ feature_dim : Optional [int ] = None ,
95+ ):
96+ img_size = to_2tuple (img_size )
97+ patch_size = to_2tuple (patch_size )
98+ if feature_size is None :
99+ with torch .no_grad ():
100+ # NOTE Most reliable way of determining output dims is to run forward pass
101+ training = self .backbone .training
102+ if training :
103+ self .backbone .eval ()
104+ o = self .backbone (torch .zeros (1 , self .in_chans , img_size [0 ], img_size [1 ]))
105+ if isinstance (o , (list , tuple )):
106+ o = o [- 1 ] # last feature if backbone outputs list/tuple of features
107+ feature_size = o .shape [- 2 :]
108+ feature_dim = o .shape [1 ]
109+ self .backbone .train (training )
110+ feature_ratio = tuple ([s // f for s , f in zip (img_size , feature_size )])
111+ else :
112+ feature_size = to_2tuple (feature_size )
113+ feature_ratio = to_2tuple (feature_ratio or 16 )
114+ if feature_dim is None :
115+ if hasattr (self .backbone , 'feature_info' ):
116+ feature_dim = self .backbone .feature_info .channels ()[- 1 ]
117+ else :
118+ feature_dim = self .backbone .num_features
119+ grid_size = tuple ([f // p for f , p in zip (feature_size , patch_size )])
120+ num_patches = grid_size [0 ] * grid_size [1 ]
121+ return img_size , patch_size , feature_size , feature_ratio , feature_dim , grid_size , num_patches
122+
123+ def set_input_size (
124+ self ,
125+ img_size : Optional [Union [int , Tuple [int , int ]]] = None ,
126+ patch_size : Optional [Union [int , Tuple [int , int ]]] = None ,
127+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
128+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
129+ feature_dim : Optional [int ] = None ,
130+ ):
131+ assert img_size is not None or patch_size is not None
132+ img_size = img_size or self .img_size
133+ new_patch_size = None
134+ if patch_size is not None :
135+ new_patch_size = to_2tuple (patch_size )
136+ if new_patch_size is not None and new_patch_size != self .patch_size :
137+ assert isinstance (self .proj , nn .Conv2d ), 'HybridEmbed must have a projection layer to change patch size.'
138+ with torch .no_grad ():
139+ new_proj = nn .Conv2d (
140+ self .proj .in_channels ,
141+ self .proj .out_channels ,
142+ kernel_size = new_patch_size ,
143+ stride = new_patch_size ,
144+ bias = self .proj .bias is not None ,
145+ )
146+ new_proj .weight .copy_ (resample_patch_embed (self .proj .weight , new_patch_size , verbose = True ))
147+ if self .proj .bias is not None :
148+ new_proj .bias .copy_ (self .proj .bias )
149+ self .proj = new_proj
150+ patch_size = new_patch_size
151+ patch_size = patch_size or self .patch_size
152+
153+ if img_size != self .img_size or patch_size != self .patch_size :
154+ (
155+ self .img_size ,
156+ self .patch_size ,
157+ self .feature_size ,
158+ self .feature_ratio ,
159+ self .feature_dim ,
160+ self .grid_size ,
161+ self .num_patches ,
162+ ) = self ._init_backbone (
163+ img_size = img_size ,
164+ patch_size = patch_size ,
165+ feature_size = feature_size ,
166+ feature_ratio = feature_ratio ,
167+ feature_dim = feature_dim ,
168+ )
169+
170+ def feat_ratio (self , as_scalar = True ) -> Union [Tuple [int , int ], int ]:
171+ total_reduction = (
172+ self .feature_ratio [0 ] * self .patch_size [0 ],
173+ self .feature_ratio [1 ] * self .patch_size [1 ]
174+ )
175+ if as_scalar :
176+ return max (total_reduction )
177+ else :
178+ return total_reduction
179+
180+ def dynamic_feat_size (self , img_size : Tuple [int , int ]) -> Tuple [int , int ]:
181+ """ Get feature grid size taking account dynamic padding and backbone network feat reduction
182+ """
183+ feat_size = (img_size [0 ] // self .feature_ratio [0 ], img_size [1 ] // self .feature_ratio [1 ])
184+ if self .dynamic_img_pad :
185+ return math .ceil (feat_size [0 ] / self .patch_size [0 ]), math .ceil (feat_size [1 ] / self .patch_size [1 ])
186+ else :
187+ return feat_size [0 ] // self .patch_size [0 ], feat_size [1 ] // self .patch_size [1 ]
188+
189+ @torch .jit .ignore
190+ def set_grad_checkpointing (self , enable : bool = True ):
191+ if hasattr (self .backbone , 'set_grad_checkpointing' ):
192+ self .backbone .set_grad_checkpointing (enable = enable )
193+ elif hasattr (self .backbone , 'grad_checkpointing' ):
194+ self .backbone .grad_checkpointing = enable
195+
196+ def forward (self , x ):
197+ x = self .backbone (x )
198+ if isinstance (x , (list , tuple )):
199+ x = x [- 1 ] # last feature if backbone outputs list/tuple of features
200+ _ , _ , H , W = x .shape
201+ if self .dynamic_img_pad :
202+ pad_h = (self .patch_size [0 ] - H % self .patch_size [0 ]) % self .patch_size [0 ]
203+ pad_w = (self .patch_size [1 ] - W % self .patch_size [1 ]) % self .patch_size [1 ]
204+ x = F .pad (x , (0 , pad_w , 0 , pad_h ))
205+ x = self .proj (x )
206+ if self .flatten :
207+ x = x .flatten (2 ).transpose (1 , 2 ) # NCHW -> NLC
208+ elif self .output_fmt != Format .NCHW :
209+ x = nchw_to (x , self .output_fmt )
210+ return x
211+
212+
213+ class HybridEmbedWithSize (HybridEmbed ):
214+ """ CNN Feature Map Embedding
215+ Extract feature map from CNN, flatten, project to embedding dim.
216+ """
217+ def __init__ (
218+ self ,
219+ backbone : nn .Module ,
220+ img_size : Union [int , Tuple [int , int ]] = 224 ,
221+ patch_size : Union [int , Tuple [int , int ]] = 1 ,
222+ feature_size : Optional [Union [int , Tuple [int , int ]]] = None ,
223+ feature_ratio : Optional [Union [int , Tuple [int , int ]]] = None ,
224+ in_chans : int = 3 ,
225+ embed_dim : int = 768 ,
226+ bias = True ,
227+ proj = True ,
228+ ):
229+ super ().__init__ (
230+ backbone = backbone ,
231+ img_size = img_size ,
232+ patch_size = patch_size ,
233+ feature_size = feature_size ,
234+ feature_ratio = feature_ratio ,
235+ in_chans = in_chans ,
236+ embed_dim = embed_dim ,
237+ bias = bias ,
238+ proj = proj ,
239+ )
240+
241+ @torch .jit .ignore
242+ def set_grad_checkpointing (self , enable : bool = True ):
243+ if hasattr (self .backbone , 'set_grad_checkpointing' ):
244+ self .backbone .set_grad_checkpointing (enable = enable )
245+ elif hasattr (self .backbone , 'grad_checkpointing' ):
246+ self .backbone .grad_checkpointing = enable
247+
248+ def forward (self , x ) -> Tuple [torch .Tensor , List [int ]]:
249+ x = self .backbone (x )
250+ if isinstance (x , (list , tuple )):
251+ x = x [- 1 ] # last feature if backbone outputs list/tuple of features
252+ x = self .proj (x )
253+ return x .flatten (2 ).transpose (1 , 2 ), x .shape [- 2 :]
0 commit comments