22import numpy as np
33import rpyc
44import torch
5+ import time
56import inspect
67from datetime import timedelta
78from typing import Dict , List , Tuple
3031from lightllm .utils .dist_utils import init_vision_distributed_env
3132from lightllm .utils .graceful_utils import graceful_registry
3233from lightllm .utils .envs_utils import get_env_start_args
34+ from lightllm .utils .log_utils import init_logger
35+ import pickle
36+ import socket
37+
38+ logger = init_logger (__name__ )
3339
3440
3541class VisualModelRpcServer (rpyc .Service ):
@@ -48,10 +54,12 @@ def exposed_init_model(self, kvargs):
4854 max_batch_size = min (self .args .visual_infer_batch_size // self .args .visual_dp , 1 )
4955 remote_vit = True if self .args .run_mode == "visual" else False
5056
57+ self .image_embed_dir = self .args .image_embed_dir
5158 self .dp_rank_id = kvargs ["dp_rank_id" ]
5259 self .tp_rank_id = kvargs ["tp_rank_id" ]
5360 kvargs ["vit_rank_id" ] = self .dp_rank_id * self .args .visual_tp + self .tp_rank_id
5461 self .cache_client = rpyc .connect ("localhost" , cache_port , config = {"allow_pickle" : True })
62+ self .cache_client ._channel .stream .sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
5563
5664 init_vision_distributed_env (kvargs )
5765 model_cfg , _ = PretrainedConfig .get_config_dict (weight_dir )
@@ -109,19 +117,18 @@ def forward(self, images: List[ImageItem]):
109117 def exposed_encode (self , images : List [ImageItem ]):
110118 images = obtain (images )
111119 all_img_embeds , uuids , valid_ids = self .forward (images )
112- all_img_embeds = all_img_embeds .to (torch .device ("cpu" ))
113-
120+ all_img_embeds = all_img_embeds .to (torch .device ("cpu" ), non_blocking = True )
114121 if self .tp_rank_id == 0 :
115122 # ready_flags = obtain(self.cache_client.root.get_items_embed(uuids))
116123 ids_to_set = []
117- for i , img in enumerate ( images ):
124+ for i in range ( len ( images ) ):
118125 # if ready:
119126 # continue
120127 uid = uuids [i ]
121128 start , end = valid_ids [i ]
122129 cur_embed_bytes = tensor2bytes (all_img_embeds [start :end ])
123130 if self .args .run_mode == "visual" :
124- create_afs (get_shm_name_embed (uid ), cur_embed_bytes , self .args . image_embed_dir )
131+ create_afs (get_shm_name_embed (uid ), cur_embed_bytes , self .image_embed_dir )
125132 else :
126133 create_shm (get_shm_name_embed (uid ), cur_embed_bytes )
127134 ids_to_set .append (uid )
@@ -131,11 +138,13 @@ def exposed_encode(self, images: List[ImageItem]):
131138
132139
133140class VisualModelRpcClient :
134- def __init__ (self , model_rpc , vit_tp , rpc_server_process = None ):
135- self .model : VisualModelRpcServer = model_rpc
141+ def __init__ (self , conn , vit_tp , rpc_server_process = None ):
142+ self .conn = conn
143+ self .model : VisualModelRpcServer = conn .root
136144 self .vit_tp = vit_tp
137145 self .rpc_server_process = rpc_server_process
138146 self .use_rpc = True
147+ self ._bg = rpyc .BgServingThread (self .conn )
139148 if self .use_rpc :
140149
141150 def async_wrap (f ):
@@ -176,7 +185,13 @@ def _init_env(port, device_id):
176185 # 注册graceful 退出的处理
177186 graceful_registry (inspect .currentframe ().f_code .co_name )
178187
179- t = ThreadedServer (VisualModelRpcServer (), port = port , protocol_config = {"allow_pickle" : True })
188+ auth = lambda sock : (sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 ) or (sock , None ))
189+ t = ThreadedServer (
190+ VisualModelRpcServer (),
191+ port = port ,
192+ protocol_config = {"allow_pickle" : True },
193+ authenticator = auth ,
194+ )
180195 t .start ()
181196 return
182197
@@ -197,6 +212,7 @@ async def start_model_process(port, vit_tp, device_id):
197212 while repeat_count < 20 :
198213 try :
199214 con = rpyc .connect ("localhost" , port , config = {"allow_pickle" : True })
215+ con ._channel .stream .sock .setsockopt (socket .IPPROTO_TCP , socket .TCP_NODELAY , 1 )
200216 break
201217 except BaseException :
202218 await asyncio .sleep (1 )
@@ -205,4 +221,4 @@ async def start_model_process(port, vit_tp, device_id):
205221 raise Exception ("init rpc env error!" )
206222
207223 assert proc .is_alive ()
208- return VisualModelRpcClient (con . root , vit_tp , rpc_server_process = proc )
224+ return VisualModelRpcClient (con , vit_tp , rpc_server_process = proc )
0 commit comments