Skip to content

Commit 58f8c1d

Browse files
SangChengCsangchengmengwangzaijun
authored
add visual_send_bs args (#1109)
Co-authored-by: sangchengmeng <sangchengmeng@sensetime.com> Co-authored-by: wangzaijun <wangzaijun@sensetime.com>
1 parent 28e0048 commit 58f8c1d

File tree

4 files changed

+70
-17
lines changed

4 files changed

+70
-17
lines changed

lightllm/models/qwen2_vl/vision_process.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
)
2323
from torchvision.transforms.v2 import functional as F
2424

25+
from lightllm.utils.log_utils import init_logger
26+
27+
logger = init_logger(__name__)
28+
29+
2530
IMAGE_FACTOR = 28
2631
MIN_PIXELS = 4 * 28 * 28
2732
MAX_PIXELS = 16384 * 28 * 28
@@ -160,9 +165,19 @@ def rescale_and_normalize(
160165

161166
return images
162167

168+
@torch.inference_mode()
163169
def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
170+
try:
171+
return self._preprocess_bydevice(image, device="cuda")
172+
except Exception as e:
173+
logger.warning(f"Exception during image preprocessing on CUDA: {str(e)}")
174+
torch.cuda.current_stream().synchronize()
175+
return self._preprocess_bydevice(image, device="cpu")
176+
177+
def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]:
164178
image_arr = np.asarray(image, dtype=np.uint8)
165-
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True)
179+
image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True)
180+
166181
grouped_images, grouped_images_index = group_images_by_shape(
167182
[image_data], disable_grouping=self.disable_grouping
168183
)
@@ -183,27 +198,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
183198
interpolation=self.interpolation,
184199
)
185200
resized_images_grouped[shape] = stacked_images
201+
202+
grouped_images = None
186203
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
204+
resized_images_grouped = None
187205

188-
# Group images by size for further processing
189-
# Needed in case do_resize is False, or resize returns images with different sizes
190206
grouped_images, grouped_images_index = group_images_by_shape(
191207
resized_images, disable_grouping=self.disable_grouping
192208
)
209+
resized_images = None
210+
193211
processed_images_grouped = {}
194212
processed_grids = {}
213+
195214
for shape, stacked_images in grouped_images.items():
215+
stacked_images = stacked_images.to("cuda", non_blocking=True)
216+
196217
resized_height, resized_width = stacked_images.shape[-2:]
197-
# Fused rescale and normalize
218+
198219
patches = self.rescale_and_normalize(
199-
stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std
220+
stacked_images,
221+
self.do_rescale,
222+
self.rescale_factor,
223+
self.do_normalize,
224+
self.image_mean,
225+
self.image_std,
200226
)
201227
if patches.ndim == 4:
202-
# add a temporal dimension if we have images
203228
patches = patches.unsqueeze(1)
229+
204230
if patches.shape[1] % self.temporal_patch_size != 0:
205231
repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1)
206232
patches = torch.cat([patches, repeats], dim=1)
233+
207234
batch_size, grid_t, channel = patches.shape[:3]
208235
grid_t = grid_t // self.temporal_patch_size
209236
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
@@ -224,8 +251,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
224251
.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)
225252
.contiguous()
226253
)
227-
# Reorder dimensions to group grid and patch information for subsequent flattening.
228-
# (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w)
254+
229255
flatten_patches = patches.view(
230256
batch_size,
231257
grid_t * grid_h * grid_w,
@@ -235,9 +261,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]:
235261
processed_images_grouped[shape] = flatten_patches
236262
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
237263

264+
grouped_images = None
265+
238266
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
239267
processed_grids = reorder_images(processed_grids, grouped_images_index)
240-
pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps)
268+
269+
pixel_values = torch.cat(processed_images, dim=0)
241270
image_grid_thw = torch.as_tensor(processed_grids)
242271

243272
return pixel_values, image_grid_thw

lightllm/server/api_cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,15 @@ def make_argument_parser() -> argparse.ArgumentParser:
368368
parser.add_argument(
369369
"--visual_infer_batch_size", type=int, default=1, help="number of images to process in each inference batch"
370370
)
371+
parser.add_argument(
372+
"--visual_send_batch_size",
373+
type=int,
374+
default=1,
375+
help="""
376+
number of images embedding to send to llm process in each batch,
377+
bigger size can improve throughput but increase latency possibly in some cases
378+
""",
379+
)
371380
parser.add_argument(
372381
"--visual_gpu_ids", nargs="+", type=int, default=None, help="List of GPU IDs to use, e.g., 0 1 2"
373382
)

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class StartArgs:
7878
grouping_key: List[str] = field(default_factory=list)
7979
push_interval: int = field(default=10)
8080
visual_infer_batch_size: int = field(default=1)
81+
visual_send_batch_size: int = field(default=1)
8182
visual_gpu_ids: List[int] = field(default_factory=lambda: [0])
8283
visual_tp: int = field(default=1)
8384
visual_dp: int = field(default=1)

lightllm/server/visualserver/manager.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def __init__(
5757
self.trust_remote_code = args.trust_remote_code
5858
self.args = args
5959
self.visual_model_rpc_ports = visual_model_rpc_ports
60+
self.send_batch_size = args.visual_send_batch_size
6061
self.shm_req_manager = ShmReqManager()
6162

6263
async def wait_to_model_ready(self):
@@ -117,6 +118,18 @@ async def loop_for_fwd(self):
117118
else:
118119
processing_group_reqs = []
119120
images_need_infer = []
121+
ready_to_send = []
122+
123+
def flush_ready(force: bool = False):
124+
if not ready_to_send:
125+
return
126+
if not force and len(ready_to_send) < self.send_batch_size:
127+
return
128+
129+
for group_req_indexes in ready_to_send:
130+
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
131+
ready_to_send.clear()
132+
120133
while len(self.waiting_reqs) > 0:
121134
group_req_indexes = self.waiting_reqs.pop(0)
122135
shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0])
@@ -146,24 +159,25 @@ async def loop_for_fwd(self):
146159
if len(images_need_infer) == self.infer_batch_size:
147160
await self.infer_imgs(images_need_infer)
148161
images_need_infer = []
149-
for _group_req_indexes in processing_group_reqs:
150-
self.send_to_next_module.send_pyobj(
151-
_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL
152-
)
162+
ready_to_send.extend(processing_group_reqs)
153163
processing_group_reqs = []
164+
flush_ready(force=False)
154165

155166
if len(images_need_infer) == 0:
156-
self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
167+
ready_to_send.append(group_req_indexes)
168+
flush_ready(force=False)
157169
else:
158170
processing_group_reqs.append(group_req_indexes)
159171

160172
if len(images_need_infer) > 0:
161173
await self.infer_imgs(images_need_infer)
162-
for _group_req_indexes in processing_group_reqs:
163-
self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)
164-
processing_group_reqs = []
165174
images_need_infer = []
166175

176+
# 这些处理完 image 的 group 也 ready 了
177+
ready_to_send.extend(processing_group_reqs)
178+
processing_group_reqs = []
179+
flush_ready(force=True)
180+
167181
async def loop_for_netio_req(self):
168182
if not hasattr(self, "visual_recv_max_count"):
169183
self.visual_recv_max_count = 64

0 commit comments

Comments
 (0)