33import torch
44import numpy as np
55from PIL import Image
6+ from collections import defaultdict
67from typing import List , Optional , Union , Tuple
78
8- from transformers .image_processing_utils import BaseImageProcessor
9- from transformers .image_transforms import (
10- convert_to_rgb ,
11- resize ,
12- to_channel_dimension_format ,
9+ from transformers .image_processing_utils_fast import (
10+ BaseImageProcessorFast ,
11+ group_images_by_shape ,
12+ reorder_images ,
1313)
14+ from torchvision .transforms import InterpolationMode
15+
1416from transformers .image_utils import (
1517 OPENAI_CLIP_MEAN ,
1618 OPENAI_CLIP_STD ,
1719 ChannelDimension ,
1820 PILImageResampling ,
19- get_image_size ,
20- infer_channel_dimension_format ,
21- to_numpy_array ,
21+ SizeDict ,
2222)
23+ from torchvision .transforms .v2 import functional as F
2324
2425IMAGE_FACTOR = 28
2526MIN_PIXELS = 4 * 28 * 28
@@ -35,9 +36,9 @@ def smart_resize(
3536 height : int , width : int , factor : int = IMAGE_FACTOR , min_pixels : int = MIN_PIXELS , max_pixels : int = MAX_PIXELS
3637) -> tuple [int , int ]:
3738
38- if max (height , width ) / min (height , width ) > 200 :
39+ if max (height , width ) / min (height , width ) > MAX_RATIO :
3940 raise ValueError (
40- f"absolute aspect ratio must be smaller than 200 , got { max (height , width ) / min (height , width )} "
41+ f"absolute aspect ratio must be smaller than MAX_RATIO , got { max (height , width ) / min (height , width )} "
4142 )
4243 h_bar = round (height / factor ) * factor
4344 w_bar = round (width / factor ) * factor
@@ -54,7 +55,7 @@ def smart_resize(
5455
5556def resize_image (
5657 image_file : Image .Image , factor : int = IMAGE_FACTOR , min_pixels : int = MIN_PIXELS , max_pixels : int = MAX_PIXELS
57- ) -> tuple [Image .Image , int , int ]:
58+ ) -> tuple [Image .Image ]:
5859
5960 image = image_file .convert ("RGB" )
6061 width , height = image .size
@@ -71,7 +72,7 @@ def resize_image(
7172 return image
7273
7374
74- class Qwen2VLImageProcessor (BaseImageProcessor ):
75+ class Qwen2VLImageProcessor (BaseImageProcessorFast ):
7576 def __init__ (
7677 self ,
7778 do_resize : bool = True ,
@@ -87,6 +88,8 @@ def __init__(
8788 patch_size : int = 14 ,
8889 temporal_patch_size : int = 2 ,
8990 merge_size : int = 2 ,
91+ disable_grouping : Optional [bool ] = None ,
92+ interpolation : Optional ["F.InterpolationMode" ] = InterpolationMode .BICUBIC ,
9093 ** kwargs ,
9194 ) -> None :
9295 super ().__init__ (** kwargs )
@@ -103,63 +106,138 @@ def __init__(
103106 self .patch_size = patch_size
104107 self .temporal_patch_size = temporal_patch_size
105108 self .merge_size = merge_size
109+ self .disable_grouping = disable_grouping
110+ self .interpolation = interpolation
106111 self .data_format = ChannelDimension .FIRST
112+ self ._fused_cache = {} # key: (do_norm, do_rescale, rescale_factor, device)
113+
114+ def _get_fused_mean_std (
115+ self ,
116+ do_normalize : bool ,
117+ image_mean : Union [float , list [float ]],
118+ image_std : Union [float , list [float ]],
119+ do_rescale : bool ,
120+ rescale_factor : float ,
121+ device : Optional ["torch.device" ],
122+ ) -> tuple [torch .Tensor , torch .Tensor , bool ]:
123+ key = (bool (do_normalize ), bool (do_rescale ), float (rescale_factor ), str (device ))
124+ if key not in self ._fused_cache :
125+ if do_rescale and do_normalize :
126+ mean = torch .tensor (image_mean ) * (1.0 / rescale_factor )
127+ std = torch .tensor (image_std ) * (1.0 / rescale_factor )
128+ do_rescale = False
129+ else :
130+ mean = torch .tensor (image_mean )
131+ std = torch .tensor (image_std )
132+ self ._fused_cache [key ] = (mean .to (device = device ), std .to (device = device ), do_rescale )
133+ return self ._fused_cache [key ]
134+
135+ def rescale_and_normalize (
136+ self ,
137+ images : "torch.Tensor" ,
138+ do_rescale : bool ,
139+ rescale_factor : float ,
140+ do_normalize : bool ,
141+ image_mean : Union [float , list [float ]],
142+ image_std : Union [float , list [float ]],
143+ ) -> "torch.Tensor" :
144+ """
145+ Rescale and normalize images.
146+ """
147+ image_mean , image_std , do_rescale = self ._get_fused_mean_std (
148+ do_normalize = do_normalize ,
149+ image_mean = image_mean ,
150+ image_std = image_std ,
151+ do_rescale = do_rescale ,
152+ rescale_factor = rescale_factor ,
153+ device = images .device ,
154+ )
155+ # if/elif as we use fused rescale and normalize if both are set to True
156+ if do_normalize :
157+ images = self .normalize (images .to (dtype = torch .float32 ), image_mean , image_std )
158+ elif do_rescale :
159+ images = self .rescale (images , rescale_factor )
160+
161+ return images
107162
108163 def preprocess (self , image ) -> Tuple [torch .Tensor , torch .Tensor ]:
109- if self .do_convert_rgb :
110- image = convert_to_rgb (image )
111- image = to_numpy_array (image )
112- input_data_format = infer_channel_dimension_format (image )
113- height , width = get_image_size (image , channel_dim = input_data_format )
114-
115- resized_height , resized_width = height , width
116- if self .do_resize :
117- resized_height , resized_width = smart_resize (
118- height ,
119- width ,
120- factor = self .patch_size * self .merge_size ,
121- min_pixels = self .min_pixels ,
122- max_pixels = self .max_pixels ,
164+ image_arr = np .asarray (image , dtype = np .uint8 )
165+ image_data = torch .from_numpy (image_arr ).permute (2 , 0 , 1 ).contiguous ().to ("cuda" , non_blocking = True )
166+ grouped_images , grouped_images_index = group_images_by_shape (
167+ [image_data ], disable_grouping = self .disable_grouping
168+ )
169+ resized_images_grouped = {}
170+ for shape , stacked_images in grouped_images .items ():
171+ height , width = stacked_images .shape [- 2 :]
172+ if self .do_resize :
173+ resized_height , resized_width = smart_resize (
174+ height ,
175+ width ,
176+ factor = self .patch_size * self .merge_size ,
177+ min_pixels = self .min_pixels ,
178+ max_pixels = self .max_pixels ,
179+ )
180+ stacked_images = self .resize (
181+ image = stacked_images ,
182+ size = SizeDict (height = resized_height , width = resized_width ),
183+ interpolation = self .interpolation ,
184+ )
185+ resized_images_grouped [shape ] = stacked_images
186+ resized_images = reorder_images (resized_images_grouped , grouped_images_index )
187+
188+ # Group images by size for further processing
189+ # Needed in case do_resize is False, or resize returns images with different sizes
190+ grouped_images , grouped_images_index = group_images_by_shape (
191+ resized_images , disable_grouping = self .disable_grouping
192+ )
193+ processed_images_grouped = {}
194+ processed_grids = {}
195+ for shape , stacked_images in grouped_images .items ():
196+ resized_height , resized_width = stacked_images .shape [- 2 :]
197+ # Fused rescale and normalize
198+ patches = self .rescale_and_normalize (
199+ stacked_images , self .do_rescale , self .rescale_factor , self .do_normalize , self .image_mean , self .image_std
123200 )
124- image = resize (
125- image , size = (resized_height , resized_width ), resample = self .resample , input_data_format = input_data_format
201+ if patches .ndim == 4 :
202+ # add a temporal dimension if we have images
203+ patches = patches .unsqueeze (1 )
204+ if patches .shape [1 ] % self .temporal_patch_size != 0 :
205+ repeats = patches [:, - 1 :].repeat (1 , self .temporal_patch_size - 1 , 1 , 1 , 1 )
206+ patches = torch .cat ([patches , repeats ], dim = 1 )
207+ batch_size , grid_t , channel = patches .shape [:3 ]
208+ grid_t = grid_t // self .temporal_patch_size
209+ grid_h , grid_w = resized_height // self .patch_size , resized_width // self .patch_size
210+
211+ patches = (
212+ patches .view (
213+ batch_size ,
214+ grid_t ,
215+ self .temporal_patch_size ,
216+ channel ,
217+ grid_h // self .merge_size ,
218+ self .merge_size ,
219+ self .patch_size ,
220+ grid_w // self .merge_size ,
221+ self .merge_size ,
222+ self .patch_size ,
223+ )
224+ .permute (0 , 1 , 4 , 7 , 5 , 8 , 3 , 2 , 6 , 9 )
225+ .contiguous ()
126226 )
127-
128- if self .do_rescale :
129- image = self .rescale (image , scale = self .rescale_factor , input_data_format = input_data_format )
130-
131- if self .do_normalize :
132- image = self .normalize (
133- image = image , mean = self .image_mean , std = self .image_std , input_data_format = input_data_format
227+ # Reorder dimensions to group grid and patch information for subsequent flattening.
228+ # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
229+ flatten_patches = patches .view (
230+ batch_size ,
231+ grid_t * grid_h * grid_w ,
232+ channel * self .temporal_patch_size * self .patch_size * self .patch_size ,
134233 )
135234
136- image = to_channel_dimension_format (image , self .data_format , input_channel_dim = input_data_format )
137-
138- patches = np .array ([image ])
139-
140- if patches .shape [0 ] == 1 :
141- # why to copy image 2 times. use self.temporal_patch_size = 2.
142- patches = np .tile (patches , (self .temporal_patch_size , 1 , 1 , 1 ))
143- channel = patches .shape [1 ]
144- grid_t = patches .shape [0 ] // self .temporal_patch_size
145- grid_h , grid_w = resized_height // self .patch_size , resized_width // self .patch_size
146- patches = patches .reshape (
147- grid_t ,
148- self .temporal_patch_size ,
149- channel ,
150- grid_h // self .merge_size ,
151- self .merge_size ,
152- self .patch_size ,
153- grid_w // self .merge_size ,
154- self .merge_size ,
155- self .patch_size ,
156- )
157- patches = patches .transpose (0 , 3 , 6 , 4 , 7 , 2 , 1 , 5 , 8 )
158- flatten_patches = patches .reshape (
159- grid_t * grid_h * grid_w , channel * self .temporal_patch_size * self .patch_size * self .patch_size
160- )
161- image_grid_thw = (grid_t , grid_h , grid_w )
162- pixel_values = torch .as_tensor (flatten_patches )
163- grid_thw = torch .as_tensor ([image_grid_thw ])
235+ processed_images_grouped [shape ] = flatten_patches
236+ processed_grids [shape ] = [[grid_t , grid_h , grid_w ]] * batch_size
237+
238+ processed_images = reorder_images (processed_images_grouped , grouped_images_index )
239+ processed_grids = reorder_images (processed_grids , grouped_images_index )
240+ pixel_values = torch .cat (processed_images , dim = 0 ) # (num_patches_total, C*T*ps*ps)
241+ image_grid_thw = torch .as_tensor (processed_grids )
164242
165- return pixel_values , grid_thw
243+ return pixel_values , image_grid_thw
0 commit comments