Skip to content

Commit 6b95156

Browse files
committed
0911-add-other-multimodal's vit dispatch
1 parent 0a296a1 commit 6b95156

File tree

10 files changed

+46
-15
lines changed

10 files changed

+46
-15
lines changed

lightllm/models/gemma3/gemma3_visual.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717

1818
class Gemma3VisionModel:
19-
def __init__(self):
19+
def __init__(self, kvargs):
20+
self.remote_vit = kvargs.get("remote_vit", False)
2021
pass
2122

2223
def load_model(self, weight_dir):
@@ -122,7 +123,10 @@ def encode(self, images: List[ImageItem]):
122123
for i, img in enumerate(images):
123124
if isinstance(img, ImageItem):
124125
uuids.append(img.uuid)
125-
image_data = read_shm(get_shm_name_data(img.uuid))
126+
if self.remote_vit:
127+
image_data = img._preload_data
128+
else:
129+
image_data = read_shm(get_shm_name_data(img.uuid))
126130
image_data = Image.open(BytesIO(image_data))
127131
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
128132
img_tensors.append(t)

lightllm/models/llava/llava_visual.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515

1616

1717
class LlavaVisionModel:
18-
def __init__(self):
18+
def __init__(self, kvargs):
19+
self.remote_vit = kvargs.get("remote_vit", False)
1920
pass
2021

2122
def load_model(self, weight_dir):
@@ -133,7 +134,10 @@ def encode(self, images: List[ImageItem]):
133134
for i, img in enumerate(images):
134135
if isinstance(img, ImageItem):
135136
uuids.append(img.uuid)
136-
image_data = read_shm(get_shm_name_data(img.uuid))
137+
if self.remote_vit:
138+
image_data = img._preload_data
139+
else:
140+
image_data = read_shm(get_shm_name_data(img.uuid))
137141
image_data = Image.open(BytesIO(image_data)).convert("RGB")
138142
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
139143
img_tensors.append(t)

lightllm/models/qwen2_5_vl/qwen2_5_visual.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
self.window_size = window_size
172172
self.fullatt_block_indexes = fullatt_block_indexes
173173
self.out_hidden_size = out_hidden_size
174+
self.remote_vit = kvargs.get("remote_vit", False)
174175

175176
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
176177

@@ -381,7 +382,10 @@ def encode(self, images: List[ImageItem]):
381382
for i, img in enumerate(images):
382383
if isinstance(img, ImageItem):
383384
uuids.append(img.uuid)
384-
image_data = read_shm(get_shm_name_data(img.uuid))
385+
if self.remote_vit:
386+
image_data = img._preload_data
387+
else:
388+
image_data = read_shm(get_shm_name_data(img.uuid))
385389
image_data = Image.open(BytesIO(image_data))
386390
image_data = resize_image(image_data)
387391
pixel_values, image_grid_thw = self.processor.preprocess(image_data)

lightllm/models/qwen2_vl/qwen2_visual.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def __init__(
200200
self.patch_size = patch_size
201201
self.spatial_merge_size = spatial_merge_size
202202
self.temporal_patch_size = temporal_patch_size
203+
self.remote_vit = kvargs.get("remote_vit", False)
203204

204205
self.patch_embed = PatchEmbed(
205206
patch_size=self.patch_size,
@@ -309,7 +310,10 @@ def encode(self, images: List[ImageItem]):
309310
for i, img in enumerate(images):
310311
if isinstance(img, ImageItem):
311312
uuids.append(img.uuid)
312-
image_data = read_shm(get_shm_name_data(img.uuid))
313+
if self.remote_vit:
314+
image_data = img._preload_data
315+
else:
316+
image_data = read_shm(get_shm_name_data(img.uuid))
313317
image_data = Image.open(BytesIO(image_data))
314318
image_data = resize_image(image_data)
315319
pixel_values, image_grid_thw = self.processor.preprocess(image_data)

lightllm/models/qwen_vl/qwen_visual.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
333333
class QWenVisionTransformer(nn.Module):
334334
def __init__(
335335
self,
336+
kvargs,
336337
image_size: int,
337338
patch_size: int,
338339
width: int,
@@ -344,6 +345,7 @@ def __init__(
344345
**kwargs,
345346
):
346347
super().__init__()
348+
self.remote_vit = kvargs.get("remote_vit", False)
347349
image_height, image_width = self.image_size = (image_size, image_size)
348350
patch_height, patch_width = self.patch_size = (patch_size, patch_size)
349351
self.grid_size = (image_height // patch_height, image_width // patch_width)
@@ -422,7 +424,10 @@ def encode(self, image_uuids: List):
422424
for i, item in enumerate(image_uuids):
423425
if isinstance(item, int):
424426
uuids.append(item)
425-
image_data = read_shm(get_shm_name_data(item))
427+
if self.remote_vit:
428+
image_data = item._preload_data
429+
else:
430+
image_data = read_shm(get_shm_name_data(item.uuid))
426431
image_data = Image.open(BytesIO(image_data)).convert("RGB")
427432
t = self.image_transform(image_data)
428433
img_tensors.append(t)

lightllm/models/tarsier2/tarsier2_visual.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def forward(self, image_features, input_embeddings):
152152
class TarsierVisionTransformerPretrainedModel(nn.Module):
153153
def __init__(
154154
self,
155+
kvargs,
155156
vision_config=None,
156157
text_config=None,
157158
ignore_index=-100,
@@ -165,6 +166,7 @@ def __init__(
165166
**kwargs,
166167
):
167168
super().__init__()
169+
self.remote_vit = kvargs.get("remote_vit", False)
168170
self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config)
169171

170172
if projection_head == "Pixel_Shuffle":
@@ -251,7 +253,10 @@ def encode(self, images: List[ImageItem]):
251253
for i, img in enumerate(images):
252254
if isinstance(img, ImageItem):
253255
uuids.append(img.uuid)
254-
image_data = read_shm(get_shm_name_data(img.uuid))
256+
if self.remote_vit:
257+
image_data = img._preload_data
258+
else:
259+
image_data = read_shm(get_shm_name_data(img.uuid))
255260
image_data = Image.open(BytesIO(image_data))
256261
image_data = resize_image(image_data)
257262
pixel_values, image_grid_thw = self.processor.preprocess(image=image_data)

lightllm/models/vit/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def __init__(self, kvargs):
4747
self.quant_cfg_path = kvargs.get("quant_cfg", None)
4848
self.load_image_func = get_load_image_func(self.weight_dir_)
4949
self.max_batch_size = kvargs.get("max_batch_size", 1)
50+
self.remote_vit = kvargs.get("remote_vit", False)
5051

5152
self._init_datatype()
5253
self._init_config()
@@ -178,8 +179,10 @@ def encode(self, images: List[ImageItem]):
178179
for i, img in enumerate(images):
179180
if isinstance(img, ImageItem):
180181
uuids.append(img.uuid)
181-
image_data = img._preload_data
182-
# image_data = read_shm(get_shm_name_data(img.uuid))
182+
if self.remote_vit:
183+
image_data = img._preload_data
184+
else:
185+
image_data = read_shm(get_shm_name_data(img.uuid))
183186
image_data = Image.open(BytesIO(image_data))
184187
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
185188
img_tensors.append(t)

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def exposed_init_model(self, kvargs):
4646
quant_type = self.args.vit_quant_type
4747
quant_cfg = self.args.vit_quant_cfg
4848
max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1)
49+
remote_vit = True if self.args.run_mode == "visual" else False
4950

5051
self.dp_rank_id = kvargs["dp_rank_id"]
5152
self.tp_rank_id = kvargs["tp_rank_id"]
@@ -62,10 +63,11 @@ def exposed_init_model(self, kvargs):
6263
"quant_type": quant_type,
6364
"quant_cfg": quant_cfg,
6465
"max_batch_size": max_batch_size,
66+
"remote_vit": remote_vit,
6567
}
6668
self.model_type = model_cfg["model_type"]
6769
if self.model_type == "qwen":
68-
self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16()
70+
self.model = QWenVisionTransformer(kvargs, **model_cfg["visual"]).eval().bfloat16()
6971
elif self.model_type == "qwen2_vl":
7072
self.model = (
7173
Qwen2VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
@@ -75,14 +77,14 @@ def exposed_init_model(self, kvargs):
7577
Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
7678
)
7779
elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration":
78-
self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16()
80+
self.model = TarsierVisionTransformerPretrainedModel(kvargs, **model_cfg).eval().bfloat16()
7981
elif self.model_type == "llava":
80-
self.model = LlavaVisionModel()
82+
self.model = LlavaVisionModel(kvargs)
8183
elif self.model_type == "internvl_chat":
8284
self.model = VisionTransformer(kvargs)
8385
# self.model = InternVLVisionModel()
8486
elif self.model_type == "gemma3":
85-
self.model = Gemma3VisionModel()
87+
self.model = Gemma3VisionModel(kvargs)
8688
else:
8789
raise Exception(f"can not support {self.model_type} now")
8890
self.model.load_model(weight_dir)

lightllm/server/visualserver/vit_connect.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def _setup_vit_connections(self):
5050
"""
5151
if self.remote_vit:
5252
# 远程VIT实例模式
53-
print("remote")
5453
self._setup_remote_vit_connections()
5554
else:
5655
print("not remote")

lightllm/utils/shm_size_check.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_
117117
)
118118
fake_image_item.image_w = fake_image_item._data[0]
119119
fake_image_item.image_h = fake_image_item._data[1]
120+
fake_image_item.extra_params["image_patch_max_num"] = 12
120121
max_image_tokens = tokenizer.get_image_token_length(fake_image_item)
121122

122123
# 估算图片 token 所需的资源

0 commit comments

Comments
 (0)