Skip to content

Commit 4853561

Browse files
committed
[fix]0915-fix-rpyc-cost
1 parent 6b95156 commit 4853561

File tree

5 files changed

+52
-28
lines changed

5 files changed

+52
-28
lines changed

lightllm/models/vit/model.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import time
23
import json
34
import torch
45
from lightllm.models.vit.layer_infer.pre_layer_infer import ViTPreLayerInfer
@@ -179,10 +180,7 @@ def encode(self, images: List[ImageItem]):
179180
for i, img in enumerate(images):
180181
if isinstance(img, ImageItem):
181182
uuids.append(img.uuid)
182-
if self.remote_vit:
183-
image_data = img._preload_data
184-
else:
185-
image_data = read_shm(get_shm_name_data(img.uuid))
183+
image_data = read_shm(get_shm_name_data(img.uuid))
186184
image_data = Image.open(BytesIO(image_data))
187185
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
188186
img_tensors.append(t)

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def _clear(self, free_max_count: int):
8181
# if self.args.run_mode == "visual":
8282
# free_afs(get_shm_name_embed(id), self.args.image_embed_dir)
8383
# elif not self.args.enable_remote_vit:
84-
if not self.args.run_mode == "visual":
84+
if not self.args.enable_remote_vit and self.args.run_mode != "visual":
8585
free_shm(get_shm_name_embed(id))
8686
del self._md5_to_record[record.md5sum]
8787
del self._records[id]

lightllm/server/httpserver/manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import asyncio
55
import uvloop
66
import rpyc
7+
import socket
78
import time
89
import copy
910
import hashlib
@@ -84,6 +85,7 @@ def __init__(
8485
self.enable_multimodal = enable_multimodal
8586
if self.enable_multimodal:
8687
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
88+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
8789
# 初始化VIT连接管理器
8890
from lightllm.server.visualserver.vit_connect import VITConnectionManager
8991

lightllm/server/visualserver/manager.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
import asyncio
55
import uvloop
66
import rpyc
7+
import socket
78
import pickle
89
import hashlib
910
import datetime
1011
import inspect
1112
from fastapi import Request
1213
from ..tokenizer import get_tokenizer
1314
from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes
15+
from lightllm.server.embed_cache.utils import get_shm_name_data, create_shm
1416
from lightllm.server.core.objs import ShmReqManager
1517
from lightllm.server.core.objs import SamplingParams
1618
from lightllm.server.core.objs import Req, FinishStatus
@@ -63,6 +65,7 @@ def _setup_connections(self):
6365
self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio)
6466
self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}")
6567
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
68+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
6669

6770
async def wait_to_model_ready(self):
6871
visual_dp = self.args.visual_dp
@@ -100,7 +103,6 @@ async def infer_imgs(self, images: List[ImageItem]):
100103
for vit_tp_rank in range(self.args.visual_tp):
101104
task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images))
102105
tasks.append(task)
103-
104106
await asyncio.gather(*tasks)
105107
return
106108

@@ -162,19 +164,34 @@ async def _recv_reqs(self):
162164
# ]
163165
logger.info(f"Receive req {recv_req.group_req_id}, image_count:{len(recv_req.multimodal_params.images)}")
164166
uuids = [img.uuid for img in recv_req.multimodal_params.images]
165-
already_embed = self.cache_client.root.get_items_embed(uuids)
167+
already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids)
166168
if all(already_embed):
167169
return None
170+
171+
uuids = []
168172
token_nums = []
173+
datas = []
169174
for img, embed in zip(recv_req.multimodal_params.images, already_embed):
170175
if not embed:
171176
uuids.append(img.uuid)
172177
token_nums.append(img.token_num)
178+
datas.append(img._preload_data)
179+
img.free()
173180
while True:
174-
records = self.cache_client.root.alloc(uuids, token_nums)
181+
records = await asyncio.to_thread(self.cache_client.root.alloc, uuids, token_nums)
175182
if records is not None:
176183
break
177-
await asyncio.sleep(0.1)
184+
await asyncio.sleep(0.01)
185+
ready_flags = obtain(self.cache_client.root.get_items_data(uuids))
186+
update_data_ids = []
187+
188+
for uid, ready, data in zip(uuids, ready_flags, datas):
189+
if not ready:
190+
create_shm(get_shm_name_data(uid), data)
191+
update_data_ids.append(uid)
192+
193+
if update_data_ids:
194+
await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids)
178195
return recv_req
179196
else:
180197
return self.vit_receiver.recv_pyobj(zmq.NOBLOCK)
@@ -193,7 +210,8 @@ async def loop_for_netio_req(self):
193210
self.waiting_reqs.append(recv_req)
194211
else:
195212
assert False, f"Error Req Inf {recv_req}"
196-
self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256)
213+
await asyncio.sleep(0)
214+
self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256)
197215
except zmq.ZMQError:
198216
# 当队列已经开始清空的时候,将一次接受数量下调
199217
self.visual_recv_max_count = 64
@@ -217,21 +235,11 @@ async def loop_for_fwd_visual_only(self):
217235
images_need_infer.append(img)
218236

219237
if len(images_need_infer) == self.infer_batch_size:
220-
_t0 = time.perf_counter()
221238
await self.infer_imgs(images_need_infer)
222-
logger.info(
223-
f"[visual] batch infer complete, image_count: {len(images_need_infer)}, "
224-
f"elapsed_time {(time.perf_counter()-_t0) * 1000}ms"
225-
)
226239
images_need_infer = []
227240

228241
if len(images_need_infer) > 0:
229-
_t1 = time.perf_counter()
230242
await self.infer_imgs(images_need_infer)
231-
logger.info(
232-
f"[visual] batch infer complete, image_count:{len(images_need_infer)}, "
233-
f"elapsed_time {(time.perf_counter()-_t1) * 1000}ms"
234-
)
235243
images_need_infer = []
236244
# 在这里release这个image,ref-1
237245
logger.info(f"req-id {visual_req.group_req_id} has been release ok")

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
import rpyc
44
import torch
5+
import time
56
import inspect
67
from datetime import timedelta
78
from typing import Dict, List, Tuple
@@ -30,6 +31,11 @@
3031
from lightllm.utils.dist_utils import init_vision_distributed_env
3132
from lightllm.utils.graceful_utils import graceful_registry
3233
from lightllm.utils.envs_utils import get_env_start_args
34+
from lightllm.utils.log_utils import init_logger
35+
import pickle
36+
import socket
37+
38+
logger = init_logger(__name__)
3339

3440

3541
class VisualModelRpcServer(rpyc.Service):
@@ -48,10 +54,12 @@ def exposed_init_model(self, kvargs):
4854
max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1)
4955
remote_vit = True if self.args.run_mode == "visual" else False
5056

57+
self.image_embed_dir = self.args.image_embed_dir
5158
self.dp_rank_id = kvargs["dp_rank_id"]
5259
self.tp_rank_id = kvargs["tp_rank_id"]
5360
kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id
5461
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
62+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
5563

5664
init_vision_distributed_env(kvargs)
5765
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
@@ -109,19 +117,18 @@ def forward(self, images: List[ImageItem]):
109117
def exposed_encode(self, images: List[ImageItem]):
110118
images = obtain(images)
111119
all_img_embeds, uuids, valid_ids = self.forward(images)
112-
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
113-
120+
all_img_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True)
114121
if self.tp_rank_id == 0:
115122
# ready_flags = obtain(self.cache_client.root.get_items_embed(uuids))
116123
ids_to_set = []
117-
for i, img in enumerate(images):
124+
for i in range(len(images)):
118125
# if ready:
119126
# continue
120127
uid = uuids[i]
121128
start, end = valid_ids[i]
122129
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
123130
if self.args.run_mode == "visual":
124-
create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.args.image_embed_dir)
131+
create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.image_embed_dir)
125132
else:
126133
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
127134
ids_to_set.append(uid)
@@ -131,11 +138,13 @@ def exposed_encode(self, images: List[ImageItem]):
131138

132139

133140
class VisualModelRpcClient:
134-
def __init__(self, model_rpc, vit_tp, rpc_server_process=None):
135-
self.model: VisualModelRpcServer = model_rpc
141+
def __init__(self, conn, vit_tp, rpc_server_process=None):
142+
self.conn = conn
143+
self.model: VisualModelRpcServer = conn.root
136144
self.vit_tp = vit_tp
137145
self.rpc_server_process = rpc_server_process
138146
self.use_rpc = True
147+
self._bg = rpyc.BgServingThread(self.conn)
139148
if self.use_rpc:
140149

141150
def async_wrap(f):
@@ -176,7 +185,13 @@ def _init_env(port, device_id):
176185
# 注册graceful 退出的处理
177186
graceful_registry(inspect.currentframe().f_code.co_name)
178187

179-
t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True})
188+
auth = lambda sock: (sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) or (sock, None))
189+
t = ThreadedServer(
190+
VisualModelRpcServer(),
191+
port=port,
192+
protocol_config={"allow_pickle": True},
193+
authenticator=auth,
194+
)
180195
t.start()
181196
return
182197

@@ -197,6 +212,7 @@ async def start_model_process(port, vit_tp, device_id):
197212
while repeat_count < 20:
198213
try:
199214
con = rpyc.connect("localhost", port, config={"allow_pickle": True})
215+
con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
200216
break
201217
except BaseException:
202218
await asyncio.sleep(1)
@@ -205,4 +221,4 @@ async def start_model_process(port, vit_tp, device_id):
205221
raise Exception("init rpc env error!")
206222

207223
assert proc.is_alive()
208-
return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc)
224+
return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc)

0 commit comments

Comments
 (0)