2222)
2323from torchvision .transforms .v2 import functional as F
2424
25+ from lightllm .utils .log_utils import init_logger
26+
27+ logger = init_logger (__name__ )
28+
29+
2530IMAGE_FACTOR = 28
2631MIN_PIXELS = 4 * 28 * 28
2732MAX_PIXELS = 16384 * 28 * 28
@@ -160,9 +165,19 @@ def rescale_and_normalize(
160165
161166 return images
162167
168+ @torch .inference_mode ()
163169 def preprocess (self , image ) -> Tuple [torch .Tensor , torch .Tensor ]:
170+ try :
171+ return self ._preprocess_bydevice (image , device = "cuda" )
172+ except Exception as e :
173+ logger .warning (f"Exception during image preprocessing on CUDA: { str (e )} " )
174+ torch .cuda .current_stream ().synchronize ()
175+ return self ._preprocess_bydevice (image , device = "cpu" )
176+
177+ def _preprocess_bydevice (self , image , device = "cuda" ) -> Tuple [torch .Tensor , torch .Tensor ]:
164178 image_arr = np .asarray (image , dtype = np .uint8 )
165- image_data = torch .from_numpy (image_arr ).permute (2 , 0 , 1 ).contiguous ().to ("cuda" , non_blocking = True )
179+ image_data = torch .from_numpy (image_arr ).permute (2 , 0 , 1 ).contiguous ().to (device = device , non_blocking = True )
180+
166181 grouped_images , grouped_images_index = group_images_by_shape (
167182 [image_data ], disable_grouping = self .disable_grouping
168183 )
@@ -183,27 +198,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
183198 interpolation = self .interpolation ,
184199 )
185200 resized_images_grouped [shape ] = stacked_images
201+
202+ grouped_images = None
186203 resized_images = reorder_images (resized_images_grouped , grouped_images_index )
204+ resized_images_grouped = None
187205
188- # Group images by size for further processing
189- # Needed in case do_resize is False, or resize returns images with different sizes
190206 grouped_images , grouped_images_index = group_images_by_shape (
191207 resized_images , disable_grouping = self .disable_grouping
192208 )
209+ resized_images = None
210+
193211 processed_images_grouped = {}
194212 processed_grids = {}
213+
195214 for shape , stacked_images in grouped_images .items ():
215+ stacked_images = stacked_images .to ("cuda" , non_blocking = True )
216+
196217 resized_height , resized_width = stacked_images .shape [- 2 :]
197- # Fused rescale and normalize
218+
198219 patches = self .rescale_and_normalize (
199- stacked_images , self .do_rescale , self .rescale_factor , self .do_normalize , self .image_mean , self .image_std
220+ stacked_images ,
221+ self .do_rescale ,
222+ self .rescale_factor ,
223+ self .do_normalize ,
224+ self .image_mean ,
225+ self .image_std ,
200226 )
201227 if patches .ndim == 4 :
202- # add a temporal dimension if we have images
203228 patches = patches .unsqueeze (1 )
229+
204230 if patches .shape [1 ] % self .temporal_patch_size != 0 :
205231 repeats = patches [:, - 1 :].repeat (1 , self .temporal_patch_size - 1 , 1 , 1 , 1 )
206232 patches = torch .cat ([patches , repeats ], dim = 1 )
233+
207234 batch_size , grid_t , channel = patches .shape [:3 ]
208235 grid_t = grid_t // self .temporal_patch_size
209236 grid_h , grid_w = resized_height // self .patch_size , resized_width // self .patch_size
@@ -224,8 +251,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
224251 .permute (0 , 1 , 4 , 7 , 5 , 8 , 3 , 2 , 6 , 9 )
225252 .contiguous ()
226253 )
227- # Reorder dimensions to group grid and patch information for subsequent flattening.
228- # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
254+
229255 flatten_patches = patches .view (
230256 batch_size ,
231257 grid_t * grid_h * grid_w ,
@@ -235,9 +261,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
235261 processed_images_grouped [shape ] = flatten_patches
236262 processed_grids [shape ] = [[grid_t , grid_h , grid_w ]] * batch_size
237263
264+ grouped_images = None
265+
238266 processed_images = reorder_images (processed_images_grouped , grouped_images_index )
239267 processed_grids = reorder_images (processed_grids , grouped_images_index )
240- pixel_values = torch .cat (processed_images , dim = 0 ) # (num_patches_total, C*T*ps*ps)
268+
269+ pixel_values = torch .cat (processed_images , dim = 0 )
241270 image_grid_thw = torch .as_tensor (processed_grids )
242271
243272 return pixel_values , image_grid_thw
0 commit comments