Skip to content

Commit 96e2a1d

Browse files
authored
opti-qwen2-vl-pre-process (#1094)
1 parent e2e9fab commit 96e2a1d

File tree

6 files changed

+151
-71
lines changed

6 files changed

+151
-71
lines changed

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,6 @@ def encode(self, images: List[ImageItem]):
383383
uuids.append(img.uuid)
384384
image_data = read_shm(get_shm_name_data(img.uuid))
385385
image_data = Image.open(BytesIO(image_data))
386-
image_data = resize_image(image_data)
387386
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
388387
img_tensors.append(pixel_values)
389388
img_grids.append(image_grid_thw)

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,6 @@ def encode(self, images: List[ImageItem]):
311311
uuids.append(img.uuid)
312312
image_data = read_shm(get_shm_name_data(img.uuid))
313313
image_data = Image.open(BytesIO(image_data))
314-
image_data = resize_image(
315-
image_file=image_data,
316-
factor=self.processor.patch_size * self.processor.merge_size,
317-
min_pixels=self.processor.min_pixels,
318-
max_pixels=self.processor.max_pixels,
319-
)
320314
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
321315
img_tensors.append(pixel_values)
322316
img_grids.append(image_grid_thw)

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 142 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33
import torch
44
import numpy as np
55
from PIL import Image
6+
from collections import defaultdict
67
from typing import List, Optional, Union, Tuple
78

8-
from transformers.image_processing_utils import BaseImageProcessor
9-
from transformers.image_transforms import (
10-
convert_to_rgb,
11-
resize,
12-
to_channel_dimension_format,
9+
from transformers.image_processing_utils_fast import (
10+
BaseImageProcessorFast,
11+
group_images_by_shape,
12+
reorder_images,
1313
)
14+
from torchvision.transforms import InterpolationMode
15+
1416
from transformers.image_utils import (
1517
OPENAI_CLIP_MEAN,
1618
OPENAI_CLIP_STD,
1719
ChannelDimension,
1820
PILImageResampling,
19-
get_image_size,
20-
infer_channel_dimension_format,
21-
to_numpy_array,
21+
SizeDict,
2222
)
23+
from torchvision.transforms.v2 import functional as F
2324

2425
IMAGE_FACTOR = 28
2526
MIN_PIXELS = 4 * 28 * 28
@@ -35,9 +36,9 @@ def smart_resize(
3536
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
3637
) -> tuple[int, int]:
3738

38-
if max(height, width) / min(height, width) > 200:
39+
if max(height, width) / min(height, width) > MAX_RATIO:
3940
raise ValueError(
40-
f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
41+
f"absolute aspect ratio must be smaller than MAX_RATIO, got {max(height, width) / min(height, width)}"
4142
)
4243
h_bar = round(height / factor) * factor
4344
w_bar = round(width / factor) * factor
@@ -54,7 +55,7 @@ def smart_resize(
5455

5556
def resize_image(
5657
image_file: Image.Image, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
57-
) -> tuple[Image.Image, int, int]:
58+
) -> tuple[Image.Image]:
5859

5960
image = image_file.convert("RGB")
6061
width, height = image.size
@@ -71,7 +72,7 @@ def resize_image(
7172
return image
7273

7374

74-
class Qwen2VLImageProcessor(BaseImageProcessor):
75+
class Qwen2VLImageProcessor(BaseImageProcessorFast):
7576
def __init__(
7677
self,
7778
do_resize: bool = True,
@@ -87,6 +88,8 @@ def __init__(
8788
patch_size: int = 14,
8889
temporal_patch_size: int = 2,
8990
merge_size: int = 2,
91+
disable_grouping: Optional[bool] = None,
92+
interpolation: Optional["F.InterpolationMode"] = InterpolationMode.BICUBIC,
9093
**kwargs,
9194
) -> None:
9295
super().__init__(**kwargs)
@@ -103,63 +106,138 @@ def __init__(
103106
self.patch_size = patch_size
104107
self.temporal_patch_size = temporal_patch_size
105108
self.merge_size = merge_size
109+
self.disable_grouping = disable_grouping
110+
self.interpolation = interpolation
106111
self.data_format = ChannelDimension.FIRST
112+
self._fused_cache = {} # key: (do_norm, do_rescale, rescale_factor, device)
113+
114+
def _get_fused_mean_std(
115+
self,
116+
do_normalize: bool,
117+
image_mean: Union[float, list[float]],
118+
image_std: Union[float, list[float]],
119+
do_rescale: bool,
120+
rescale_factor: float,
121+
device: Optional["torch.device"],
122+
) -> tuple[torch.Tensor, torch.Tensor, bool]:
123+
key = (bool(do_normalize), bool(do_rescale), float(rescale_factor), str(device))
124+
if key not in self._fused_cache:
125+
if do_rescale and do_normalize:
126+
mean = torch.tensor(image_mean) * (1.0 / rescale_factor)
127+
std = torch.tensor(image_std) * (1.0 / rescale_factor)
128+
do_rescale = False
129+
else:
130+
mean = torch.tensor(image_mean)
131+
std = torch.tensor(image_std)
132+
self._fused_cache[key] = (mean.to(device=device), std.to(device=device), do_rescale)
133+
return self._fused_cache[key]
134+
135+
def rescale_and_normalize(
136+
self,
137+
images: "torch.Tensor",
138+
do_rescale: bool,
139+
rescale_factor: float,
140+
do_normalize: bool,
141+
image_mean: Union[float, list[float]],
142+
image_std: Union[float, list[float]],
143+
) -> "torch.Tensor":
144+
"""
145+
Rescale and normalize images.
146+
"""
147+
image_mean, image_std, do_rescale = self._get_fused_mean_std(
148+
do_normalize=do_normalize,
149+
image_mean=image_mean,
150+
image_std=image_std,
151+
do_rescale=do_rescale,
152+
rescale_factor=rescale_factor,
153+
device=images.device,
154+
)
155+
# if/elif as we use fused rescale and normalize if both are set to True
156+
if do_normalize:
157+
images = self.normalize(images.to(dtype=torch.float32), image_mean, image_std)
158+
elif do_rescale:
159+
images = self.rescale(images, rescale_factor)
160+
161+
return images
107162

108163
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
109-
if self.do_convert_rgb:
110-
image = convert_to_rgb(image)
111-
image = to_numpy_array(image)
112-
input_data_format = infer_channel_dimension_format(image)
113-
height, width = get_image_size(image, channel_dim=input_data_format)
114-
115-
resized_height, resized_width = height, width
116-
if self.do_resize:
117-
resized_height, resized_width = smart_resize(
118-
height,
119-
width,
120-
factor=self.patch_size * self.merge_size,
121-
min_pixels=self.min_pixels,
122-
max_pixels=self.max_pixels,
164+
image_arr = np.asarray(image, dtype=np.uint8)
165+
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True)
166+
grouped_images, grouped_images_index = group_images_by_shape(
167+
[image_data], disable_grouping=self.disable_grouping
168+
)
169+
resized_images_grouped = {}
170+
for shape, stacked_images in grouped_images.items():
171+
height, width = stacked_images.shape[-2:]
172+
if self.do_resize:
173+
resized_height, resized_width = smart_resize(
174+
height,
175+
width,
176+
factor=self.patch_size * self.merge_size,
177+
min_pixels=self.min_pixels,
178+
max_pixels=self.max_pixels,
179+
)
180+
stacked_images = self.resize(
181+
image=stacked_images,
182+
size=SizeDict(height=resized_height, width=resized_width),
183+
interpolation=self.interpolation,
184+
)
185+
resized_images_grouped[shape] = stacked_images
186+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
187+
188+
# Group images by size for further processing
189+
# Needed in case do_resize is False, or resize returns images with different sizes
190+
grouped_images, grouped_images_index = group_images_by_shape(
191+
resized_images, disable_grouping=self.disable_grouping
192+
)
193+
processed_images_grouped = {}
194+
processed_grids = {}
195+
for shape, stacked_images in grouped_images.items():
196+
resized_height, resized_width = stacked_images.shape[-2:]
197+
# Fused rescale and normalize
198+
patches = self.rescale_and_normalize(
199+
stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std
123200
)
124-
image = resize(
125-
image, size=(resized_height, resized_width), resample=self.resample, input_data_format=input_data_format
201+
if patches.ndim == 4:
202+
# add a temporal dimension if we have images
203+
patches = patches.unsqueeze(1)
204+
if patches.shape[1] % self.temporal_patch_size != 0:
205+
repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1)
206+
patches = torch.cat([patches, repeats], dim=1)
207+
batch_size, grid_t, channel = patches.shape[:3]
208+
grid_t = grid_t // self.temporal_patch_size
209+
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
210+
211+
patches = (
212+
patches.view(
213+
batch_size,
214+
grid_t,
215+
self.temporal_patch_size,
216+
channel,
217+
grid_h // self.merge_size,
218+
self.merge_size,
219+
self.patch_size,
220+
grid_w // self.merge_size,
221+
self.merge_size,
222+
self.patch_size,
223+
)
224+
.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
225+
.contiguous()
126226
)
127-
128-
if self.do_rescale:
129-
image = self.rescale(image, scale=self.rescale_factor, input_data_format=input_data_format)
130-
131-
if self.do_normalize:
132-
image = self.normalize(
133-
image=image, mean=self.image_mean, std=self.image_std, input_data_format=input_data_format
227+
# Reorder dimensions to group grid and patch information for subsequent flattening.
228+
# (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
229+
flatten_patches = patches.view(
230+
batch_size,
231+
grid_t * grid_h * grid_w,
232+
channel * self.temporal_patch_size * self.patch_size * self.patch_size,
134233
)
135234

136-
image = to_channel_dimension_format(image, self.data_format, input_channel_dim=input_data_format)
137-
138-
patches = np.array([image])
139-
140-
if patches.shape[0] == 1:
141-
# why to copy image 2 times. use self.temporal_patch_size = 2.
142-
patches = np.tile(patches, (self.temporal_patch_size, 1, 1, 1))
143-
channel = patches.shape[1]
144-
grid_t = patches.shape[0] // self.temporal_patch_size
145-
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
146-
patches = patches.reshape(
147-
grid_t,
148-
self.temporal_patch_size,
149-
channel,
150-
grid_h // self.merge_size,
151-
self.merge_size,
152-
self.patch_size,
153-
grid_w // self.merge_size,
154-
self.merge_size,
155-
self.patch_size,
156-
)
157-
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
158-
flatten_patches = patches.reshape(
159-
grid_t * grid_h * grid_w, channel * self.temporal_patch_size * self.patch_size * self.patch_size
160-
)
161-
image_grid_thw = (grid_t, grid_h, grid_w)
162-
pixel_values = torch.as_tensor(flatten_patches)
163-
grid_thw = torch.as_tensor([image_grid_thw])
235+
processed_images_grouped[shape] = flatten_patches
236+
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
237+
238+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
239+
processed_grids = reorder_images(processed_grids, grouped_images_index)
240+
pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps)
241+
image_grid_thw = torch.as_tensor(processed_grids)
164242

165-
return pixel_values, grid_thw
243+
return pixel_values, image_grid_thw

lightllm/server/httpserver/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import uvloop
66
import rpyc
7+
import socket
78
import time
89
import copy
910
import hashlib
@@ -79,6 +80,7 @@ def __init__(
7980
self.enable_multimodal = args.enable_multimodal
8081
if self.enable_multimodal:
8182
self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True})
83+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
8284
self.send_to_visual = context.socket(zmq.PUSH)
8385
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{args.visual_port}")
8486
if args.enable_cpu_cache and not self.args.enable_multimodal:

lightllm/server/visualserver/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import uvloop
55
import rpyc
6+
import socket
67
import pickle
78
import inspect
89
import setproctitle
@@ -45,6 +46,7 @@ def __init__(
4546
self.zmq_recv_socket = context.socket(zmq.PULL)
4647
self.zmq_recv_socket.bind(f"{args.zmq_mode}127.0.0.1:{args.visual_port}")
4748
self.cache_client = rpyc.connect("localhost", args.cache_port, config={"allow_pickle": True})
49+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
4850
self.cache_port = args.cache_port
4951
self.waiting_reqs: List[GroupReqIndexes] = []
5052
self.model_weightdir = args.model_dir

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import rpyc
44
import torch
5+
import socket
56
import inspect
67
from datetime import timedelta
78
from typing import Dict, List, Tuple
@@ -39,6 +40,7 @@ def exposed_init_model(self, kvargs):
3940
weight_dir = kvargs["weight_dir"]
4041
self.vit_rank_id = kvargs["vit_rank_id"]
4142
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
43+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
4244
self.data_type = kvargs["data_type"]
4345

4446
init_vision_distributed_env(kvargs)
@@ -161,6 +163,8 @@ def _init_env(port, device_id):
161163
# 注册graceful 退出的处理
162164
graceful_registry(inspect.currentframe().f_code.co_name)
163165

166+
import lightllm.utils.rpyc_fix_utils as _
167+
164168
t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True})
165169
t.start()
166170
return
@@ -182,6 +186,7 @@ async def start_model_process(port, vit_tp, device_id):
182186
while repeat_count < 20:
183187
try:
184188
con = rpyc.connect("localhost", port, config={"allow_pickle": True})
189+
con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
185190
break
186191
except BaseException:
187192
await asyncio.sleep(1)

0 commit comments

Comments
 (0)