From fbd6c7cac3c87c3884117f2346cf68269dba9750 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 15 Aug 2025 06:04:40 +0000 Subject: [PATCH 01/40] 0815-temp --- lightllm/server/api_http.py | 28 ++++++- lightllm/server/api_lightllm.py | 16 ++++ lightllm/server/api_server.py | 4 +- lightllm/server/api_start.py | 102 ++++++++++++++++++++++++ lightllm/server/visualserver/manager.py | 6 +- 5 files changed, 151 insertions(+), 5 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 3835994e5..8a4e98da2 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -40,8 +40,9 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from .visualserver.manager import VisualManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .api_lightllm import lightllm_get_score +from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger from lightllm.utils.error_utils import ServerBusyError @@ -69,6 +70,7 @@ class G_Objs: g_generate_func: Callable = None g_generate_stream_func: Callable = None httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None + visual_manager: VisualManager = None shared_token_load: TokenLoad = None def set_args(self, args): @@ -89,6 +91,16 @@ def set_args(self, args): args, metric_port=args.metric_port, ) + elif args.run_mode == "visual_only": + self.metric_client = MetricClient(args.metric_port) + self.httpserver_manager = None + self.visual_manager = VisualManager( + args, + next_module_port=args.next_module_port, + visual_port=args.visual_port, + cache_port=args.cache_port, + visual_model_rpc_ports=args.visual_model_rpc_ports, + ) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -139,7 +151,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode == "pd_master": + if g_objs.args.run_mode in ["pd_master", "visual_only"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -209,6 +221,18 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) +@app.post("/get_image_embed") +async def get_image_embed(request: Request) -> Response: + try: + return await lightllm_get_image_embedding(request, g_objs.visual_manager) + except ServerBusyError as e: + logger.error("%s", str(e), exc_info=True) + return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) + except Exception as e: + logger.error("An error occurred: %s", str(e), exc_info=True) + return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) + + @app.post("/") async def compat_generate(request: Request) -> Response: request_dict = await request.json() diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 6f258379e..7a52ef667 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -5,6 +5,7 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager +from .visualserver.manager import VisualManager import ujson as json @@ -136,3 +137,18 @@ async def stream_results() -> AsyncGenerator[bytes, None]: background_tasks = BackgroundTasks() return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) + + +async def lightllm_get_image_embedding(request: Request, visual_manager: VisualManager) -> Response: + request_dict = await request.json() + sample_params_dict = request_dict["parameters"] + sampling_params = SamplingParams() + sampling_params.init(tokenizer=None, **sample_params_dict) + sampling_params.verify() + multimodal_params_dict = request_dict.get("multimodal_params", {}) + multimodal_params = MultimodalParams(**multimodal_params_dict) + + result_embeddings = VisualManager.generate(multimodal_params, sampling_params, request=request) + + # 5. Return JSON result + return {"embeddings": result_embeddings} diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808..6ffe6b31d 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,11 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) + elif args.run_mode == "visual_only": + visual_only_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c2a87b4c3..9483532a4 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -416,6 +416,108 @@ def pd_master_start(args): http_server_process.wait() +def visual_only_start(args): + set_unique_server_name(args) + if args.run_mode != "visual_only": + return + + can_use_ports = alloc_can_use_network_port( + num=5 + args.visual_dp * args.visual_tp, + ) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + visual_port, + audio_port, + cache_port, + metric_port, + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] + + visual_model_tp_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + can_use_ports = can_use_ports[args.visual_tp :] + visual_model_tp_ports.append(tp_ports_for_dp) + + # 将申请好的端口放入args参数中 + args.router_port = router_port + args.visual_port = visual_port + args.audio_port = audio_port + args.cache_port = cache_port + args.metric_port = metric_port + + logger.info(f"all start args:{args}") + + set_env_start_args(args) + + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(cache_port, args)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, audio_port, visual_port, cache_port, visual_model_tp_ports), + ], + ) + if args.enable_multimodal_audio: + from .audioserver.manager import start_audio_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_audio_process, + ], + start_args=[ + (args, router_port, audio_port, cache_port), + ], + ) + + # 启动 gunicorn + command = [ + "gunicorn", + "--workers", + f"{args.httpserver_workers}", + "--worker-class", + "uvicorn.workers.UvicornWorker", + "--bind", + f"{args.host}:{args.port}", + "--log-level", + "info", + "--access-logfile", + "-", + "--error-logfile", + "-", + "lightllm.server.api_http:app", + "--timeout", + f"{get_lightllm_gunicorn_time_out_seconds()}", + "--keep-alive", + f"{get_lightllm_gunicorn_keep_alive()}", + ] + + # 启动子进程 + http_server_process = subprocess.Popen(command) + + if "s3://" in args.model_dir: + from lightllm.utils.petrel_helper import s3_model_clear + + s3_model_clear(args.model_dir) + + if args.health_monitor: + from lightllm.server.health_monitor.manager import start_health_check_process + + process_manager.start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) + setup_signal_handlers(http_server_process, process_manager) + http_server_process.wait() + return + + def config_server_start(args): set_unique_server_name(args) if args.run_mode != "config_server": diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 4a3dec826..34df571f6 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -30,9 +30,11 @@ def __init__( cache_port, visual_model_rpc_ports, ): + self.visual_only = True if args.run_mode == "visual_only" else False context = zmq.Context(2) - self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") + if not self.visual_only: + self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) + self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") From 4acd7e7940f75291c0cc0cfcc90f6d54a5ed5d76 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 15 Aug 2025 11:28:30 +0000 Subject: [PATCH 02/40] 0815-add-visual-only --- lightllm/server/api_cli.py | 2 +- lightllm/server/api_http.py | 20 +- lightllm/server/api_lightllm.py | 15 +- lightllm/server/api_start.py | 5 +- .../httpserver_for_visual_only/__init__.py | 0 .../httpserver_for_visual_only/manager.py | 461 ++++++++++++++++++ lightllm/server/visualserver/manager.py | 16 +- .../visualserver/model_infer/model_rpc.py | 3 +- 8 files changed, 497 insertions(+), 25 deletions(-) create mode 100644 lightllm/server/httpserver_for_visual_only/__init__.py create mode 100644 lightllm/server/httpserver_for_visual_only/manager.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..b7f5e7707 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server"], + choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8a4e98da2..bd5c23d88 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -42,6 +42,7 @@ from .httpserver.manager import HttpServerManager from .visualserver.manager import VisualManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster +from .httpserver_for_visual_only.manager import HttpServerManagerForVisualOnly from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger @@ -69,7 +70,7 @@ class G_Objs: args: object = None g_generate_func: Callable = None g_generate_stream_func: Callable = None - httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None + httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster, HttpServerManagerForVisualOnly] = None visual_manager: VisualManager = None shared_token_load: TokenLoad = None @@ -92,14 +93,12 @@ def set_args(self, args): metric_port=args.metric_port, ) elif args.run_mode == "visual_only": - self.metric_client = MetricClient(args.metric_port) - self.httpserver_manager = None - self.visual_manager = VisualManager( + # self.metric_client = MetricClient(args.metric_port) + self.httpserver_manager = HttpServerManagerForVisualOnly( args, - next_module_port=args.next_module_port, - visual_port=args.visual_port, cache_port=args.cache_port, - visual_model_rpc_ports=args.visual_model_rpc_ports, + visual_port=args.visual_port, + # metric_port=args.metric_port, ) else: init_tokenizer(args) # for openai api @@ -221,10 +220,10 @@ async def get_score(request: Request) -> Response: return create_error_response(HTTPStatus.EXPECTATION_FAILED, str(e)) -@app.post("/get_image_embed") +@app.post("/get_image_embedding") async def get_image_embed(request: Request) -> Response: try: - return await lightllm_get_image_embedding(request, g_objs.visual_manager) + return await lightllm_get_image_embedding(request, g_objs.httpserver_manager) except ServerBusyError as e: logger.error("%s", str(e), exc_info=True) return create_error_response(HTTPStatus.SERVICE_UNAVAILABLE, str(e)) @@ -358,6 +357,7 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - loop.create_task(g_objs.httpserver_manager.handle_loop()) + if g_objs.httpserver_manager is not None: + loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 7a52ef667..660608f6e 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -5,7 +5,8 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager -from .visualserver.manager import VisualManager +from .httpserver_for_visual_only.manager import HttpServerManagerForVisualOnly +from fastapi.responses import JSONResponse import ujson as json @@ -139,8 +140,12 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def lightllm_get_image_embedding(request: Request, visual_manager: VisualManager) -> Response: +async def lightllm_get_image_embedding( + request: Request, httpserver_manager: HttpServerManagerForVisualOnly +) -> Response: request_dict = await request.json() + # request_dict: {'parameters': {'max_new_tokens': 128}, + # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} sample_params_dict = request_dict["parameters"] sampling_params = SamplingParams() sampling_params.init(tokenizer=None, **sample_params_dict) @@ -148,7 +153,7 @@ async def lightllm_get_image_embedding(request: Request, visual_manager: VisualM multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - result_embeddings = VisualManager.generate(multimodal_params, sampling_params, request=request) - + await httpserver_manager.generate(sampling_params, multimodal_params, request=request) # 5. Return JSON result - return {"embeddings": result_embeddings} + print("embedding OK") + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 9483532a4..40594f39f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -420,9 +420,9 @@ def visual_only_start(args): set_unique_server_name(args) if args.run_mode != "visual_only": return - + already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] can_use_ports = alloc_can_use_network_port( - num=5 + args.visual_dp * args.visual_tp, + num=5 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -446,6 +446,7 @@ def visual_only_start(args): args.audio_port = audio_port args.cache_port = cache_port args.metric_port = metric_port + args.visual_model_rpc_ports = visual_model_tp_ports logger.info(f"all start args:{args}") diff --git a/lightllm/server/httpserver_for_visual_only/__init__.py b/lightllm/server/httpserver_for_visual_only/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightllm/server/httpserver_for_visual_only/manager.py b/lightllm/server/httpserver_for_visual_only/manager.py new file mode 100644 index 000000000..2b0a6bd4a --- /dev/null +++ b/lightllm/server/httpserver_for_visual_only/manager.py @@ -0,0 +1,461 @@ +import sys +import zmq +import zmq.asyncio +import asyncio +import uvloop +import rpyc +import time +import copy +import hashlib +import datetime +import pickle +from frozendict import frozendict + +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +from typing import Union, List, Tuple, Dict, Optional +from fastapi import Request +from ..tokenizer import get_tokenizer +from ..pd_io_struct import NodeRole +from ..embed_cache.utils import get_shm_name_data, create_shm +from ..multimodal_params import AudioItem, MultimodalParams, ImageItem +from ..req_id_generator import ReqIDGenerator +from lightllm.server.core.objs import Req, FinishStatus +from lightllm.server.core.objs import SamplingParams +from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE +from lightllm.server.core.objs.io_objs import GroupReqObjs +from lightllm.server.core.objs.shm_req_manager import ShmReqManager +from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.utils.log_utils import init_logger +from lightllm.server.metrics.manager import MetricClient +from lightllm.utils.statics_utils import MovingAverage +from lightllm.utils.config_utils import get_vocab_size +from lightllm.utils.envs_utils import get_unique_server_name +from rpyc.utils.classic import obtain + +logger = init_logger(__name__) + + +class HttpServerManagerForVisualOnly: + def __init__( + self, + args, + cache_port, + visual_port, + # metric_port + ): + self.args = args + context = zmq.asyncio.Context(2) + + self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) + self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0)) + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + print("self.send_to_visual") + self.shm_req_manager = ShmReqManager() + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) + self.max_req_total_len = args.max_req_total_len + self.id_gen = ReqIDGenerator() + # self.metric_client = MetricClient(metric_port) + # 有的模型的vocab size 读取tokenizer和config.json中不一致 + self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) + return + + async def _alloc_resource(self, items, md5sums, token_nums, datas): + + while True: + records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + + if records is None: + await asyncio.sleep(0.1) + continue + + uid_list = [] + for item, rec in zip(items, records): + item.uuid = rec["id"] + item.token_id = rec["token_id"] + item.token_num = rec["token_num"] + uid_list.append(rec["id"]) + + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) + update_data_ids = [] + + for uid, ready, data in zip(uid_list, ready_flags, datas): + if not ready: + create_shm(get_shm_name_data(uid), data) + update_data_ids.append(uid) + + if update_data_ids: + self.cache_client.root.set_items_data(update_data_ids) + return + + async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): + # 这里的锁是为了 防止多个含有多张图片的请求 同时申请的record数量 大于cache_capacity,从而造成死锁的问题。 + # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, + # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 + print("in alloc_multimodal_resoueces") + async with self._resource_lock: + items, md5sums, tokens_nums, datas = [], [], [], [] + for img in multimodal_params.images: + self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) + data = img.read() + # must after init_imageitem_extral_params + token_num = self.tokenizer.get_image_token_length(img) + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) + md5sums.append(md5sum) + tokens_nums.append(token_num) + datas.append(data) + items.append(img) + for audio in multimodal_params.audios: + self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) + data = audio.read() + token_num = self.tokenizer.get_audio_token_length(audio) + md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) + md5sums.append(md5sum) + tokens_nums.append(token_num) + datas.append(data) + items.append(audio) + + await self._alloc_resource(items, md5sums, tokens_nums, datas) + return + + async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): + if multimodal_params is not None: + ids_to_release = [] + for img in multimodal_params.images: + if img.uuid is not None: + ids_to_release.append(img.uuid) + # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 + img.uuid = None + img.token_id = None + img.token_num = None + for audio in multimodal_params.audios: + if audio.uuid is not None: + ids_to_release.append(audio.uuid) + # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 + audio.uuid = None + audio.token_id = None + audio.token_num = None + if ids_to_release: + self.cache_client.root.release(ids_to_release) + return + + def tokens(self, multimodal_params, samping_params: SamplingParams, kwargs=None): + kwargs = {} if kwargs is None else kwargs + image_tokens = 0 + img_count = 0 + audio_tokens = 0 + audio_count = 0 + for img in multimodal_params.images: + img_count += 1 + self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) + image_tokens += self.tokenizer.get_image_token_length(img) + for audio in multimodal_params.audios: + audio_count += 1 + self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) + audio_tokens += self.tokenizer.get_audio_token_length(audio) + return image_tokens + img_count + audio_tokens + audio_count + + async def loop_for_request(self): + assert self.args.node_rank > 0 + while True: + ( + sampling_params, + multimodal_params, + ) = await self.multinode_req_manager.recv_pyobj() + results_generator = self.generate(sampling_params, multimodal_params, None) + + async def generate_wrapper(results_generator): + async for _, _, _, _ in results_generator: + pass + + asyncio.create_task(generate_wrapper(results_generator)) + return + + def alloc_req_id(self, sampling_params, is_health_req: bool = False): + # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 + # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 + # health 请求 request_id 为负数,直接返回 + if is_health_req: + return sampling_params.group_request_id + group_request_id = self.id_gen.generate_id() + + sampling_params.group_request_id = group_request_id + return group_request_id + + async def generate( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + print(" in generate") + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + original_multimodal_params = None + await multimodal_params.verify_and_preload(request) + + # 记录请求到达的相关信息 + await self._log_req_header(request_headers, group_request_id) + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + + # 监控 + # if group_request_id > 0: + # self.metric_client.counter_inc("lightllm_request_count") + # self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) + # self.metric_client.histogram_observe("lightllm_request_max_new_tokens", + # sampling_params.max_new_tokens) + # prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params) + + # 申请资源并存储 + alloced_req_indexes = [] + while len(alloced_req_indexes) < sampling_params.n: + alloc_req_index = await self.shm_req_manager.async_alloc_req_index() + sleep_time = 0.1 + while alloc_req_index is None: + await asyncio.sleep(sleep_time) + sleep_time *= 1.1 + sleep_time = min(1, sleep_time) + + alloc_req_index = await self.shm_req_manager.async_alloc_req_index() + alloced_req_indexes.append(alloc_req_index) + req_objs = [] + for i, req_index in enumerate(alloced_req_indexes): + req_obj = await self.shm_req_manager.async_get_req_obj_by_index(req_index) + req_obj.init( + group_request_id + i, + # 随便写的,后面改掉 + [24, 67], + sampling_params, + self.tokenizer, + chunked_prefill_size=self.args.chunked_prefill_size, + ) + req_objs.append(req_obj) + + req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) + self.req_id_to_out_inf[group_request_id] = req_status + + await self.transfer_to_visual(sampling_params, original_multimodal_params, req_status.group_req_objs) + + # results_generator = self._wait_to_token_package( + # start_time, + # group_request_id, + # sampling_params, + # req_status, + # request, + # ) + # async for sub_req_id, request_output, metadata, finish_status in results_generator: + # yield sub_req_id, request_output, metadata, finish_status + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + # error need to release multimodel resources. + # 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放 + # 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环 + # 进行回收。 + if group_request_id not in self.req_id_to_out_inf: + await self._release_multimodal_resources(multimodal_params) + await self.abort(group_request_id) + raise e + return + + async def _log_req_header(self, request_headers, group_request_id: int): + + x_request_id = request_headers.get("X-Request-Id", "") + x_session_id = request_headers.get("X-Session-Id", "") + + format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"recieved req X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_in_time} " + f"lightllm_req_id:{group_request_id} " + ) + return + + async def transfer_to_visual( + self, + sampling_params: SamplingParams, + original_multimodal_params: MultimodalParams, + group_req_objs: Optional[GroupReqObjs] = None, + ): + await self.send_to_visual.send_pyobj( + group_req_objs.to_group_req_index(), + protocol=pickle.HIGHEST_PROTOCOL, + ) + return + + async def _wait_to_token_package( + self, + start_time, + group_request_id: int, + sampling_params: SamplingParams, + req_status: "ReqStatus", + request: Request, + ): + + event = req_status.event + unfinished_count = sampling_params.best_of + out_token_counter = 0 + first_token_cost_ms = sys.float_info.max + is_first_token = True + + while True: + try: + await asyncio.wait_for(event.wait(), timeout=5) + except asyncio.TimeoutError: + pass + + if not self.disable_abort and request is not None and await request.is_disconnected(): + await self.abort(group_request_id) + raise Exception(f"req_id {group_request_id} disconnected") + + async with req_status.lock: + event.clear() + if len(req_status.out_token_info_list) == 0: + continue + + for sub_req_id, out_str, metadata, finish_status in req_status.out_token_info_list: + # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 + # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode + + prompt_cache_len = metadata.pop("prompt_cache_len", 0) + if is_first_token: + first_token_cost_ms = (time.time() - start_time) * 1000 + is_first_token = False + self.first_time_costs.add(first_token_cost_ms) + + out_token_counter += 1 + + # update inference timemark + self.latest_success_infer_time_mark.set_value(int(time.time())) + + yield sub_req_id, out_str, metadata, finish_status + # 如果有子请求完成,就更新计数 + if finish_status.is_finished(): + unfinished_count -= 1 + + if unfinished_count == 0: + total_cost_time_ms = (time.time() - start_time) * 1000 + mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter + self.per_token_costs.add(mean_per_token_cost_time_ms) + x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" + x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" + + mtp_avg_token_per_step = out_token_counter / max( + (out_token_counter - metadata["mtp_accepted_token_num"]), 1 + ) + format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_start_time} " + f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " + f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter} " + f"mean_per_token_cost_time: {mean_per_token_cost_time_ms}ms " + f"prompt_cache_len:{prompt_cache_len} " + f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " + ) + if group_request_id < 0: + # health 探测请求,不记录日志和监控 + return + self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len) + self.metric_client.histogram_observe( + "lightllm_request_inference_duration", total_cost_time_ms / 1000.0 + ) + self.metric_client.histogram_observe( + "lightllm_request_mean_time_per_token_duration", mean_per_token_cost_time_ms / 1000.0 + ) + self.metric_client.histogram_observe( + "lightllm_request_first_token_duration", first_token_cost_ms / 1000.0 + ) + self.metric_client.histogram_observe("lightllm_request_generated_tokens", out_token_counter) + self.metric_client.counter_inc("lightllm_request_success") + + return + req_status.out_token_info_list.clear() + return + + async def abort(self, group_req_id: int): + req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id, None) + if req_status is None: + logger.warning(f"aborted group_request_id {group_req_id} not exist") + return + + group_req_objs: GroupReqObjs = req_status.group_req_objs + for req in group_req_objs.shm_req_objs: + req.is_aborted = True + logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") + return + + async def recycle_resource_loop(self): + pre_time_mark = time.time() + + while True: + + try: + await asyncio.wait_for(self.recycle_event.wait(), timeout=0.02) + except asyncio.TimeoutError: + pass + self.recycle_event.clear() + + # 清理已经处理完的可以删除的请求 + release_req_status: List[ReqStatus] = [] + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is not None and req_status.can_release(): + release_req_status.append(req_status) + + for req_status in release_req_status: + self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) + for req in req_status.group_req_objs.shm_req_objs: + await self.shm_req_manager.async_put_back_req_obj(req) + await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) + await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) + + # 先保留这个关键得日志,用于方便定位重构中的问题。 + if time.time() - pre_time_mark > 120: + pre_time_mark = time.time() + for group_req_id_ in list(self.req_id_to_out_inf.keys()): + req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id_, None) + if req_status is None: + continue + + logger.info( + f"left req id {req_status.group_req_objs.group_req_id}" + f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " + f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" + ) + return + + async def handle_loop(self): + self.recycle_event = asyncio.Event() + asyncio.create_task(self.recycle_resource_loop()) + + return + + +class ReqStatus: + def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: + self.lock = asyncio.Lock() + self.event = asyncio.Event() + self.group_req_objs = GroupReqObjs( + group_req_id=group_request_id, + multimodal_params=multimodal_params, + shm_req_objs=req_objs, + time_mark=start_time, + ) + self.out_token_info_list = [] + + def can_release(self): + for req in self.group_req_objs.shm_req_objs: + if not req.can_release(): + return False + return True diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 34df571f6..5a9b7e191 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -118,7 +118,8 @@ async def loop_for_fwd(self): # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 # 需要一些一致的流程来保证不出现异步问题。 - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + if not self.visual_only: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) continue multimodal_params = group_req_indexes.multimodal_params @@ -134,20 +135,23 @@ async def loop_for_fwd(self): await self.infer_imgs(images_need_infer) images_need_infer = [] for _group_req_indexes in processing_group_reqs: - self.send_to_next_module.send_pyobj( - _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL - ) + if not self.visual_only: + self.send_to_next_module.send_pyobj( + _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL + ) processing_group_reqs = [] if len(images_need_infer) == 0: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + if not self.visual_only: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) else: processing_group_reqs.append(group_req_indexes) if len(images_need_infer) > 0: await self.infer_imgs(images_need_infer) for _group_req_indexes in processing_group_reqs: - self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + if not self.visual_only: + self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) processing_group_reqs = [] images_need_infer = [] diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index a25065e42..0abf983b0 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -74,7 +74,7 @@ def exposed_init_model(self, kvargs): self.model = Gemma3VisionModel() else: raise Exception(f"can not support {self.model_type} now") - + print("begin load visual model weight") self.model.load_model(weight_dir) self.model = self.model.cuda() except Exception as e: @@ -98,6 +98,7 @@ def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) all_img_embeds = all_img_embeds.to(torch.device("cpu")) + print(f"all_img_embeds is {all_img_embeds}") if self.tp_rank_id == 0: ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) From 68dd163cf353a4712c5002314652e75217a5b75a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 20 Aug 2025 06:36:50 +0000 Subject: [PATCH 03/40] 0820-add-llm-only --- .../qwen_vl/layer_infer/pre_layer_infer.py | 11 ++- lightllm/server/api_cli.py | 3 +- lightllm/server/api_http.py | 3 +- lightllm/server/api_start.py | 19 +++-- .../embed_cache/impl/naive_memory_cache.py | 11 ++- lightllm/server/embed_cache/manager.py | 4 + lightllm/server/embed_cache/utils.py | 42 ++++++++++ lightllm/server/httpserver/manager.py | 78 +++++++++++++++---- .../httpserver_for_visual_only/manager.py | 16 ++-- lightllm/server/multimodal_params.py | 2 + lightllm/server/pd_io_struct.py | 5 +- lightllm/server/visualserver/manager.py | 27 ++++--- .../visualserver/model_infer/model_rpc.py | 17 +++- 13 files changed, 187 insertions(+), 51 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index b5b31a413..dbf2c8ea0 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -6,9 +6,10 @@ from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.utils.infer_utils import mark_cost_time -from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed +from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed, read_afs from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb from lightllm.distributed.communication_op import all_reduce +from lightllm.utils.envs_utils import get_env_start_args """ @@ -29,6 +30,7 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) + self.args = get_env_start_args() return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -50,8 +52,11 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei # skip the same image if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue - # pull the img_embeds by uid from shm - data = read_shm(get_shm_name_embed(img["uuid"])) + # pull the img_embeds by uid from shm or afs + if self.args.run_mode == "llm_only": + data = read_afs(get_shm_name_embed(img["uuid"])) + else: + data = read_shm(get_shm_name_embed(img["uuid"])) img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index b7f5e7707..825c6cbb3 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only"], + choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only", "llm_only"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -314,6 +314,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--metric_gateway", type=str, default=None, help="address for collecting monitoring metrics") parser.add_argument("--job_name", type=str, default="lightllm", help="job name for monitor") + parser.add_argument("--visual_embed_path", type=str, default=None, help="path for vit embed") parser.add_argument( "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" ) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index bd5c23d88..857a4cf37 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -357,7 +357,6 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - if g_objs.httpserver_manager is not None: - loop.create_task(g_objs.httpserver_manager.handle_loop()) + loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 40594f39f..40bbbbf3c 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -59,7 +59,7 @@ def signal_handler(sig, frame): return -def normal_or_p_d_start(args): +def check_and_set_args(args): set_unique_server_name(args) if args.enable_mps: @@ -67,7 +67,7 @@ def normal_or_p_d_start(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode"]: + if args.run_mode not in ["normal", "prefill", "decode", "llm_only", "visual_only"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] @@ -138,6 +138,11 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + # visual_only模式下才需要设置visual_embed_path + if args.visual_embed_path is not None: + assert ( + args.run_mode == "visual_only" or args.run_mode == "llm_only" + ), "only visual_only or llm_only mode need visual_embed_path" # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -202,6 +207,10 @@ def normal_or_p_d_start(args): args.data_type = get_dtype(args.model_dir) assert args.data_type in ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"] + +def normal_or_p_d_start(args): + + check_and_set_args(args) already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] if args.run_mode == "decode": already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] @@ -269,7 +278,7 @@ def normal_or_p_d_start(args): ], start_args=[(cache_port, args)], ) - if args.enable_multimodal_audio: + if args.enable_multimodal_audio and args.run_mode != "llm_only": from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( @@ -289,7 +298,7 @@ def normal_or_p_d_start(args): ], ) - else: + elif args.run_mode != "llm_only": process_manager.start_submodule_processes( start_funcs=[ start_visual_process, @@ -417,7 +426,7 @@ def pd_master_start(args): def visual_only_start(args): - set_unique_server_name(args) + check_and_set_args(args) if args.run_mode != "visual_only": return already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 5477be22b..323fdf420 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -7,7 +7,7 @@ import time from collections import deque import multiprocessing.shared_memory as shm -from ..utils import get_shm_name_data, get_shm_name_embed, free_shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, free_afs from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -65,7 +65,7 @@ def _check_and_set_new_id_range(self, alloced_token_num): except BaseException as e: logger.exception(str(e)) time.sleep(3) - return + return self.token_id_range_start def _clear(self, free_max_count: int): deleted = 0 @@ -77,7 +77,10 @@ def _clear(self, free_max_count: int): if record.data: free_shm(get_shm_name_data(id)) if record.embed: - free_shm(get_shm_name_embed(id)) + if self.args.run_mode == "visual_only": + free_afs(get_shm_name_embed(id)) + else: + free_shm(get_shm_name_embed(id)) del self._md5_to_record[record.md5sum] del self._records[id] self.occupied -= 1 @@ -103,7 +106,7 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l rec.visittime = now rec.ref += 1 else: - uid_int = uuid.uuid1().int + uid_int = int(md5sum, 16) self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 557bdcf3b..8059bfb2a 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -22,6 +22,10 @@ def on_disconnect(self, conn): # (to finalize the service, if needed) pass + def exposed__check_and_set_new_id_range(self, token_num: int) -> int: + token_num = obtain(token_num) + return self._impl._check_and_set_new_id_range(token_num) + def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]: md5sum_list = obtain(md5sum_list) token_num_list = obtain(token_num_list) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 6df031293..b1227ec6b 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,7 +1,11 @@ +import os +import time import torch import numpy as np from io import BytesIO +from pathlib import Path import multiprocessing.shared_memory as shm +from lightllm.utils.envs_utils import get_env_start_args def tensor2bytes(t: torch.Tensor): @@ -35,21 +39,59 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) +def create_afs(name, data): + try: + data_size = len(data) + path = os.path.join(get_env_start_args().visual_embed_path, name) + with open(path, "xb") as f: + mem_view = memoryview(data) + f.write(mem_view[:data_size]) + f.flush() + os.fsync(f.fileno()) + except FileExistsError: + print("Warning create afs {} failed because of FileExistsError!".format(name)) + + def read_shm(name): shared_memory = shm.SharedMemory(name=name) data = shared_memory.buf.tobytes() return data +def read_afs(name: str, base_dir: str = "/mtc/sangchengmeng/afs") -> bytes: + + path = Path(base_dir) / name + return path.read_bytes() + + def free_shm(name): shared_memory = shm.SharedMemory(name=name) shared_memory.close() shared_memory.unlink() +def free_afs(name): + path = os.path.join(get_env_start_args().visual_embed_path, name) + try: + os.remove(path) + except FileNotFoundError: + print("Warning free afs {} failed because of FileNotFoundError!".format(name)) + return + except PermissionError as e: + print("Warning free afs {} failed due to PermissionError: {}".format(name, e)) + return + + def get_shm_name_data(uid): return str(uid) + "-data" def get_shm_name_embed(uid): return str(uid) + "-embed" + + +def afs_embed_exists(md5sum: str): + uid_int = int(md5sum, 16) + filename = f"{uid_int}-embed" + fullpath = os.path.join(get_env_start_args().visual_embed_path, filename) + return True if os.path.isfile(fullpath) else False diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index de552d80c..92e143c5e 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -16,7 +16,7 @@ from fastapi import Request from ..tokenizer import get_tokenizer from ..pd_io_struct import NodeRole -from ..embed_cache.utils import get_shm_name_data, create_shm +from ..embed_cache.utils import get_shm_name_data, create_shm, afs_embed_exists from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator from .async_queue import AsyncQueue @@ -83,8 +83,9 @@ def __init__( self.enable_multimodal = enable_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + if self.args.run_mode != "llm_only": + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") self.shm_req_manager = ShmReqManager() @@ -101,7 +102,7 @@ def __init__( self.metric_client = MetricClient(metric_port) self.pd_mode: NodeRole = NodeRole(self.args.run_mode) - assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL] + assert self.pd_mode in [NodeRole.P, NodeRole.D, NodeRole.NORMAL, NodeRole.LLM_ONLY] self.id_gen = ReqIDGenerator() self.first_time_costs = MovingAverage() self.per_token_costs = MovingAverage() @@ -155,7 +156,10 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), + ) md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) @@ -164,7 +168,10 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) @@ -173,6 +180,47 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, await self._alloc_resource(items, md5sums, tokens_nums, datas) return + async def _wait_for_afs_embed(self, md5sum_hex: str, interval_sec: float = 0.01) -> None: + while not afs_embed_exists(md5sum_hex): + await asyncio.sleep(interval_sec) + + async def _get_image_embedding_from_afs(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): + for img in multimodal_params.images: + self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) + data = img.read() + # must after init_imageitem_extral_params + token_num = self.tokenizer.get_image_token_length(img) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), + ) + uid_int = int(md5sum, 16) + if not afs_embed_exists(md5sum): + await self._wait_for_afs_embed(md5sum) + img.uuid = uid_int + img.afs_embed = True + token_id_range_start = self.cache_client.root._check_and_set_new_id_range(token_num) + img.token_id = token_id_range_start + img.token_num = token_num + + for audio in multimodal_params.audios: + self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) + data = audio.read() + token_num = self.tokenizer.get_audio_token_length(audio) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) + if not afs_embed_exists(md5sum): + await self._wait_for_afs_embed(md5sum) + uid_int = int(md5sum, 16) + audio.uuid = uid_int + audio.afs_embed = True + token_id_range_start = self.cache_client.root._check_and_set_new_id_range(token_num) + audio.token_id = token_id_range_start + audio.token_num = token_num + return + async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): # 只有 P 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): @@ -236,7 +284,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): # health 请求 request_id 为负数,直接返回 if is_health_req: return sampling_params.group_request_id - if self.pd_mode == NodeRole.NORMAL: + if self.pd_mode == NodeRole.NORMAL or self.pd_mode == NodeRole.LLM_ONLY: if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() else: @@ -334,7 +382,7 @@ async def generate( # 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放 # 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环 # 进行回收。 - if group_request_id not in self.req_id_to_out_inf: + if group_request_id not in self.req_id_to_out_inf and self.args.run_mode != "llm_only": await self._release_multimodal_resources(multimodal_params) await self.abort(group_request_id) raise e @@ -363,7 +411,10 @@ async def _encode( ), "too many multimodal items!" if multimodal_params.audios: assert self.args.enable_multimodal_audio, "audio multimodal not enabled" - await self._alloc_multimodal_resources(multimodal_params, sampling_params) + if self.args.run_mode == "llm_only": + await self._get_image_embedding_from_afs(multimodal_params, sampling_params) + else: + await self._alloc_multimodal_resources(multimodal_params, sampling_params) prompt_ids = self.tokenizer.encode( prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens ) @@ -438,7 +489,7 @@ async def transfer_to_next_module( ): if self.pd_mode == NodeRole.P: - if self.enable_multimodal: + if self.enable_multimodal and self.args.run_mode != "llm_only": self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -458,8 +509,8 @@ async def transfer_to_next_module( ) return - if self.pd_mode == NodeRole.NORMAL: - if self.enable_multimodal: + if self.pd_mode == NodeRole.NORMAL or self.pd_mode == NodeRole.LLM_ONLY: + if self.enable_multimodal and self.args.run_mode != "llm_only": self.send_to_visual.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, @@ -608,7 +659,8 @@ async def recycle_resource_loop(self): for req in req_status.group_req_objs.shm_req_objs: await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) - await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) + if self.args.run_mode != "llm_only": + await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) # 先保留这个关键得日志,用于方便定位重构中的问题。 if time.time() - pre_time_mark > 120: diff --git a/lightllm/server/httpserver_for_visual_only/manager.py b/lightllm/server/httpserver_for_visual_only/manager.py index 2b0a6bd4a..5d6fe315c 100644 --- a/lightllm/server/httpserver_for_visual_only/manager.py +++ b/lightllm/server/httpserver_for_visual_only/manager.py @@ -5,6 +5,7 @@ import uvloop import rpyc import time +import json import copy import hashlib import datetime @@ -16,7 +17,7 @@ from fastapi import Request from ..tokenizer import get_tokenizer from ..pd_io_struct import NodeRole -from ..embed_cache.utils import get_shm_name_data, create_shm +from ..embed_cache.utils import get_shm_name_data, create_shm, afs_embed_exists from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator from lightllm.server.core.objs import Req, FinishStatus @@ -52,7 +53,6 @@ def __init__( self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.send_to_visual = context.socket(zmq.PUSH) self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") - print("self.send_to_visual") self.shm_req_manager = ShmReqManager() self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) @@ -95,7 +95,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 这里的锁是为了 防止多个含有多张图片的请求 同时申请的record数量 大于cache_capacity,从而造成死锁的问题。 # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 - print("in alloc_multimodal_resoueces") async with self._resource_lock: items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: @@ -103,7 +102,10 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params))) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), + ) md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) @@ -112,7 +114,10 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) data = audio.read() token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params))) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), + ) md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) @@ -192,7 +197,6 @@ async def generate( request: Request, is_health_req: bool = False, ) -> Tuple[int, str, dict, FinishStatus]: - print(" in generate") start_time = time.time() request_headers = request.headers if request is not None else {} group_request_id = self.alloc_req_id(sampling_params, is_health_req) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index e3c1d19d2..1171c7c32 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -24,6 +24,7 @@ def __init__(self, **kwargs): self.token_num = None # the audio length self.audio_length = None + self.afs_embed = False self._preload_data = None self.extra_params = {} @@ -77,6 +78,7 @@ def __init__(self, **kwargs): self.token_num = None self.image_w = 0 self.image_h = 0 + self.afs_embed = False self._preload_data = None self.extra_params = {} diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 414e3c74a..b713c08a3 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -15,6 +15,7 @@ class NodeRole(enum.Enum): D = "decode" NORMAL = "normal" PD_MASTER = "pd_master" + LLM_ONLY = "llm_only" def is_D(self): return self == NodeRole.D @@ -23,10 +24,10 @@ def is_P(self): return self == NodeRole.P def is_normal(self): - return self == NodeRole.NORMAL + return (self == NodeRole.NORMAL) or (self == NodeRole.LLM_ONLY) def is_P_or_NORMAL(self): - return (self == NodeRole.P) or (self == NodeRole.NORMAL) + return (self == NodeRole.P) or (self == NodeRole.NORMAL) or (self == NodeRole.LLM_ONLY) def is_P_or_D(self): return (self == NodeRole.P) or (self == NodeRole.D) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 5a9b7e191..eb7620f9b 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -102,6 +102,17 @@ async def infer_imgs(self, images: List[ImageItem]): await asyncio.gather(*tasks) return + async def _send_to_next_module_or_release(self, group_req_indexes: GroupReqIndexes): + if self.visual_only: + for idx in group_req_indexes.shm_req_indexes: + shm_req = self.shm_req_manager.get_req_obj_by_index(idx) + shm_req.can_released_mark = True + shm_req.finish_status.set_status(1) + self.shm_req_manager.put_back_req_obj(shm_req) + logger.info(f"router release req id {shm_req.request_id}") + else: + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + async def loop_for_fwd(self): while True: if len(self.waiting_reqs) == 0: @@ -118,13 +129,12 @@ async def loop_for_fwd(self): # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 # 需要一些一致的流程来保证不出现异步问题。 - if not self.visual_only: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + await self._send_to_next_module_or_release(group_req_indexes) continue multimodal_params = group_req_indexes.multimodal_params - img_uuids = [img.uuid for img in multimodal_params.images] + img_uuids = [img.uuid for img in multimodal_params.images if not img.afs_embed] ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) for img, ready in zip(multimodal_params.images, ready_image): @@ -135,23 +145,18 @@ async def loop_for_fwd(self): await self.infer_imgs(images_need_infer) images_need_infer = [] for _group_req_indexes in processing_group_reqs: - if not self.visual_only: - self.send_to_next_module.send_pyobj( - _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL - ) + await self._send_to_next_module_or_release(group_req_indexes) processing_group_reqs = [] if len(images_need_infer) == 0: - if not self.visual_only: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + await self._send_to_next_module_or_release(group_req_indexes) else: processing_group_reqs.append(group_req_indexes) if len(images_need_infer) > 0: await self.infer_imgs(images_need_infer) for _group_req_indexes in processing_group_reqs: - if not self.visual_only: - self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) + await self._send_to_next_module_or_release(group_req_indexes) processing_group_reqs = [] images_need_infer = [] diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 0abf983b0..c219e3685 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -17,7 +17,14 @@ from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel -from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed +from lightllm.server.embed_cache.utils import ( + tensor2bytes, + read_shm, + create_shm, + create_afs, + get_shm_name_data, + get_shm_name_embed, +) from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.dist_utils import init_vision_distributed_env @@ -31,6 +38,7 @@ def exposed_init_model(self, kvargs): import torch import torch.distributed as dist + self.args = get_env_start_args() self.vit_dp = kvargs["vit_dp"] self.vit_tp = kvargs["vit_tp"] self.dp_rank_id = kvargs["dp_rank_id"] @@ -74,7 +82,6 @@ def exposed_init_model(self, kvargs): self.model = Gemma3VisionModel() else: raise Exception(f"can not support {self.model_type} now") - print("begin load visual model weight") self.model.load_model(weight_dir) self.model = self.model.cuda() except Exception as e: @@ -98,7 +105,6 @@ def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) all_img_embeds = all_img_embeds.to(torch.device("cpu")) - print(f"all_img_embeds is {all_img_embeds}") if self.tp_rank_id == 0: ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) @@ -109,7 +115,10 @@ def exposed_encode(self, images: List[ImageItem]): uid = uuids[i] start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - create_shm(get_shm_name_embed(uid), cur_embed_bytes) + if self.args.run_mode == "visual_only": + create_afs(get_shm_name_embed(uid), cur_embed_bytes) + else: + create_shm(get_shm_name_embed(uid), cur_embed_bytes) ids_to_set.append(uid) if ids_to_set: self.cache_client.root.set_items_embed(ids_to_set) From 9aaf63beb5c773a51dd3e83fa868cc702511fd18 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 20 Aug 2025 07:11:43 +0000 Subject: [PATCH 04/40] 0820 --- lightllm/server/api_http.py | 15 +++ lightllm/server/api_server.py | 4 +- lightllm/server/api_start.py | 99 +++++++++++++++++++ .../embed_cache/impl/naive_memory_cache.py | 2 +- lightllm/server/httpserver/manager.py | 18 ++-- .../httpserver_for_visual_only/manager.py | 2 +- 6 files changed, 131 insertions(+), 9 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 857a4cf37..ef28db373 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -100,6 +100,21 @@ def set_args(self, args): visual_port=args.visual_port, # metric_port=args.metric_port, ) + elif args.run_mode == "llm_only": + init_tokenizer(args) # for openai api + SamplingParams.load_generation_cfg(args.model_dir) + self.metric_client = MetricClient(args.metric_port) + self.httpserver_manager = HttpServerManager( + args, + router_port=args.router_port, + cache_port=None, + detokenization_pub_port=args.detokenization_pub_port, + visual_port=None, + enable_multimodal=args.enable_multimodal, + metric_port=args.metric_port, + ) + dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 + self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 6ffe6b31d..51cd53da0 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,7 +5,7 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start, llm_only_start if args.run_mode == "pd_master": pd_master_start(args) @@ -13,5 +13,7 @@ config_server_start(args) elif args.run_mode == "visual_only": visual_only_start(args) + elif args.run_mode == "llm_only": + llm_only_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 40bbbbf3c..e3adf18dd 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -362,6 +362,105 @@ def normal_or_p_d_start(args): return +def llm_only_start(args): + + check_and_set_args(args) + already_uesd_ports = [args.nccl_port, args.port] + + # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 + # 捕获到端口设置冲突的问题 + ports_locker = PortLocker(already_uesd_ports) + ports_locker.lock_port() + + node_world_size = args.tp // args.nnodes + can_use_ports = alloc_can_use_network_port(num=4 + node_world_size, used_nccl_ports=already_uesd_ports) + logger.info(f"alloced ports: {can_use_ports}") + ( + router_port, + detokenization_port, + detokenization_pub_port, + metric_port, + ) = can_use_ports[0:4] + can_use_ports = can_use_ports[4:] + + # 将申请好的端口放入args参数中 + args.router_port = router_port + args.detokenization_port = detokenization_port + args.detokenization_pub_port = detokenization_pub_port + args.metric_port = metric_port + + # 申请在 p d 分离模式下,会用的端口 + args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] + # p d 分离模式下用于标识节点的id + args.pd_node_id = uuid.uuid4().int + # p 节点用来建立torch kv 传输分布组的可用端口范围 + args.pd_p_allowed_port_min = 20000 + args.pd_p_allowed_port_max = 30000 + + # p d 分离模式下,decode节点的调度间隙是0 + if args.run_mode == "decode": + args.router_max_wait_tokens = 0 + + send_and_receive_node_ip(args) # 多机用于收发node ip + set_env_start_args(args) + logger.info(f"all start args:{args}") + + ports_locker.release_port() + + process_manager.start_submodule_processes( + start_funcs=[ + start_metric_manager, + ], + start_args=[(metric_port, args)], + ) + + process_manager.start_submodule_processes( + start_funcs=[start_router_process, start_detokenization_process], + start_args=[ + (args, router_port, detokenization_port, metric_port), + (args, detokenization_port, detokenization_pub_port), + ], + ) + + # 启动 gunicorn + command = [ + "gunicorn", + "--workers", + f"{args.httpserver_workers}", + "--worker-class", + "uvicorn.workers.UvicornWorker", + "--bind", + f"{args.host}:{args.port}", + "--log-level", + "info", + "--access-logfile", + "-", + "--error-logfile", + "-", + "lightllm.server.api_http:app", + "--timeout", + f"{get_lightllm_gunicorn_time_out_seconds()}", + "--keep-alive", + f"{get_lightllm_gunicorn_keep_alive()}", + ] + + # 启动子进程 + http_server_process = subprocess.Popen(command) + + if "s3://" in args.model_dir: + from lightllm.utils.petrel_helper import s3_model_clear + + s3_model_clear(args.model_dir) + + if args.health_monitor: + from lightllm.server.health_monitor.manager import start_health_check_process + + process_manager.start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) + setup_signal_handlers(http_server_process, process_manager) + http_server_process.wait() + return + + def pd_master_start(args): set_unique_server_name(args) if args.run_mode != "pd_master": diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 323fdf420..d263b3967 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -65,7 +65,7 @@ def _check_and_set_new_id_range(self, alloced_token_num): except BaseException as e: logger.exception(str(e)) time.sleep(3) - return self.token_id_range_start + return def _clear(self, free_max_count: int): deleted = 0 diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 92e143c5e..506e48f19 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -81,11 +81,13 @@ def __init__( ) self.enable_multimodal = enable_multimodal - if self.enable_multimodal: + if self.enable_multimodal and self.args.run_mode != "llm_only": self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - if self.args.run_mode != "llm_only": - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + self.send_to_visual = context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + + self.token_id_range_start = 100000000 + self.token_id_range_end = 2 ** 63 - 1 self.shm_req_manager = ShmReqManager() @@ -115,6 +117,10 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return + async def _check_and_set_new_id_range(self, token_num): + assert self.token_id_range_start + token_num < self.token_id_range_end + self.token_id_range_start += token_num + async def _alloc_resource(self, items, md5sums, token_nums, datas): while True: @@ -199,7 +205,7 @@ async def _get_image_embedding_from_afs(self, multimodal_params: MultimodalParam await self._wait_for_afs_embed(md5sum) img.uuid = uid_int img.afs_embed = True - token_id_range_start = self.cache_client.root._check_and_set_new_id_range(token_num) + token_id_range_start = self.token_id_range_start img.token_id = token_id_range_start img.token_num = token_num @@ -216,7 +222,7 @@ async def _get_image_embedding_from_afs(self, multimodal_params: MultimodalParam uid_int = int(md5sum, 16) audio.uuid = uid_int audio.afs_embed = True - token_id_range_start = self.cache_client.root._check_and_set_new_id_range(token_num) + token_id_range_start = self.token_id_range_start audio.token_id = token_id_range_start audio.token_num = token_num return diff --git a/lightllm/server/httpserver_for_visual_only/manager.py b/lightllm/server/httpserver_for_visual_only/manager.py index 5d6fe315c..e67071cab 100644 --- a/lightllm/server/httpserver_for_visual_only/manager.py +++ b/lightllm/server/httpserver_for_visual_only/manager.py @@ -238,7 +238,7 @@ async def generate( req_obj.init( group_request_id + i, # 随便写的,后面改掉 - [24, 67], + [21456], sampling_params, self.tokenizer, chunked_prefill_size=self.args.chunked_prefill_size, From 66d0c105179fe33fc0a22665d0ee9304ee57d7db Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 20 Aug 2025 08:06:33 +0000 Subject: [PATCH 05/40] 0820-del-metric --- lightllm/server/api_http.py | 4 +- lightllm/server/api_start.py | 7 + .../httpserver_for_visual_only/manager.py | 125 +----------------- 3 files changed, 13 insertions(+), 123 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index ef28db373..4fd4edb3f 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -93,12 +93,12 @@ def set_args(self, args): metric_port=args.metric_port, ) elif args.run_mode == "visual_only": - # self.metric_client = MetricClient(args.metric_port) + self.metric_client = MetricClient(args.metric_port) self.httpserver_manager = HttpServerManagerForVisualOnly( args, cache_port=args.cache_port, visual_port=args.visual_port, - # metric_port=args.metric_port, + metric_port=args.metric_port, ) elif args.run_mode == "llm_only": init_tokenizer(args) # for openai api diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index e3adf18dd..3ae0375c5 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -560,6 +560,13 @@ def visual_only_start(args): set_env_start_args(args) + process_manager.start_submodule_processes( + start_funcs=[ + start_metric_manager, + ], + start_args=[(metric_port, args)], + ) + from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( diff --git a/lightllm/server/httpserver_for_visual_only/manager.py b/lightllm/server/httpserver_for_visual_only/manager.py index e67071cab..beb8bbb18 100644 --- a/lightllm/server/httpserver_for_visual_only/manager.py +++ b/lightllm/server/httpserver_for_visual_only/manager.py @@ -38,13 +38,7 @@ class HttpServerManagerForVisualOnly: - def __init__( - self, - args, - cache_port, - visual_port, - # metric_port - ): + def __init__(self, args, cache_port, visual_port, metric_port): self.args = args context = zmq.asyncio.Context(2) @@ -58,7 +52,7 @@ def __init__( self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) self.max_req_total_len = args.max_req_total_len self.id_gen = ReqIDGenerator() - # self.metric_client = MetricClient(metric_port) + self.metric_client = MetricClient(metric_port) # 有的模型的vocab size 读取tokenizer和config.json中不一致 self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) return @@ -202,7 +196,6 @@ async def generate( group_request_id = self.alloc_req_id(sampling_params, is_health_req) try: - original_multimodal_params = None await multimodal_params.verify_and_preload(request) # 记录请求到达的相关信息 @@ -212,14 +205,6 @@ async def generate( ), "too many multimodal items!" await self._alloc_multimodal_resources(multimodal_params, sampling_params) - # 监控 - # if group_request_id > 0: - # self.metric_client.counter_inc("lightllm_request_count") - # self.metric_client.histogram_observe("lightllm_request_input_length", prompt_tokens) - # self.metric_client.histogram_observe("lightllm_request_max_new_tokens", - # sampling_params.max_new_tokens) - # prompt_ids = await self._check_and_repair_length(prompt_ids, sampling_params) - # 申请资源并存储 alloced_req_indexes = [] while len(alloced_req_indexes) < sampling_params.n: @@ -248,17 +233,7 @@ async def generate( req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) self.req_id_to_out_inf[group_request_id] = req_status - await self.transfer_to_visual(sampling_params, original_multimodal_params, req_status.group_req_objs) - - # results_generator = self._wait_to_token_package( - # start_time, - # group_request_id, - # sampling_params, - # req_status, - # request, - # ) - # async for sub_req_id, request_output, metadata, finish_status in results_generator: - # yield sub_req_id, request_output, metadata, finish_status + await self.transfer_to_visual(req_status.group_req_objs) except Exception as e: logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") @@ -287,8 +262,6 @@ async def _log_req_header(self, request_headers, group_request_id: int): async def transfer_to_visual( self, - sampling_params: SamplingParams, - original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, ): await self.send_to_visual.send_pyobj( @@ -297,96 +270,6 @@ async def transfer_to_visual( ) return - async def _wait_to_token_package( - self, - start_time, - group_request_id: int, - sampling_params: SamplingParams, - req_status: "ReqStatus", - request: Request, - ): - - event = req_status.event - unfinished_count = sampling_params.best_of - out_token_counter = 0 - first_token_cost_ms = sys.float_info.max - is_first_token = True - - while True: - try: - await asyncio.wait_for(event.wait(), timeout=5) - except asyncio.TimeoutError: - pass - - if not self.disable_abort and request is not None and await request.is_disconnected(): - await self.abort(group_request_id) - raise Exception(f"req_id {group_request_id} disconnected") - - async with req_status.lock: - event.clear() - if len(req_status.out_token_info_list) == 0: - continue - - for sub_req_id, out_str, metadata, finish_status in req_status.out_token_info_list: - # pd master 节点需要这个做统计信息, 所以放在元数据中返回给 pd master 节点 - # p 节点返回 prompt_ids 信息,防止 d 节点重新 encode - - prompt_cache_len = metadata.pop("prompt_cache_len", 0) - if is_first_token: - first_token_cost_ms = (time.time() - start_time) * 1000 - is_first_token = False - self.first_time_costs.add(first_token_cost_ms) - - out_token_counter += 1 - - # update inference timemark - self.latest_success_infer_time_mark.set_value(int(time.time())) - - yield sub_req_id, out_str, metadata, finish_status - # 如果有子请求完成,就更新计数 - if finish_status.is_finished(): - unfinished_count -= 1 - - if unfinished_count == 0: - total_cost_time_ms = (time.time() - start_time) * 1000 - mean_per_token_cost_time_ms = (total_cost_time_ms - first_token_cost_ms) / out_token_counter - self.per_token_costs.add(mean_per_token_cost_time_ms) - x_request_id = request.headers.get("X-Request-Id", "") if request is not None else "" - x_session_id = request.headers.get("X-Session-Id", "") if request is not None else "" - - mtp_avg_token_per_step = out_token_counter / max( - (out_token_counter - metadata["mtp_accepted_token_num"]), 1 - ) - format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"X-Request-Id:{x_request_id} " - f"X-Session-Id:{x_session_id} start_time:{format_start_time} " - f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " - f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter} " - f"mean_per_token_cost_time: {mean_per_token_cost_time_ms}ms " - f"prompt_cache_len:{prompt_cache_len} " - f"mtp_avg_token_per_step:{mtp_avg_token_per_step} " - ) - if group_request_id < 0: - # health 探测请求,不记录日志和监控 - return - self.metric_client.histogram_observe("lightllm_cache_length", prompt_cache_len) - self.metric_client.histogram_observe( - "lightllm_request_inference_duration", total_cost_time_ms / 1000.0 - ) - self.metric_client.histogram_observe( - "lightllm_request_mean_time_per_token_duration", mean_per_token_cost_time_ms / 1000.0 - ) - self.metric_client.histogram_observe( - "lightllm_request_first_token_duration", first_token_cost_ms / 1000.0 - ) - self.metric_client.histogram_observe("lightllm_request_generated_tokens", out_token_counter) - self.metric_client.counter_inc("lightllm_request_success") - - return - req_status.out_token_info_list.clear() - return - async def abort(self, group_req_id: int): req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id, None) if req_status is None: @@ -456,7 +339,7 @@ def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], sta shm_req_objs=req_objs, time_mark=start_time, ) - self.out_token_info_list = [] + self.finished = False def can_release(self): for req in self.group_req_objs.shm_req_objs: From 1f46fd288e737331a446c6b297e1aeb3d63fd373 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 26 Aug 2025 20:24:11 +0800 Subject: [PATCH 06/40] add redis server for vit/llm disaggaggregation --- lightllm/server/api_cli.py | 12 + lightllm/server/api_start.py | 25 +- .../impl/memory_cache_with_redis.py | 48 +++ .../embed_cache/impl/naive_memory_cache.py | 5 +- lightllm/server/embed_cache/utils.py | 314 +++++++++++++++++- lightllm/server/httpserver/manager.py | 45 +-- lightllm/utils/redis_utils.py | 60 ++++ 7 files changed, 450 insertions(+), 59 deletions(-) create mode 100644 lightllm/server/embed_cache/impl/memory_cache_with_redis.py create mode 100644 lightllm/utils/redis_utils.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 825c6cbb3..97b3cef7f 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -478,4 +478,16 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + # redis for vit llm disaggregation + parser.add_argument( + "--redis_port", + type=int, + default=6379, + help="The port number for the redis service in config_server mode.", + ) + parser.add_argument( + "--start_redis", + action="store_true", + help="Whether to start the redis service in config_server mode.", + ) return parser diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 3ae0375c5..f8601a237 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -15,16 +15,19 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.utils.redis_utils import start_redis_service logger = init_logger(__name__) -def setup_signal_handlers(http_server_process, process_manager): +def setup_signal_handlers(http_server_process, process_manager, redis_process=None): def signal_handler(sig, frame): if sig == signal.SIGINT: logger.info("Received SIGINT (Ctrl+C), forcing immediate exit...") if http_server_process: kill_recursive(http_server_process) + if redis_process and redis_process.poll() is None: + redis_process.terminate() process_manager.terminate_all_processes() logger.info("All processes have been forcefully terminated.") @@ -47,6 +50,19 @@ def signal_handler(sig, frame): logger.warning("HTTP server did not exit in time, killing it...") kill_recursive(http_server_process) + # 优雅关闭Redis + if redis_process and redis_process.poll() is None: + redis_process.send_signal(signal.SIGTERM) + start_time = time.time() + while (time.time() - start_time) < 10: + if redis_process.poll() is not None: + logger.info("Redis service has exited gracefully") + break + time.sleep(0.5) + else: + logger.warning("Redis service did not exit in time, killing it...") + redis_process.terminate() + process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) @@ -56,6 +72,8 @@ def signal_handler(sig, frame): logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") + if redis_process: + logger.info(f"redis service pid {redis_process.pid}") return @@ -639,6 +657,9 @@ def config_server_start(args): if args.run_mode != "config_server": return + # 启动Redis服务(如果指定) + redis_process = start_redis_service(args) + logger.info(f"all start args:{args}") set_env_start_args(args) @@ -666,5 +687,5 @@ def config_server_start(args): ] http_server_process = subprocess.Popen(command) - setup_signal_handlers(http_server_process, process_manager) + setup_signal_handlers(http_server_process, process_manager, redis_process) http_server_process.wait() diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py new file mode 100644 index 000000000..48b4b5680 --- /dev/null +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -0,0 +1,48 @@ +import uuid +import threading +import dataclasses +import requests +from typing import Union, Optional +import torch +import time +from collections import deque +import multiprocessing.shared_memory as shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, free_afs, EmbedRefCountRedis +from .naive_memory_cache import Record, InMemoryCache +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class MemoryCacheWithRedis(InMemoryCache): + def __init__(self, args) -> None: + super().__init__(args) + redis_url = f"redis://{args.config_server_host}:{args.redis_port}" + self.redis_cache = EmbedRefCountRedis( + redis_url=redis_url, + capacity=args.cache_capacity, + evict_fraction=args.evict_fraction, + image_embed_dir=args.image_embed_dir, + ) + # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id + # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 + # 硬盘里的图片image embed 数量。 + self.cache_capacity = args.cache_capacity * 2 + + def release(self, ids: list[int]) -> None: + with self.lock: + for id_ in ids: + self._records[id_].ref -= 1 + self.redis_cache.decr(id_) + + def set_items_data(self, ids: list[int]) -> None: + pass + + def get_items_data(self, ids: list[int]) -> list[Optional[bool]]: + return [self._records.get(id_).data if id_ in self._records else False for id_ in ids] + + def set_items_embed(self, ids: list[int]) -> None: + pass + + def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + pass diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index d263b3967..8b0f83f02 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -77,10 +77,7 @@ def _clear(self, free_max_count: int): if record.data: free_shm(get_shm_name_data(id)) if record.embed: - if self.args.run_mode == "visual_only": - free_afs(get_shm_name_embed(id)) - else: - free_shm(get_shm_name_embed(id)) + free_shm(get_shm_name_embed(id)) del self._md5_to_record[record.md5sum] del self._records[id] self.occupied -= 1 diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index b1227ec6b..9d9a0361b 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -1,7 +1,9 @@ import os import time import torch +import redis import numpy as np +from typing import List, Tuple from io import BytesIO from pathlib import Path import multiprocessing.shared_memory as shm @@ -70,18 +72,6 @@ def free_shm(name): shared_memory.unlink() -def free_afs(name): - path = os.path.join(get_env_start_args().visual_embed_path, name) - try: - os.remove(path) - except FileNotFoundError: - print("Warning free afs {} failed because of FileNotFoundError!".format(name)) - return - except PermissionError as e: - print("Warning free afs {} failed due to PermissionError: {}".format(name, e)) - return - - def get_shm_name_data(uid): return str(uid) + "-data" @@ -95,3 +85,303 @@ def afs_embed_exists(md5sum: str): filename = f"{uid_int}-embed" fullpath = os.path.join(get_env_start_args().visual_embed_path, filename) return True if os.path.isfile(fullpath) else False + + +""" +Importable Redis-backed MD5 refcount with LRU eviction. + +Public API: + from md5_refcount import EmbedRefCountRedis + + cache = EmbedRefCountRedis( + redis_url="redis://localhost:6379/0", + capacity=10000, + evict_fraction=0.2 + ) + + # Insert a new md5 with default ref_count=0 + success, evicted_list = cache.insert(md5) + + # Query if exists and increment ref_count if found + exists = cache.query_and_incre(md5) + + # Decrement ref_count + rc, deleted = cache.decr(md5) + + s = cache.stats() +""" + + +class EmbedRefCountRedis: + def __init__( + self, + redis_url: str = "redis://localhost:6379/0", + capacity: int = 50_000, + evict_fraction: float = 0.2, + key_prefix: str = "md5:", + image_embed_dir: str = None, + path_ext: str = ".embed", + **redis_kwargs, + ) -> None: + """ + - capacity: max count of md5 entries allowed in Redis + - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity + - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") + - path_ext: file extension for embed files (default: ".embed") + """ + if not (0.0 <= evict_fraction <= 1.0): + raise ValueError("evict_fraction must be 0..1") + if capacity < 1: + raise ValueError("capacity must be >=1") + + self.capacity = int(capacity) + self.evict_fraction = float(evict_fraction) + self.zset_key = f"{key_prefix}lru" + self.ref_prefix = f"{key_prefix}rc:" + self.lock_key = f"{key_prefix}evict:lock" + self.image_embed_dir = image_embed_dir + self.path_ext = path_ext + + self.r = redis.Redis.from_url(redis_url, decode_responses=True, **redis_kwargs) + + # Register Lua scripts + self._insert_script = self.r.register_script(self._INSERT_LUA) + self._query_incre_script = self.r.register_script(self._QUERY_INCRE_LUA) + self._decr_script = self.r.register_script(self._DECR_LUA) + self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) + + def insert(self, md5: str) -> Tuple[bool, List[str]]: + """Insert a new md5 with default ref_count=0. May trigger LRU eviction.""" + # 等待任何正在进行的逐出操作 + self._wait_if_eviction() + + res = self._insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + + if res[0] == 0: # No eviction needed + return True, [] + + # Need eviction - use atomic eviction script + try: + if self._try_acquire_lock(): + try: + # 原子执行逐出和插入 + evict_res = self._evict_and_insert_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5, self.capacity, self.evict_fraction], + ) + success = bool(evict_res[0]) + victims = evict_res[1:] if len(evict_res) > 1 else [] + + # 删除被逐出md5对应的AFS文件 + if victims and self.image_embed_dir: + self._delete_afs_files(victims) + + return success, victims + finally: + self._release_lock() + else: + # 等待锁释放后重试 + time.sleep(0.1) + return self.insert(md5) + except Exception as e: + self._release_lock() + raise e + + def query_and_incre(self, md5: str) -> bool: + """Query if md5 exists and increment ref_count if found.""" + self._wait_if_eviction() + + res = self._query_incre_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + return bool(res[0]) + + def decr(self, md5: str) -> Tuple[int, bool]: + """Decrement ref_count for md5. Returns (ref_count, deleted).""" + self._wait_if_eviction() + + res = self._decr_script( + keys=[self.zset_key, self.ref_prefix], + args=[md5], + ) + if res[0] == -1: + raise KeyError("md5 not found") + return int(res[0]), bool(res[1]) + + def stats(self) -> dict: + self._wait_if_eviction() + + size = self.r.zcard(self.zset_key) + return { + "items": size, + "capacity": self.capacity, + "evict_fraction": self.evict_fraction, + } + + def _wait_if_eviction(self) -> None: + max_wait = 30 + start_time = time.time() + + while self.r.exists(self.lock_key): + if time.time() - start_time > max_wait: + raise TimeoutError("Eviction operation timeout, waited too long") + time.sleep(0.01) # 短暂等待 + + def _try_acquire_lock(self) -> bool: + return bool(self.r.set(self.lock_key, "1", nx=True, ex=30)) + + def _release_lock(self) -> None: + try: + self.r.delete(self.lock_key) + except Exception: + pass + + def _md5_to_afs_path(self, md5: str) -> str: + """Convert md5 to AFS file path.""" + if not self.image_embed_dir: + return None + filename = md5 + self.path_ext + return filename + + def _delete_afs_files(self, victims: List[str]) -> None: + """Delete AFS files for evicted md5s.""" + if not self.image_embed_dir: + return + + for md5 in victims: + try: + file_path = self._md5_to_afs_path(md5) + if file_path and os.path.exists(file_path): + os.remove(file_path) + print(f"Deleted AFS file: {file_path}") + except Exception as e: + print(f"Warning: Failed to delete AFS file for {md5}: {e}") + + # ---------------- Lua scripts ---------------- + _INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) + +local ref_key = ref_prefix .. md5 +if redis.call('GET', ref_key) then + return {0} -- Already exists +end + +local size = redis.call('ZCARD', zset) +if size < capacity then + -- Insert with ref_count=0 + redis.call('SET', ref_key, 0) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) + return {0} -- Success, no eviction +end + +return {1} -- Need eviction +""" + + _QUERY_INCRE_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {0} -- Not found +end + +-- Found, increment ref_count and update LRU +local rc = tonumber(val) + 1 +redis.call('SET', ref_key, rc) +local now = redis.call('TIME')[1] * 1000 +redis.call('ZADD', zset, now, md5) +return {1} -- Found and incremented +""" + + _DECR_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = md5 +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local md5 = ARGV[1] + +local ref_key = ref_prefix .. md5 +local val = redis.call('GET', ref_key) + +if not val then + return {-1, 0} -- Not found +end + +local rc = tonumber(val) - 1 +if rc <= 0 then + redis.call('DEL', ref_key) + redis.call('ZREM', zset, md5) + return {0, 1} -- Deleted +else + redis.call('SET', ref_key, rc) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) + return {rc, 0} -- Updated +end +""" + + _EVICT_AND_INSERT_LUA = r""" +-- KEYS[1] = zset key, KEYS[2] = ref_prefix +-- ARGV[1] = new_md5, ARGV[2] = capacity, ARGV[3] = evict_fraction +local zset = KEYS[1] +local ref_prefix = KEYS[2] +local new_md5 = ARGV[1] +local capacity = tonumber(ARGV[2]) +local evict_fraction = tonumber(ARGV[3]) + +-- 计算需要逐出的数量 +local need = math.max(1, math.floor(capacity * evict_fraction + 0.5)) +local victims = {} + +-- 获取所有键并按LRU排序 +local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') +local i = 1 + +-- 查找引用计数为0的键作为逐出候选 +while #victims < need and i <= #all_keys do + local md5 = all_keys[i] + local ref_key = ref_prefix .. md5 + local rc = redis.call('GET', ref_key) + + if rc and tonumber(rc) <= 0 then + table.insert(victims, md5) + end + i = i + 2 -- 跳过分数 +end + +-- 如果找到足够的候选,执行逐出 +if #victims >= need then + -- 删除受害者 + for _, v in ipairs(victims) do + local ref_key = ref_prefix .. v + redis.call('DEL', ref_key) + redis.call('ZREM', zset, v) + end + + -- 插入新的md5 + local ref_key = ref_prefix .. new_md5 + redis.call('SET', ref_key, 0) + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, new_md5) + + return {1, table.unpack(victims)} -- success + victims +else + return {0} -- 逐出失败,没有足够的候选 +end +""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 506e48f19..b3b65987a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -137,6 +137,10 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): item.token_num = rec["token_num"] uid_list.append(rec["id"]) + # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + if self.args.run_mode == "llm_only": + return + ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -186,47 +190,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, await self._alloc_resource(items, md5sums, tokens_nums, datas) return - async def _wait_for_afs_embed(self, md5sum_hex: str, interval_sec: float = 0.01) -> None: - while not afs_embed_exists(md5sum_hex): - await asyncio.sleep(interval_sec) - - async def _get_image_embedding_from_afs(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): - for img in multimodal_params.images: - self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - data = img.read() - # must after init_imageitem_extral_params - token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), - ) - uid_int = int(md5sum, 16) - if not afs_embed_exists(md5sum): - await self._wait_for_afs_embed(md5sum) - img.uuid = uid_int - img.afs_embed = True - token_id_range_start = self.token_id_range_start - img.token_id = token_id_range_start - img.token_num = token_num - - for audio in multimodal_params.audios: - self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) - data = audio.read() - token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), - ) - if not afs_embed_exists(md5sum): - await self._wait_for_afs_embed(md5sum) - uid_int = int(md5sum, 16) - audio.uuid = uid_int - audio.afs_embed = True - token_id_range_start = self.token_id_range_start - audio.token_id = token_id_range_start - audio.token_num = token_num - return - async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): # 只有 P 和 NORMAL 节点需要真的管理多模态资源 if self.pd_mode.is_P_or_NORMAL(): diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py new file mode 100644 index 000000000..bd4d87ca3 --- /dev/null +++ b/lightllm/utils/redis_utils.py @@ -0,0 +1,60 @@ +import subprocess +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +def start_redis_service(args): + """launch redis service""" + if not hasattr(args, "start_redis") or not args.start_redis: + return None + + try: + redis_port = args.redis_port + + redis_command = [ + "redis-server", + "--port", + str(redis_port), + "--bind", + f"{args.config_server_host}", + "--daemonize", + "no", + "--logfile", + "-", + "--loglevel", + "notice", + ] + + logger.info(f"Starting Redis service on port {redis_port}") + redis_process = subprocess.Popen(redis_command) + + import redis + import time + + max_wait = 10 + start_time = time.time() + + while time.time() - start_time < max_wait: + try: + r = redis.Redis(host=args.config_server_host, port=redis_port, socket_connect_timeout=1) + r.ping() + logger.info(f"Redis service started successfully on port {redis_port}") + del r + break + except Exception: + time.sleep(0.5) + if redis_process.poll() is not None: + logger.error("Redis service failed to start") + return None + else: + logger.error("Redis service startup timeout") + if redis_process.poll() is None: + redis_process.terminate() + return None + + return redis_process + + except Exception as e: + logger.error(f"Failed to start Redis service: {e}") + return None From 3561a1792a9c6984abb8b762b375740eb9c4d126 Mon Sep 17 00:00:00 2001 From: baishihao Date: Tue, 26 Aug 2025 20:29:04 +0800 Subject: [PATCH 07/40] remove unused code of http manager --- lightllm/server/embed_cache/utils.py | 7 ------- lightllm/server/httpserver/manager.py | 8 ++------ 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 9d9a0361b..04eed2ecb 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -80,13 +80,6 @@ def get_shm_name_embed(uid): return str(uid) + "-embed" -def afs_embed_exists(md5sum: str): - uid_int = int(md5sum, 16) - filename = f"{uid_int}-embed" - fullpath = os.path.join(get_env_start_args().visual_embed_path, filename) - return True if os.path.isfile(fullpath) else False - - """ Importable Redis-backed MD5 refcount with LRU eviction. diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index b3b65987a..162258ec1 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -16,7 +16,7 @@ from fastapi import Request from ..tokenizer import get_tokenizer from ..pd_io_struct import NodeRole -from ..embed_cache.utils import get_shm_name_data, create_shm, afs_embed_exists +from ..embed_cache.utils import get_shm_name_data, create_shm from ..multimodal_params import AudioItem, MultimodalParams, ImageItem from ..req_id_generator import ReqIDGenerator from .async_queue import AsyncQueue @@ -117,10 +117,6 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _check_and_set_new_id_range(self, token_num): - assert self.token_id_range_start + token_num < self.token_id_range_end - self.token_id_range_start += token_num - async def _alloc_resource(self, items, md5sums, token_nums, datas): while True: @@ -253,7 +249,7 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): # health 请求 request_id 为负数,直接返回 if is_health_req: return sampling_params.group_request_id - if self.pd_mode == NodeRole.NORMAL or self.pd_mode == NodeRole.LLM_ONLY: + if self.pd_mode.is_normal(): if not self.is_multinode_tp: group_request_id = self.id_gen.generate_id() else: From 0ef48cb62e1d5f5bd57abaf675bd0ffb24d45d32 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 26 Aug 2025 12:33:37 +0000 Subject: [PATCH 08/40] [0826]modify visual server --- lightllm/models/vit/model.py | 2 +- lightllm/server/api_http.py | 16 +- lightllm/server/api_lightllm.py | 9 +- lightllm/server/api_start.py | 46 +-- .../server/core/objs/io_objs/group_req.py | 6 + lightllm/server/core/objs/req.py | 28 +- lightllm/server/embed_cache/utils.py | 5 +- .../httpserver_for_visual_only/__init__.py | 0 .../httpserver_for_visual_only/manager.py | 348 ------------------ lightllm/server/multimodal_params.py | 9 +- lightllm/server/visualserver/manager.py | 158 ++++++-- .../visualserver/model_infer/model_rpc.py | 43 ++- 12 files changed, 226 insertions(+), 444 deletions(-) delete mode 100644 lightllm/server/httpserver_for_visual_only/__init__.py delete mode 100644 lightllm/server/httpserver_for_visual_only/manager.py diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..a114bc1dd 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -178,7 +178,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = img._preload_data image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 4fd4edb3f..43f9780e8 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -42,7 +42,6 @@ from .httpserver.manager import HttpServerManager from .visualserver.manager import VisualManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster -from .httpserver_for_visual_only.manager import HttpServerManagerForVisualOnly from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger @@ -70,7 +69,7 @@ class G_Objs: args: object = None g_generate_func: Callable = None g_generate_stream_func: Callable = None - httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster, HttpServerManagerForVisualOnly] = None + httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster, VisualManager] = None visual_manager: VisualManager = None shared_token_load: TokenLoad = None @@ -94,11 +93,12 @@ def set_args(self, args): ) elif args.run_mode == "visual_only": self.metric_client = MetricClient(args.metric_port) - self.httpserver_manager = HttpServerManagerForVisualOnly( + self.httpserver_manager = VisualManager( args, - cache_port=args.cache_port, + next_module_port=None, visual_port=args.visual_port, - metric_port=args.metric_port, + cache_port=None, + visual_model_rpc_ports=args.visual_model_rpc_ports, ) elif args.run_mode == "llm_only": init_tokenizer(args) # for openai api @@ -372,6 +372,10 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - loop.create_task(g_objs.httpserver_manager.handle_loop()) + if g_objs.args.run_mode == "visual_only": + await g_objs.httpserver_manager.wait_to_model_ready() + loop.create_task(g_objs.httpserver_manager.loop_for_fwd_visual_only()) + else: + loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 660608f6e..97032ba2e 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -5,7 +5,7 @@ from lightllm.server.core.objs.sampling_params import SamplingParams from .multimodal_params import MultimodalParams from .httpserver.manager import HttpServerManager -from .httpserver_for_visual_only.manager import HttpServerManagerForVisualOnly +from .visualserver.manager import VisualManager from fastapi.responses import JSONResponse import ujson as json @@ -140,9 +140,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def lightllm_get_image_embedding( - request: Request, httpserver_manager: HttpServerManagerForVisualOnly -) -> Response: +async def lightllm_get_image_embedding(request: Request, httpserver_manager: VisualManager) -> Response: request_dict = await request.json() # request_dict: {'parameters': {'max_new_tokens': 128}, # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} @@ -154,6 +152,5 @@ async def lightllm_get_image_embedding( multimodal_params = MultimodalParams(**multimodal_params_dict) await httpserver_manager.generate(sampling_params, multimodal_params, request=request) - # 5. Return JSON result - print("embedding OK") + return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f8601a237..34263a038 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -548,17 +548,16 @@ def visual_only_start(args): return already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] can_use_ports = alloc_can_use_network_port( - num=5 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=4 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( router_port, visual_port, audio_port, - cache_port, metric_port, - ) = can_use_ports[0:5] - can_use_ports = can_use_ports[5:] + ) = can_use_ports[0:4] + can_use_ports = can_use_ports[4:] visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -570,7 +569,6 @@ def visual_only_start(args): args.router_port = router_port args.visual_port = visual_port args.audio_port = audio_port - args.cache_port = cache_port args.metric_port = metric_port args.visual_model_rpc_ports = visual_model_tp_ports @@ -585,33 +583,17 @@ def visual_only_start(args): start_args=[(metric_port, args)], ) - from .visualserver.manager import start_visual_process - - process_manager.start_submodule_processes( - start_funcs=[ - start_cache_manager, - ], - start_args=[(cache_port, args)], - ) - process_manager.start_submodule_processes( - start_funcs=[ - start_visual_process, - ], - start_args=[ - (args, audio_port, visual_port, cache_port, visual_model_tp_ports), - ], - ) - if args.enable_multimodal_audio: - from .audioserver.manager import start_audio_process - - process_manager.start_submodule_processes( - start_funcs=[ - start_audio_process, - ], - start_args=[ - (args, router_port, audio_port, cache_port), - ], - ) + # if args.enable_multimodal_audio: + # from .audioserver.manager import start_audio_process + + # process_manager.start_submodule_processes( + # start_funcs=[ + # start_audio_process, + # ], + # start_args=[ + # (args, router_port, audio_port, cache_port), + # ], + # ) # 启动 gunicorn command = [ diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index d16dc4d06..4236b713c 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -4,6 +4,12 @@ from ..req import Req +@dataclass +class VisualOnlyReqIndexes: + group_req_id: int + multimodal_params: MultimodalParams + + @dataclass class GroupReqIndexes: group_req_id: int diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 06a728925..5fb89ad54 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -153,6 +153,30 @@ def init( self.post_init() + def init_visual_only( + self, + request_id: int, + ): + # 只是为了有更好的编码辅助类型提示 + self.index_in_shm_mem: int = self.index_in_shm_mem + self.ref_count: int = self.ref_count + + self.request_id = request_id + self.group_req_id = convert_sub_id_to_group_id(request_id) + self.is_paused = False + self.finish_status = FinishStatus() + self.is_aborted = False + self.router_aborted = False + self.shm_infer_released = False + self.shm_cur_kv_len = 0 + self.shm_cur_output_len = 0 + self.candetoken_out_len = 0 + self.prompt_cache_len = 0 + self.finish_token_index = -1 + self.can_released_mark = False + + self.post_init() + def post_init(self): # 子类继承进行一些额外的初始化操作 pass @@ -206,7 +230,9 @@ def can_release(self): # 只有管理节点有一个引用 ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - + print(f"self.is_aborted is {self.is_aborted}") + print(f"self.finish_status.is_finished() is {self.finish_status.is_finished()}") + print(f"self.ref_count is {self.ref_count}") if self.is_aborted and can_released_mark and ref_count_ok: return True diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 04eed2ecb..7fd9978bf 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -44,7 +44,7 @@ def create_shm(name, data): def create_afs(name, data): try: data_size = len(data) - path = os.path.join(get_env_start_args().visual_embed_path, name) + path = os.path.join("/mtc/sangchengmeng/afs", name) with open(path, "xb") as f: mem_view = memoryview(data) f.write(mem_view[:data_size]) @@ -79,7 +79,6 @@ def get_shm_name_data(uid): def get_shm_name_embed(uid): return str(uid) + "-embed" - """ Importable Redis-backed MD5 refcount with LRU eviction. @@ -377,4 +376,4 @@ def _delete_afs_files(self, victims: List[str]) -> None: else return {0} -- 逐出失败,没有足够的候选 end -""" +""" \ No newline at end of file diff --git a/lightllm/server/httpserver_for_visual_only/__init__.py b/lightllm/server/httpserver_for_visual_only/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/lightllm/server/httpserver_for_visual_only/manager.py b/lightllm/server/httpserver_for_visual_only/manager.py deleted file mode 100644 index beb8bbb18..000000000 --- a/lightllm/server/httpserver_for_visual_only/manager.py +++ /dev/null @@ -1,348 +0,0 @@ -import sys -import zmq -import zmq.asyncio -import asyncio -import uvloop -import rpyc -import time -import json -import copy -import hashlib -import datetime -import pickle -from frozendict import frozendict - -asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) -from typing import Union, List, Tuple, Dict, Optional -from fastapi import Request -from ..tokenizer import get_tokenizer -from ..pd_io_struct import NodeRole -from ..embed_cache.utils import get_shm_name_data, create_shm, afs_embed_exists -from ..multimodal_params import AudioItem, MultimodalParams, ImageItem -from ..req_id_generator import ReqIDGenerator -from lightllm.server.core.objs import Req, FinishStatus -from lightllm.server.core.objs import SamplingParams -from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE -from lightllm.server.core.objs.io_objs import GroupReqObjs -from lightllm.server.core.objs.shm_req_manager import ShmReqManager -from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem -from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt -from lightllm.utils.log_utils import init_logger -from lightllm.server.metrics.manager import MetricClient -from lightllm.utils.statics_utils import MovingAverage -from lightllm.utils.config_utils import get_vocab_size -from lightllm.utils.envs_utils import get_unique_server_name -from rpyc.utils.classic import obtain - -logger = init_logger(__name__) - - -class HttpServerManagerForVisualOnly: - def __init__(self, args, cache_port, visual_port, metric_port): - self.args = args - context = zmq.asyncio.Context(2) - - self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1) - self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0)) - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.shm_req_manager = ShmReqManager() - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) - self.req_id_to_out_inf: Dict[int, ReqStatus] = {} # value type (out_str, metadata, finished, event) - self.max_req_total_len = args.max_req_total_len - self.id_gen = ReqIDGenerator() - self.metric_client = MetricClient(metric_port) - # 有的模型的vocab size 读取tokenizer和config.json中不一致 - self.vocab_size = max(get_vocab_size(args.model_dir), self.tokenizer.vocab_size) - return - - async def _alloc_resource(self, items, md5sums, token_nums, datas): - - while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) - - if records is None: - await asyncio.sleep(0.1) - continue - - uid_list = [] - for item, rec in zip(items, records): - item.uuid = rec["id"] - item.token_id = rec["token_id"] - item.token_num = rec["token_num"] - uid_list.append(rec["id"]) - - ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) - update_data_ids = [] - - for uid, ready, data in zip(uid_list, ready_flags, datas): - if not ready: - create_shm(get_shm_name_data(uid), data) - update_data_ids.append(uid) - - if update_data_ids: - self.cache_client.root.set_items_data(update_data_ids) - return - - async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams): - # 这里的锁是为了 防止多个含有多张图片的请求 同时申请的record数量 大于cache_capacity,从而造成死锁的问题。 - # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, - # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 - async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] - for img in multimodal_params.images: - self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - data = img.read() - # must after init_imageitem_extral_params - token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), - ) - md5sums.append(md5sum) - tokens_nums.append(token_num) - datas.append(data) - items.append(img) - for audio in multimodal_params.audios: - self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params) - data = audio.read() - token_num = self.tokenizer.get_audio_token_length(audio) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), - ) - md5sums.append(md5sum) - tokens_nums.append(token_num) - datas.append(data) - items.append(audio) - - await self._alloc_resource(items, md5sums, tokens_nums, datas) - return - - async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): - if multimodal_params is not None: - ids_to_release = [] - for img in multimodal_params.images: - if img.uuid is not None: - ids_to_release.append(img.uuid) - # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 - img.uuid = None - img.token_id = None - img.token_num = None - for audio in multimodal_params.audios: - if audio.uuid is not None: - ids_to_release.append(audio.uuid) - # 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常 - audio.uuid = None - audio.token_id = None - audio.token_num = None - if ids_to_release: - self.cache_client.root.release(ids_to_release) - return - - def tokens(self, multimodal_params, samping_params: SamplingParams, kwargs=None): - kwargs = {} if kwargs is None else kwargs - image_tokens = 0 - img_count = 0 - audio_tokens = 0 - audio_count = 0 - for img in multimodal_params.images: - img_count += 1 - self.tokenizer.init_imageitem_extral_params(img, multimodal_params, samping_params) - image_tokens += self.tokenizer.get_image_token_length(img) - for audio in multimodal_params.audios: - audio_count += 1 - self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, samping_params) - audio_tokens += self.tokenizer.get_audio_token_length(audio) - return image_tokens + img_count + audio_tokens + audio_count - - async def loop_for_request(self): - assert self.args.node_rank > 0 - while True: - ( - sampling_params, - multimodal_params, - ) = await self.multinode_req_manager.recv_pyobj() - results_generator = self.generate(sampling_params, multimodal_params, None) - - async def generate_wrapper(results_generator): - async for _, _, _, _ in results_generator: - pass - - asyncio.create_task(generate_wrapper(results_generator)) - return - - def alloc_req_id(self, sampling_params, is_health_req: bool = False): - # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 - # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 - # health 请求 request_id 为负数,直接返回 - if is_health_req: - return sampling_params.group_request_id - group_request_id = self.id_gen.generate_id() - - sampling_params.group_request_id = group_request_id - return group_request_id - - async def generate( - self, - sampling_params: SamplingParams, - multimodal_params: MultimodalParams, - request: Request, - is_health_req: bool = False, - ) -> Tuple[int, str, dict, FinishStatus]: - start_time = time.time() - request_headers = request.headers if request is not None else {} - group_request_id = self.alloc_req_id(sampling_params, is_health_req) - - try: - await multimodal_params.verify_and_preload(request) - - # 记录请求到达的相关信息 - await self._log_req_header(request_headers, group_request_id) - assert ( - len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity - ), "too many multimodal items!" - await self._alloc_multimodal_resources(multimodal_params, sampling_params) - - # 申请资源并存储 - alloced_req_indexes = [] - while len(alloced_req_indexes) < sampling_params.n: - alloc_req_index = await self.shm_req_manager.async_alloc_req_index() - sleep_time = 0.1 - while alloc_req_index is None: - await asyncio.sleep(sleep_time) - sleep_time *= 1.1 - sleep_time = min(1, sleep_time) - - alloc_req_index = await self.shm_req_manager.async_alloc_req_index() - alloced_req_indexes.append(alloc_req_index) - req_objs = [] - for i, req_index in enumerate(alloced_req_indexes): - req_obj = await self.shm_req_manager.async_get_req_obj_by_index(req_index) - req_obj.init( - group_request_id + i, - # 随便写的,后面改掉 - [21456], - sampling_params, - self.tokenizer, - chunked_prefill_size=self.args.chunked_prefill_size, - ) - req_objs.append(req_obj) - - req_status = ReqStatus(group_request_id, multimodal_params, req_objs, start_time) - self.req_id_to_out_inf[group_request_id] = req_status - - await self.transfer_to_visual(req_status.group_req_objs) - - except Exception as e: - logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") - # error need to release multimodel resources. - # 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放 - # 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环 - # 进行回收。 - if group_request_id not in self.req_id_to_out_inf: - await self._release_multimodal_resources(multimodal_params) - await self.abort(group_request_id) - raise e - return - - async def _log_req_header(self, request_headers, group_request_id: int): - - x_request_id = request_headers.get("X-Request-Id", "") - x_session_id = request_headers.get("X-Session-Id", "") - - format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"recieved req X-Request-Id:{x_request_id} " - f"X-Session-Id:{x_session_id} start_time:{format_in_time} " - f"lightllm_req_id:{group_request_id} " - ) - return - - async def transfer_to_visual( - self, - group_req_objs: Optional[GroupReqObjs] = None, - ): - await self.send_to_visual.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) - return - - async def abort(self, group_req_id: int): - req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id, None) - if req_status is None: - logger.warning(f"aborted group_request_id {group_req_id} not exist") - return - - group_req_objs: GroupReqObjs = req_status.group_req_objs - for req in group_req_objs.shm_req_objs: - req.is_aborted = True - logger.warning(f"aborted group_request_id {group_req_objs.group_req_id}") - return - - async def recycle_resource_loop(self): - pre_time_mark = time.time() - - while True: - - try: - await asyncio.wait_for(self.recycle_event.wait(), timeout=0.02) - except asyncio.TimeoutError: - pass - self.recycle_event.clear() - - # 清理已经处理完的可以删除的请求 - release_req_status: List[ReqStatus] = [] - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is not None and req_status.can_release(): - release_req_status.append(req_status) - - for req_status in release_req_status: - self.req_id_to_out_inf.pop(req_status.group_req_objs.group_req_id, None) - for req in req_status.group_req_objs.shm_req_objs: - await self.shm_req_manager.async_put_back_req_obj(req) - await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) - await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) - - # 先保留这个关键得日志,用于方便定位重构中的问题。 - if time.time() - pre_time_mark > 120: - pre_time_mark = time.time() - for group_req_id_ in list(self.req_id_to_out_inf.keys()): - req_status: ReqStatus = self.req_id_to_out_inf.get(group_req_id_, None) - if req_status is None: - continue - - logger.info( - f"left req id {req_status.group_req_objs.group_req_id}" - f"can release {req_status.group_req_objs.shm_req_objs[0].can_released_mark} " - f"refcount {req_status.group_req_objs.shm_req_objs[0].ref_count}" - ) - return - - async def handle_loop(self): - self.recycle_event = asyncio.Event() - asyncio.create_task(self.recycle_resource_loop()) - - return - - -class ReqStatus: - def __init__(self, group_request_id, multimodal_params, req_objs: List[Req], start_time) -> None: - self.lock = asyncio.Lock() - self.event = asyncio.Event() - self.group_req_objs = GroupReqObjs( - group_req_id=group_request_id, - multimodal_params=multimodal_params, - shm_req_objs=req_objs, - time_mark=start_time, - ) - self.finished = False - - def can_release(self): - for req in self.group_req_objs.shm_req_objs: - if not req.can_release(): - return False - return True diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 1171c7c32..6c06082c8 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -54,8 +54,8 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None ans = self._preload_data - self._preload_data = None - self._data = None + # self._preload_data = None + # self._data = None return ans def to_dict(self): @@ -79,6 +79,7 @@ def __init__(self, **kwargs): self.image_w = 0 self.image_h = 0 self.afs_embed = False + self.is_abort = False self._preload_data = None self.extra_params = {} @@ -113,8 +114,8 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None ans = self._preload_data - self._preload_data = None - self._data = None + # self._preload_data = None + # self._data = None return ans def to_dict(self): diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index eb7620f9b..eea4e8fe9 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -1,13 +1,22 @@ import zmq +import time import zmq.asyncio import asyncio import uvloop import rpyc import pickle +import hashlib +import datetime import inspect -from typing import List -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from fastapi import Request +from ..tokenizer import get_tokenizer +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes, VisualOnlyReqIndexes from lightllm.server.core.objs import ShmReqManager +from lightllm.server.core.objs import SamplingParams +from lightllm.server.core.objs import Req, FinishStatus +from typing import Union, List, Tuple, Dict, Optional +from ..req_id_generator import ReqIDGenerator +from lightllm.server.core.objs.io_objs import GroupReqObjs asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem @@ -30,29 +39,33 @@ def __init__( cache_port, visual_model_rpc_ports, ): - self.visual_only = True if args.run_mode == "visual_only" else False + self.args = args + self.visual_only = True if self.args.run_mode == "visual_only" else False context = zmq.Context(2) + self.id_gen = ReqIDGenerator() if not self.visual_only: self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) + self.cache_port = cache_port - self.waiting_reqs: List[GroupReqIndexes] = [] + self.waiting_reqs_from_httpserver: List[GroupReqIndexes] = [] + self.waiting_reqs_visual_only: List[VisualOnlyReqIndexes] = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp self.vit_dp = args.visual_dp self.vit_tp = args.visual_tp self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code - self.args = args self.visual_model_rpc_ports = visual_model_rpc_ports self.shm_req_manager = ShmReqManager() + self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) async def wait_to_model_ready(self): - + # 待完成,需要读取config_server来起多个vit self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] for dp_rank_id in range(self.vit_dp): @@ -102,26 +115,15 @@ async def infer_imgs(self, images: List[ImageItem]): await asyncio.gather(*tasks) return - async def _send_to_next_module_or_release(self, group_req_indexes: GroupReqIndexes): - if self.visual_only: - for idx in group_req_indexes.shm_req_indexes: - shm_req = self.shm_req_manager.get_req_obj_by_index(idx) - shm_req.can_released_mark = True - shm_req.finish_status.set_status(1) - self.shm_req_manager.put_back_req_obj(shm_req) - logger.info(f"router release req id {shm_req.request_id}") - else: - self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) - async def loop_for_fwd(self): while True: - if len(self.waiting_reqs) == 0: + if len(self.waiting_reqs_from_httpserver) == 0: await asyncio.sleep(0.01) # 10ms else: processing_group_reqs = [] images_need_infer = [] - while len(self.waiting_reqs) > 0: - group_req_indexes = self.waiting_reqs.pop(0) + while len(self.waiting_reqs_from_httpserver) > 0: + group_req_indexes = self.waiting_reqs_from_httpserver.pop(0) shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) is_aborted = shm_req.is_aborted self.shm_req_manager.put_back_req_obj(shm_req) @@ -129,7 +131,7 @@ async def loop_for_fwd(self): # 因为连接断开 aborted 掉的请求也需要传输到后续的模块进行处理 # 因为采用 shm 来映射所有的 req 对象以后,引用管理情况复杂了 # 需要一些一致的流程来保证不出现异步问题。 - await self._send_to_next_module_or_release(group_req_indexes) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) continue multimodal_params = group_req_indexes.multimodal_params @@ -145,21 +147,125 @@ async def loop_for_fwd(self): await self.infer_imgs(images_need_infer) images_need_infer = [] for _group_req_indexes in processing_group_reqs: - await self._send_to_next_module_or_release(group_req_indexes) + self.send_to_next_module.send_pyobj( + _group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL + ) processing_group_reqs = [] if len(images_need_infer) == 0: - await self._send_to_next_module_or_release(group_req_indexes) + self.send_to_next_module.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) else: processing_group_reqs.append(group_req_indexes) if len(images_need_infer) > 0: await self.infer_imgs(images_need_infer) for _group_req_indexes in processing_group_reqs: - await self._send_to_next_module_or_release(group_req_indexes) + self.send_to_next_module.send_pyobj(_group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL) processing_group_reqs = [] images_need_infer = [] + async def loop_for_fwd_visual_only(self): + while True: + if len(self.waiting_reqs_visual_only) == 0: + await asyncio.sleep(0.01) # 10ms + else: + images_need_infer = [] + + while len(self.waiting_reqs_visual_only) > 0: + visual_req = self.waiting_reqs_visual_only.pop(0) + + for img in visual_req.multimodal_params.images: + if img.is_abort: + continue + images_need_infer.append(img) + + if len(images_need_infer) == self.infer_batch_size: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + + if len(images_need_infer) > 0: + await self.infer_imgs(images_need_infer) + images_need_infer = [] + # 在这里release这个image,ref-1 + logger.info(f"req-id {visual_req.group_req_id} has been release ok") + + async def _initialize_multimodal_metadata( + self, multimodal_params: MultimodalParams, sampling_params: SamplingParams + ): + for img in multimodal_params.images: + self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) + data = img.read() + # must after init_imageitem_extral_params + token_num = self.tokenizer.get_image_token_length(img) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), + ) + img.uuid = int(md5sum, 16) + img.token_num = token_num + + async def _log_req_header(self, request_headers, group_request_id: int, image_count: int): + + x_request_id = request_headers.get("X-Request-Id", "") + x_session_id = request_headers.get("X-Session-Id", "") + + format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"recieved req X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_in_time} " + f"lightllm_req_id:{group_request_id} " + f"image_count:{image_count}" + ) + return + + def alloc_req_id(self, sampling_params, is_health_req: bool = False): + # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 + # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 + # health 请求 request_id 为负数,直接返回 + if is_health_req: + return sampling_params.group_request_id + group_request_id = self.id_gen.generate_id() + + sampling_params.group_request_id = group_request_id + return group_request_id + + async def generate( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + await multimodal_params.verify_and_preload(request) + image_count = len(multimodal_params.images) + # 记录请求到达的相关信息 + await self._log_req_header(request_headers, group_request_id, image_count) + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + + await self._initialize_multimodal_metadata(multimodal_params, sampling_params) + + visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, multimodal_params=multimodal_params) + self.waiting_reqs_visual_only.append(visual_req_status) + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self.abort(group_request_id, multimodal_params) + raise e + return + + async def abort(self, group_req_id: int, multimodal_params: MultimodalParams): + logger.warning(f"aborted group_request_id {group_req_id}") + for img in multimodal_params.images: + img.is_abort = True + return + async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -169,7 +275,7 @@ async def loop_for_netio_req(self): for _ in range(self.visual_recv_max_count): recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): - self.waiting_reqs.append(recv_req) + self.waiting_reqs_from_httpserver.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index c219e3685..17355ede8 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -46,8 +46,10 @@ def exposed_init_model(self, kvargs): self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] self.vit_rank_id = kvargs["vit_rank_id"] - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + if self.args.run_mode != "visual_only": + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) self.data_type = kvargs["data_type"] + self.visual_only = True if self.args.run_mode == "visual_only" else False init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -107,21 +109,28 @@ def exposed_encode(self, images: List[ImageItem]): all_img_embeds = all_img_embeds.to(torch.device("cpu")) if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - if self.args.run_mode == "visual_only": - create_afs(get_shm_name_embed(uid), cur_embed_bytes) - else: - create_shm(get_shm_name_embed(uid), cur_embed_bytes) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) + if self.visual_only: + for i, img in enumerate(images): + uid = img.uuid + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_afs(get_shm_name_embed(uid), cur_embed_bytes) # 后面替换成redis存 + else: + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_flags): + if ready: + continue + uid = uuids[i] + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + if self.args.run_mode == "visual_only": + create_afs(get_shm_name_embed(uid), cur_embed_bytes) + else: + create_shm(get_shm_name_embed(uid), cur_embed_bytes) + ids_to_set.append(uid) + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) return @@ -179,7 +188,7 @@ def _init_env(port, device_id): async def start_model_process(port, vit_tp, device_id): import multiprocessing - proc = multiprocessing.Process( + proc = multiprocessing.get_context("spawn").Process( target=_init_env, args=( port, From 27ef8f336fd34fbfb6936a24aef011f35cbae688 Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 27 Aug 2025 19:45:45 +0800 Subject: [PATCH 09/40] add vit mananger for vit-llm disaggr --- lightllm/server/api_cli.py | 11 ++ lightllm/server/api_server.py | 2 - lightllm/server/api_start.py | 108 +------------- .../server/core/objs/io_objs/group_req.py | 6 - lightllm/server/httpserver/manager.py | 32 ++--- lightllm/server/httpserver/vit_loop.py | 136 ++++++++++++++++++ lightllm/utils/start_utils.py | 8 ++ 7 files changed, 168 insertions(+), 135 deletions(-) create mode 100644 lightllm/server/httpserver/vit_loop.py diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 586467c7b..72a7b7b0d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -506,6 +506,17 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--enable_remote_vit", + action="store_true", + help="Whether to enable remote vit for multimodal service.", + ) + parser.add_argument( + "--remote_vit_port", + type=int, + default=12346, + help="The port number for the remote vit service.", + ) # redis for vit llm disaggregation parser.add_argument( "--redis_port", diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 51cd53da0..0f8a440b9 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -13,7 +13,5 @@ config_server_start(args) elif args.run_mode == "visual_only": visual_only_start(args) - elif args.run_mode == "llm_only": - llm_only_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index a15442e9b..543b8bfa8 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -5,7 +5,7 @@ import subprocess import signal from lightllm.utils.net_utils import alloc_can_use_network_port, PortLocker -from lightllm.utils.start_utils import process_manager, kill_recursive +from lightllm.utils.start_utils import process_manager, kill_recursive, is_multimodal_mode from .metrics.manager import start_metric_manager from .embed_cache.manager import start_cache_manager from lightllm.utils.log_utils import init_logger @@ -157,11 +157,13 @@ def check_and_set_args(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + args.enable_multimodal = is_multimodal_mode(args) # visual_only模式下才需要设置visual_embed_path if args.visual_embed_path is not None: assert ( args.run_mode == "visual_only" or args.run_mode == "llm_only" ), "only visual_only or llm_only mode need visual_embed_path" + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -174,13 +176,11 @@ def check_and_set_args(args): args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] # 检查visual_nccl_port数量是否足够 - if len(args.visual_nccl_ports) < args.visual_dp: + if args.visual_nccl_ports is not None and len(args.visual_nccl_ports) < args.visual_dp: raise ValueError( f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " f"but got ({len(args.visual_nccl_ports)})." ) - else: - args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] if args.visual_dp <= 0: raise ValueError("visual_dp must be a positive integer.") @@ -287,7 +287,6 @@ def normal_or_p_d_start(args): logger.info(f"all start args:{args}") ports_locker.release_port() - if args.enable_multimodal: from .visualserver.manager import start_visual_process @@ -381,105 +380,6 @@ def normal_or_p_d_start(args): return -def llm_only_start(args): - - check_and_set_args(args) - already_uesd_ports = [args.nccl_port, args.port] - - # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 - # 捕获到端口设置冲突的问题 - ports_locker = PortLocker(already_uesd_ports) - ports_locker.lock_port() - - node_world_size = args.tp // args.nnodes - can_use_ports = alloc_can_use_network_port(num=4 + node_world_size, used_nccl_ports=already_uesd_ports) - logger.info(f"alloced ports: {can_use_ports}") - ( - router_port, - detokenization_port, - detokenization_pub_port, - metric_port, - ) = can_use_ports[0:4] - can_use_ports = can_use_ports[4:] - - # 将申请好的端口放入args参数中 - args.router_port = router_port - args.detokenization_port = detokenization_port - args.detokenization_pub_port = detokenization_pub_port - args.metric_port = metric_port - - # 申请在 p d 分离模式下,会用的端口 - args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size] - # p d 分离模式下用于标识节点的id - args.pd_node_id = uuid.uuid4().int - # p 节点用来建立torch kv 传输分布组的可用端口范围 - args.pd_p_allowed_port_min = 20000 - args.pd_p_allowed_port_max = 30000 - - # p d 分离模式下,decode节点的调度间隙是0 - if args.run_mode == "decode": - args.router_max_wait_tokens = 0 - - send_and_receive_node_ip(args) # 多机用于收发node ip - set_env_start_args(args) - logger.info(f"all start args:{args}") - - ports_locker.release_port() - - process_manager.start_submodule_processes( - start_funcs=[ - start_metric_manager, - ], - start_args=[(metric_port, args)], - ) - - process_manager.start_submodule_processes( - start_funcs=[start_router_process, start_detokenization_process], - start_args=[ - (args, router_port, detokenization_port, metric_port), - (args, detokenization_port, detokenization_pub_port), - ], - ) - - # 启动 gunicorn - command = [ - "gunicorn", - "--workers", - f"{args.httpserver_workers}", - "--worker-class", - "uvicorn.workers.UvicornWorker", - "--bind", - f"{args.host}:{args.port}", - "--log-level", - "info", - "--access-logfile", - "-", - "--error-logfile", - "-", - "lightllm.server.api_http:app", - "--timeout", - f"{get_lightllm_gunicorn_time_out_seconds()}", - "--keep-alive", - f"{get_lightllm_gunicorn_keep_alive()}", - ] - - # 启动子进程 - http_server_process = subprocess.Popen(command) - - if "s3://" in args.model_dir: - from lightllm.utils.petrel_helper import s3_model_clear - - s3_model_clear(args.model_dir) - - if args.health_monitor: - from lightllm.server.health_monitor.manager import start_health_check_process - - process_manager.start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) - setup_signal_handlers(http_server_process, process_manager) - http_server_process.wait() - return - - def pd_master_start(args): set_unique_server_name(args) if args.run_mode != "pd_master": diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index 4cd0c2cf4..dfcbdd256 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -4,12 +4,6 @@ from ..req import Req -@dataclass -class VisualOnlyReqIndexes: - group_req_id: int - multimodal_params: MultimodalParams - - @dataclass class GroupReqIndexes: group_req_id: int diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 552fdcde2..6486f590e 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -81,10 +81,12 @@ def __init__( ) self.enable_multimodal = enable_multimodal - if self.enable_multimodal and self.args.run_mode != "llm_only": + if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.send_to_visual = context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + # 初始化VIT连接管理器 + from .vit_loop import VITConnectionManager + + self.vit_manager = VITConnectionManager(args, context, visual_port) self.token_id_range_start = 100000000 self.token_id_range_end = 2 ** 63 - 1 @@ -406,10 +408,7 @@ async def _encode( ), "too many multimodal items!" if multimodal_params.audios: assert self.args.enable_multimodal_audio, "audio multimodal not enabled" - if self.args.run_mode == "llm_only": - await self._get_image_embedding_from_afs(multimodal_params, sampling_params) - else: - await self._alloc_multimodal_resources(multimodal_params, sampling_params) + await self._alloc_multimodal_resources(multimodal_params, sampling_params) prompt_ids = self.tokenizer.encode( prompt, multimodal_params, add_special_tokens=sampling_params.add_special_tokens ) @@ -483,9 +482,9 @@ async def transfer_to_next_module( group_req_objs: Optional[GroupReqObjs] = None, ): - if self.pd_mode == NodeRole.P: - if self.enable_multimodal and self.args.run_mode != "llm_only": - self.send_to_visual.send_pyobj( + if self.pd_mode.is_P_or_NORMAL(): + if self.enable_multimodal: + await self.vit_manager.send_to_vit( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) @@ -504,19 +503,6 @@ async def transfer_to_next_module( ) return - if self.pd_mode == NodeRole.NORMAL or self.pd_mode == NodeRole.LLM_ONLY: - if self.enable_multimodal and self.args.run_mode != "llm_only": - self.send_to_visual.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) - else: - self.send_to_router.send_pyobj( - group_req_objs.to_group_req_index(), - protocol=pickle.HIGHEST_PROTOCOL, - ) - return - assert False, "dead code path" return diff --git a/lightllm/server/httpserver/vit_loop.py b/lightllm/server/httpserver/vit_loop.py new file mode 100644 index 000000000..d1ae69b5c --- /dev/null +++ b/lightllm/server/httpserver/vit_loop.py @@ -0,0 +1,136 @@ +import asyncio +import zmq +import zmq.asyncio +import time +import pickle +from typing import Dict, List, Optional, Any +from lightllm.utils.log_utils import init_logger +import httpx +import base64 +from dataclasses import dataclass + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + +class VITConnectionManager: + """VIT连接管理器""" + + def __init__(self, args, context, local_visual_port: int): + self.args = args + self.context = context + self.local_visual_port = local_visual_port + + self.send_to_visual = None + self.remote_vit_instances = [] + self.current_vit_index = 0 + self.remote_vit = args.enable_remote_vit + self.remote_vit_port = args.remote_vit_port + + self._setup_vit_connections() + + def _setup_vit_connections(self): + """ + 设置VIT连接,支持本地和远程VIT实例 + 支持多种连接模式: + 1. 本地VIT实例 (默认) + 2. 远程单个VIT实例 + 3. 远程多个VIT实例 (负载均衡) + """ + if self.remote_vit: + # 远程VIT实例模式 + self._setup_remote_vit_connections() + else: + self._setup_local_vit_connection() + + def _setup_local_vit_connection(self): + self.send_to_visual = self.context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + + def _setup_remote_vit_connections(self): + asyncio.create_task(self.vit_handle_loop()) + + # wait for remote vit instances + while True: + if len(self.remote_vit_instances) > 0: + break + time.sleep(1) + + def _get_vit_instance(self): + """ + 获取下一个可用的VIT实例 (轮询负载均衡) + """ + if not self.remote_vit: + return self.send_to_visual + + # 简单的轮询负载均衡 + index = (self.current_vit_index + 1) % len(self.remote_vit_instances) + self.current_vit_index = index + return self.remote_vit_instances[index] + + async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): + """ + 发送数据到VIT实例,支持本地和远程模式 + """ + instance = self._get_vit_instance() + try: + instance.send_pyobj(data, protocol=protocol) + except Exception as e: + logger.error(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") + raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") + + async def vit_handle_loop(self): + while True: + try: + id_to_vit_obj = await self._get_vit_objs() + logger.info(f"get vit_objs {id_to_vit_obj}") + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance[id].close() + except: + pass + self.remote_vit_instances.pop(id) + logger.info(f"remote vit {id} closed") + + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + self.remote_vit_instances[id] = self.context.socket(zmq.PUSH) + self.remote_vit_instances[id].connect( + f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}" + ) + await asyncio.sleep(30) + except Exception as e: + logger.exception(str(e)) + await asyncio.sleep(10) + + async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。 + """ + # 使用 config_server 服务来发现所有的 pd_master 节点。 + uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_vit" + + try: + async with httpx.AsyncClient() as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"get pd_master_objs error {response.status_code}") + return None + except Exception as e: + logger.exception(str(e)) + await asyncio.sleep(10) + return None diff --git a/lightllm/utils/start_utils.py b/lightllm/utils/start_utils.py index 372b7e1cf..824543108 100644 --- a/lightllm/utils/start_utils.py +++ b/lightllm/utils/start_utils.py @@ -111,4 +111,12 @@ def kill_recursive(proc): logger.warning(f"Process {proc.pid} does not exist.") +def is_multimodal_mode(args): + from transformers import PretrainedConfig + + model_cfg, _ = PretrainedConfig.get_config_dict(args.model_dir) + is_multimodal = "visual" in model_cfg or "vision_config" in model_cfg + return is_multimodal + + process_manager = SubmoduleManager() From 70bc95660360e9ceaeb8530ce4d362996f10b48c Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 27 Aug 2025 11:52:54 +0000 Subject: [PATCH 10/40] [0827]temp --- lightllm/server/api_cli.py | 1 + lightllm/server/api_http.py | 16 +--- lightllm/server/api_lightllm.py | 4 +- lightllm/server/api_start.py | 72 +++++++------- lightllm/server/config_server/api_http.py | 36 ++++++- .../impl/memory_cache_with_redis.py | 2 +- .../embed_cache/impl/naive_memory_cache.py | 2 +- lightllm/server/embed_cache/utils.py | 5 +- lightllm/server/httpserver/manager.py | 71 ++++++++++++++ lightllm/server/pd_io_struct.py | 9 ++ lightllm/server/visualserver/manager.py | 96 ++++++++++++------- .../visualserver/model_infer/model_rpc.py | 4 +- 12 files changed, 226 insertions(+), 92 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 72a7b7b0d..564b33135 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -338,6 +338,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument("--metric_gateway", type=str, default=None, help="address for collecting monitoring metrics") parser.add_argument("--job_name", type=str, default="lightllm", help="job name for monitor") parser.add_argument("--visual_embed_path", type=str, default=None, help="path for vit embed") + parser.add_argument("--visual_only_port", type=int, default=18097, help="port for visual only server") parser.add_argument( "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" ) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 43f9780e8..1d9da925c 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -42,6 +42,8 @@ from .httpserver.manager import HttpServerManager from .visualserver.manager import VisualManager from .httpserver_for_pd_master.manager import HttpServerManagerForPDMaster + +# from .visualserver.manager import VisualManager from .api_lightllm import lightllm_get_score, lightllm_get_image_embedding from lightllm.utils.envs_utils import get_env_start_args, get_lightllm_websocket_max_message_size from lightllm.utils.log_utils import init_logger @@ -69,7 +71,7 @@ class G_Objs: args: object = None g_generate_func: Callable = None g_generate_stream_func: Callable = None - httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster, VisualManager] = None + httpserver_manager: Union[HttpServerManager, HttpServerManagerForPDMaster] = None visual_manager: VisualManager = None shared_token_load: TokenLoad = None @@ -93,13 +95,6 @@ def set_args(self, args): ) elif args.run_mode == "visual_only": self.metric_client = MetricClient(args.metric_port) - self.httpserver_manager = VisualManager( - args, - next_module_port=None, - visual_port=args.visual_port, - cache_port=None, - visual_model_rpc_ports=args.visual_model_rpc_ports, - ) elif args.run_mode == "llm_only": init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -372,10 +367,7 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - if g_objs.args.run_mode == "visual_only": - await g_objs.httpserver_manager.wait_to_model_ready() - loop.create_task(g_objs.httpserver_manager.loop_for_fwd_visual_only()) - else: + if g_objs.args.run_mode != "visual_only": loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_lightllm.py b/lightllm/server/api_lightllm.py index 1812facdf..ecf113f38 100644 --- a/lightllm/server/api_lightllm.py +++ b/lightllm/server/api_lightllm.py @@ -154,7 +154,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) -async def lightllm_get_image_embedding(request: Request, httpserver_manager: VisualManager) -> Response: +async def lightllm_get_image_embedding(request: Request, httpserver_manager: HttpServerManager) -> Response: request_dict = await request.json() # request_dict: {'parameters': {'max_new_tokens': 128}, # 'multimodal_params': {'images': [{'type': 'base64', 'data': 'base64'}]}} @@ -165,6 +165,6 @@ async def lightllm_get_image_embedding(request: Request, httpserver_manager: Vis multimodal_params_dict = request_dict.get("multimodal_params", {}) multimodal_params = MultimodalParams(**multimodal_params_dict) - await httpserver_manager.generate(sampling_params, multimodal_params, request=request) + await httpserver_manager.get_image_embeding(sampling_params, multimodal_params, request=request) return JSONResponse({"message": "OK"}, status_code=200) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 543b8bfa8..a20e60793 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -21,14 +21,12 @@ logger = init_logger(__name__) -def setup_signal_handlers(http_server_process, process_manager, redis_process=None): +def setup_signal_handlers(http_server_process, process_manager): def signal_handler(sig, frame): if sig == signal.SIGINT: logger.info("Received SIGINT (Ctrl+C), forcing immediate exit...") if http_server_process: kill_recursive(http_server_process) - if redis_process and redis_process.poll() is None: - redis_process.terminate() process_manager.terminate_all_processes() logger.info("All processes have been forcefully terminated.") @@ -51,19 +49,6 @@ def signal_handler(sig, frame): logger.warning("HTTP server did not exit in time, killing it...") kill_recursive(http_server_process) - # 优雅关闭Redis - if redis_process and redis_process.poll() is None: - redis_process.send_signal(signal.SIGTERM) - start_time = time.time() - while (time.time() - start_time) < 10: - if redis_process.poll() is not None: - logger.info("Redis service has exited gracefully") - break - time.sleep(0.5) - else: - logger.warning("Redis service did not exit in time, killing it...") - redis_process.terminate() - process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) @@ -73,8 +58,6 @@ def signal_handler(sig, frame): logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") - if redis_process: - logger.info(f"redis service pid {redis_process.pid}") return @@ -159,10 +142,10 @@ def check_and_set_args(args): args.enable_multimodal = is_multimodal_mode(args) # visual_only模式下才需要设置visual_embed_path - if args.visual_embed_path is not None: + if args.visual_only_port is not None: assert ( args.run_mode == "visual_only" or args.run_mode == "llm_only" - ), "only visual_only or llm_only mode need visual_embed_path" + ), "only visual_only or llm_only mode need visual_only_port" # 检查GPU数量是否足够 if args.visual_gpu_ids is None: @@ -449,16 +432,17 @@ def visual_only_start(args): return already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] can_use_ports = alloc_can_use_network_port( - num=4 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=5 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( router_port, visual_port, audio_port, + cache_port, metric_port, - ) = can_use_ports[0:4] - can_use_ports = can_use_ports[4:] + ) = can_use_ports[0:5] + can_use_ports = can_use_ports[5:] visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -470,6 +454,7 @@ def visual_only_start(args): args.router_port = router_port args.visual_port = visual_port args.audio_port = audio_port + args.cache_port = cache_port args.metric_port = metric_port args.visual_model_rpc_ports = visual_model_tp_ports @@ -484,17 +469,33 @@ def visual_only_start(args): start_args=[(metric_port, args)], ) - # if args.enable_multimodal_audio: - # from .audioserver.manager import start_audio_process + from .visualserver.manager import start_visual_process + + process_manager.start_submodule_processes( + start_funcs=[ + start_cache_manager, + ], + start_args=[(cache_port, args)], + ) + process_manager.start_submodule_processes( + start_funcs=[ + start_visual_process, + ], + start_args=[ + (args, audio_port, visual_port, cache_port, visual_model_tp_ports), + ], + ) + if args.enable_multimodal_audio: + from .audioserver.manager import start_audio_process - # process_manager.start_submodule_processes( - # start_funcs=[ - # start_audio_process, - # ], - # start_args=[ - # (args, router_port, audio_port, cache_port), - # ], - # ) + process_manager.start_submodule_processes( + start_funcs=[ + start_audio_process, + ], + start_args=[ + (args, router_port, audio_port, cache_port), + ], + ) # 启动 gunicorn command = [ @@ -540,9 +541,6 @@ def config_server_start(args): if args.run_mode != "config_server": return - # 启动Redis服务(如果指定) - redis_process = start_redis_service(args) - logger.info(f"all start args:{args}") set_env_start_args(args) @@ -570,5 +568,5 @@ def config_server_start(args): ] http_server_process = subprocess.Popen(command) - setup_signal_handlers(http_server_process, process_manager, redis_process) + setup_signal_handlers(http_server_process, process_manager) http_server_process.wait() diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index 4f19a3bdd..56645f47f 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -8,7 +8,7 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger -from ..pd_io_struct import PD_Master_Obj +from ..pd_io_struct import PD_Master_Obj, Visual_Server_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.process_check import start_parent_check_thread @@ -18,7 +18,9 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} +registered_visual_server_obj: Dict[str, Visual_Server_Obj] = {} registered_pd_master_obj_lock = Lock() +registered_visual_server_obj_lock = Lock() global_req_id = 0 global_req_id_lock = Lock() @@ -71,6 +73,30 @@ async def websocket_endpoint(websocket: WebSocket): return +@app.websocket("/visual_server_register") +async def visual_websocket_endpoint(websocket: WebSocket): + await websocket.accept() + client_ip, client_port = websocket.client + logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") + registered_visual_server_obj: Visual_Server_Obj = pickle.loads(await websocket.receive_bytes()) + logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") + with registered_visual_server_obj_lock: + registered_visual_server_obj_lock[registered_visual_server_obj.node_id] = registered_visual_server_obj + + try: + while True: + data = await websocket.receive_text() + assert data == "heartbeat" + except (WebSocketDisconnect, Exception, RuntimeError) as e: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} has error {str(e)}") + logger.exception(str(e)) + finally: + logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") + with registered_visual_server_obj_lock: + registered_visual_server_obj.pop(registered_visual_server_obj.node_id, None) + return + + @app.get("/registered_objects") async def get_registered_objects(): with registered_pd_master_obj_lock: @@ -79,6 +105,14 @@ async def get_registered_objects(): return {"data": base64_encoded} +@app.get("/registered_visual_server_objects") +async def get_vit_registered_objects(): + with registered_visual_server_obj_lock: + serialized_data = pickle.dumps(registered_visual_server_obj) + base64_encoded = base64.b64encode(serialized_data).decode("utf-8") + return {"data": base64_encoded} + + @app.get("/allocate_global_unique_id_range") async def allocate_global_id_range(): """ diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 48b4b5680..9996245db 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -7,7 +7,7 @@ import time from collections import deque import multiprocessing.shared_memory as shm -from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, free_afs, EmbedRefCountRedis +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, EmbedRefCountRedis from .naive_memory_cache import Record, InMemoryCache from lightllm.utils.log_utils import init_logger diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 8b0f83f02..7f0f5cf90 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -7,7 +7,7 @@ import time from collections import deque import multiprocessing.shared_memory as shm -from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, free_afs +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 7fd9978bf..4be50177f 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -79,6 +79,7 @@ def get_shm_name_data(uid): def get_shm_name_embed(uid): return str(uid) + "-embed" + """ Importable Redis-backed MD5 refcount with LRU eviction. @@ -108,7 +109,7 @@ class EmbedRefCountRedis: def __init__( self, redis_url: str = "redis://localhost:6379/0", - capacity: int = 50_000, + capacity: int = 50000, evict_fraction: float = 0.2, key_prefix: str = "md5:", image_embed_dir: str = None, @@ -376,4 +377,4 @@ def _delete_afs_files(self, victims: List[str]) -> None: else return {0} -- 逐出失败,没有足够的候选 end -""" \ No newline at end of file +""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6486f590e..85e7fbc8e 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,6 +13,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes, VisualOnlyReqIndexes from fastapi import Request from ..tokenizer import get_tokenizer from ..pd_io_struct import NodeRole @@ -87,6 +88,12 @@ def __init__( from .vit_loop import VITConnectionManager self.vit_manager = VITConnectionManager(args, context, visual_port) + # self.send_to_visual = context.socket(zmq.PUSH) + # if self.args.run_mode == "llm_only": + # self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{self.args.visual_only_port}") + # else: + # self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") + # self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.token_id_range_start = 100000000 self.token_id_range_end = 2 ** 63 - 1 @@ -268,6 +275,70 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert False, "dead code path" return group_request_id + async def _log_req_header_for_visual_only(self, request_headers, group_request_id: int, image_count: int): + + x_request_id = request_headers.get("X-Request-Id", "") + x_session_id = request_headers.get("X-Session-Id", "") + + format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") + logger.info( + f"recieved req X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_in_time} " + f"lightllm_req_id:{group_request_id} " + f"image_count:{image_count}" + ) + return + + async def _initialize_multimodal_metadata( + self, multimodal_params: MultimodalParams, sampling_params: SamplingParams + ): + for img in multimodal_params.images: + self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) + data = img.read() + # must after init_imageitem_extral_params + token_num = self.tokenizer.get_image_token_length(img) + md5sum = "{}_{}".format( + hashlib.md5(data).hexdigest(), + hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), + ) + img.uuid = int(md5sum, 16) + img.token_num = token_num + + # async def get_image_embeding( + # self, + # sampling_params: SamplingParams, + # multimodal_params: MultimodalParams, + # request: Request, + # is_health_req: bool = False, + # ) -> Tuple[int, str, dict, FinishStatus]: + + # request_headers = request.headers if request is not None else {} + # group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + # try: + # await multimodal_params.verify_and_preload(request) + # image_count = len(multimodal_params.images) + # # 记录请求到达的相关信息 + # await self._log_req_header_for_visual_only(request_headers, group_request_id, image_count) + # assert ( + # len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + # ), "too many multimodal items!" + + # await self._initialize_multimodal_metadata(multimodal_params, sampling_params) + + # visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, multimodal_params=multimodal_params) + + # self.send_to_visual.send_pyobj( + # visual_req_status, + # protocol=pickle.HIGHEST_PROTOCOL, + # ) + + # except Exception as e: + # logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + # await self.abort(group_request_id, multimodal_params) + # raise e + # return + async def generate( self, prompt: Union[str, List[int]], diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 938c2a06a..351ee36fa 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -73,6 +73,15 @@ def to_log_str(self): return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}" +@dataclass +class Visual_Server_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"Visual_Server host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + @dataclass class UpKVStatus: type: str = "kv_move_status" diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 98d587725..de15618c6 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -17,6 +17,7 @@ from typing import Union, List, Tuple, Dict, Optional from ..req_id_generator import ReqIDGenerator from lightllm.server.core.objs.io_objs import GroupReqObjs +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from lightllm.server.multimodal_params import MultimodalParams, ImageItem @@ -43,15 +44,17 @@ def __init__( self.visual_only = True if self.args.run_mode == "visual_only" else False context = zmq.Context(2) self.id_gen = ReqIDGenerator() - if not self.visual_only: + self.recv_from_httpserver = context.socket(zmq.PULL) + if self.visual_only: + self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{self.args.visual_only_port}") + else: + self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.cache_port = cache_port + self.memory_cache = MemoryCacheWithRedis(args) self.waiting_reqs_from_httpserver: List[GroupReqIndexes] = [] self.waiting_reqs_visual_only: List[VisualOnlyReqIndexes] = [] self.model_weightdir = args.model_dir @@ -248,36 +251,37 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): sampling_params.group_request_id = group_request_id return group_request_id - async def generate( - self, - sampling_params: SamplingParams, - multimodal_params: MultimodalParams, - request: Request, - is_health_req: bool = False, - ) -> Tuple[int, str, dict, FinishStatus]: - - request_headers = request.headers if request is not None else {} - group_request_id = self.alloc_req_id(sampling_params, is_health_req) - - try: - await multimodal_params.verify_and_preload(request) - image_count = len(multimodal_params.images) - # 记录请求到达的相关信息 - await self._log_req_header(request_headers, group_request_id, image_count) - assert ( - len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity - ), "too many multimodal items!" - - await self._initialize_multimodal_metadata(multimodal_params, sampling_params) - - visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, multimodal_params=multimodal_params) - self.waiting_reqs_visual_only.append(visual_req_status) - - except Exception as e: - logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") - await self.abort(group_request_id, multimodal_params) - raise e - return + # async def generate( + # self, + # sampling_params: SamplingParams, + # multimodal_params: MultimodalParams, + # request: Request, + # is_health_req: bool = False, + # ) -> Tuple[int, str, dict, FinishStatus]: + + # request_headers = request.headers if request is not None else {} + # group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + # try: + # await multimodal_params.verify_and_preload(request) + # image_count = len(multimodal_params.images) + # # 记录请求到达的相关信息 + # await self._log_req_header(request_headers, group_request_id, image_count) + # assert ( + # len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + # ), "too many multimodal items!" + + # await self._initialize_multimodal_metadata(multimodal_params, sampling_params) + + # visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, + # multimodal_params=multimodal_params) + # self.waiting_reqs_visual_only.append(visual_req_status) + + # except Exception as e: + # logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + # await self.abort(group_request_id, multimodal_params) + # raise e + # return async def abort(self, group_req_id: int, multimodal_params: MultimodalParams): logger.warning(f"aborted group_request_id {group_req_id}") @@ -285,6 +289,25 @@ async def abort(self, group_req_id: int, multimodal_params: MultimodalParams): img.is_abort = True return + async def loop_for_netio_req(self): + if not hasattr(self, "visual_recv_max_count"): + self.visual_recv_max_count = 64 + + while True: + try: + for _ in range(self.visual_recv_max_count): + recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + print(f"recv_req is {recv_req}") + if isinstance(recv_req, GroupReqIndexes): + self.waiting_reqs_from_httpserver.append(recv_req) + else: + assert False, f"Error Req Inf {recv_req}" + self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) + except zmq.ZMQError: + # 当队列已经开始清空的时候,将一次接受数量下调 + self.visual_recv_max_count = 64 + await asyncio.sleep(0.01) + def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() @@ -313,6 +336,9 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - loop.create_task(visualserver.loop_for_fwd()) + if args.run_mode == "visual_only": + loop.create_task(visualserver.loop_for_fwd_visual_only()) + else: + loop.create_task(visualserver.loop_for_fwd()) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 17355ede8..4947fe04d 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -85,7 +85,9 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") self.model.load_model(weight_dir) + print("begin load model") self.model = self.model.cuda() + print("load model OK") except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -188,7 +190,7 @@ def _init_env(port, device_id): async def start_model_process(port, vit_tp, device_id): import multiprocessing - proc = multiprocessing.get_context("spawn").Process( + proc = multiprocessing.Process( target=_init_env, args=( port, From ded28b701755758d738e6431d5f413be3ed38bba Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 27 Aug 2025 20:49:20 +0800 Subject: [PATCH 11/40] update visual server mananger --- lightllm/server/api_cli.py | 10 +- lightllm/server/api_start.py | 10 +- lightllm/server/embed_cache/manager.py | 10 +- lightllm/server/visualserver/manager.py | 185 ++++-------------- .../visualserver/model_infer/model_rpc.py | 63 +++--- lightllm/utils/dist_utils.py | 10 +- 6 files changed, 97 insertions(+), 191 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 564b33135..c07be6cbe 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only", "llm_only"], + choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -337,8 +337,6 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--metric_gateway", type=str, default=None, help="address for collecting monitoring metrics") parser.add_argument("--job_name", type=str, default="lightllm", help="job name for monitor") - parser.add_argument("--visual_embed_path", type=str, default=None, help="path for vit embed") - parser.add_argument("--visual_only_port", type=int, default=18097, help="port for visual only server") parser.add_argument( "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" ) @@ -507,6 +505,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=0.03, help="""The interval of the schedule time, default is 30ms.""", ) + parser.add_argument( + "--image_embed_dir", + type=str, + default=None, + help="path for vit embed", + ) parser.add_argument( "--enable_remote_vit", action="store_true", diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index a20e60793..6b5a4ee4f 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -141,12 +141,6 @@ def check_and_set_args(args): assert args.mtp_step == 0 args.enable_multimodal = is_multimodal_mode(args) - # visual_only模式下才需要设置visual_embed_path - if args.visual_only_port is not None: - assert ( - args.run_mode == "visual_only" or args.run_mode == "llm_only" - ), "only visual_only or llm_only mode need visual_only_port" - # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) @@ -279,7 +273,7 @@ def normal_or_p_d_start(args): ], start_args=[(cache_port, args)], ) - if args.enable_multimodal_audio and args.run_mode != "llm_only": + if args.enable_multimodal_audio and not args.enable_remote_vit: from .audioserver.manager import start_audio_process process_manager.start_submodule_processes( @@ -299,7 +293,7 @@ def normal_or_p_d_start(args): ], ) - elif args.run_mode != "llm_only": + elif not args.enable_remote_vit: process_manager.start_submodule_processes( start_funcs=[ start_visual_process, diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 8059bfb2a..344e0bf1a 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -4,6 +4,7 @@ from typing import Union, Optional from lightllm.utils.graceful_utils import graceful_registry from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache +from lightllm.server.embed_cache.impl.memory_cache_with_redis import MemoryCacheWithRedis from rpyc.utils.classic import obtain @@ -53,11 +54,18 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: return self._impl.get_items_embed(ids) +def get_cache_manager(args): + if args.enable_remote_vit: + return MemoryCacheWithRedis(args) + else: + return InMemoryCache(args) + + def start_cache_manager(port: int, args, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - manager = InMemoryCache(args) + manager = get_cache_manager(args) service = CacheServer(manager) from rpyc.utils.server import ThreadedServer diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index de15618c6..d846a5bdb 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -41,63 +41,44 @@ def __init__( visual_model_rpc_ports, ): self.args = args - self.visual_only = True if self.args.run_mode == "visual_only" else False - context = zmq.Context(2) - self.id_gen = ReqIDGenerator() - self.recv_from_httpserver = context.socket(zmq.PULL) - if self.visual_only: - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{self.args.visual_only_port}") - else: - self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}") - self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) - self.send_to_next_module.connect(f"{args.zmq_mode}127.0.0.1:{next_module_port}") - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - + self.remote_vit = args.enable_remote_vit self.cache_port = cache_port self.memory_cache = MemoryCacheWithRedis(args) - self.waiting_reqs_from_httpserver: List[GroupReqIndexes] = [] - self.waiting_reqs_visual_only: List[VisualOnlyReqIndexes] = [] - self.model_weightdir = args.model_dir - self.tp_world_size = args.tp - self.vit_dp = args.visual_dp - self.vit_tp = args.visual_tp + self.waiting_reqs: List[GroupReqIndexes] = [] self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code self.visual_model_rpc_ports = visual_model_rpc_ports - self.shm_req_manager = ShmReqManager() - self.tokenizer = get_tokenizer(args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code) + self._setup_connections() + + def _setup_connections(self): + context = zmq.Context(2) + if self.remote_vit: + self.recv_from_httpserver.bind(f"tcp://*:{self.args.remote_vit_port}") + else: + self.recv_from_httpserver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}") + self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) + self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}") + self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) async def wait_to_model_ready(self): # 待完成,需要读取config_server来起多个vit self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] - for dp_rank_id in range(self.vit_dp): + for dp_rank_id in range(self.args.visual_dp): tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] - for tp_rank_id in range(self.vit_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.vit_tp + tp_rank_id] + for tp_rank_id in range(self.args.visual_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * self.args.visual_tp + tp_rank_id] rpc_model = await start_model_process( - port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp, device_id=device_id + port=tp_ports_each_dp[tp_rank_id], vit_tp=self.args.visual_tp, device_id=device_id ) self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] - for dp_rank_id in range(self.vit_dp): # async init model process - for tp_rank_id in range(self.vit_tp): + for dp_rank_id in range(self.args.visual_dp): # async init model process + for tp_rank_id in range(self.args.visual_tp): kvargs = { - "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "vit_dp": self.vit_dp, - "vit_tp": self.vit_tp, - "cache_port": self.cache_port, "tp_rank_id": tp_rank_id, "dp_rank_id": dp_rank_id, - "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id, - "data_type": self.args.data_type, - "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], - "visual_gpu_ids": self.args.visual_gpu_ids, - "quant_type": self.args.vit_quant_type, - "quant_cfg": self.args.vit_quant_cfg, - "max_batch_size": min(self.infer_batch_size // self.vit_dp, 1), } init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) @@ -108,10 +89,10 @@ async def infer_imgs(self, images: List[ImageItem]): return tasks = [] - for vit_dp_rank in range(self.vit_dp): - assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.vit_dp)] + for vit_dp_rank in range(self.args.visual_dp): + assigned_images = [images[i] for i in range(vit_dp_rank, len(images), self.args.visual_dp)] if assigned_images: - for vit_tp_rank in range(self.vit_tp): + for vit_tp_rank in range(self.args.visual_tp): task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) tasks.append(task) @@ -120,13 +101,13 @@ async def infer_imgs(self, images: List[ImageItem]): async def loop_for_fwd(self): while True: - if len(self.waiting_reqs_from_httpserver) == 0: + if len(self.waiting_reqs) == 0: await asyncio.sleep(0.01) # 10ms else: processing_group_reqs = [] images_need_infer = [] - while len(self.waiting_reqs_from_httpserver) > 0: - group_req_indexes = self.waiting_reqs_from_httpserver.pop(0) + while len(self.waiting_reqs) > 0: + group_req_indexes = self.waiting_reqs.pop(0) shm_req = self.shm_req_manager.get_req_obj_by_index(group_req_indexes.shm_req_indexes[0]) is_aborted = shm_req.is_aborted self.shm_req_manager.put_back_req_obj(shm_req) @@ -167,6 +148,21 @@ async def loop_for_fwd(self): processing_group_reqs = [] images_need_infer = [] + def _recv_reqs(self): + if self.remote_vit: + recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + for img in recv_req.multimodal_params.images: + data = img._preload_data + img._preload_data = None + md5sum = hashlib.md5(data).hexdigest() + uid = int(md5sum, 16) + # create_shm(get_shm_name_data(uid), data) + self.cache_client.root.set_items_data([uid]) + + return recv_req + else: + return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): self.visual_recv_max_count = 64 @@ -174,9 +170,9 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self._recv_reqs() if isinstance(recv_req, GroupReqIndexes): - self.waiting_reqs_from_httpserver.append(recv_req) + self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) @@ -211,103 +207,6 @@ async def loop_for_fwd_visual_only(self): # 在这里release这个image,ref-1 logger.info(f"req-id {visual_req.group_req_id} has been release ok") - async def _initialize_multimodal_metadata( - self, multimodal_params: MultimodalParams, sampling_params: SamplingParams - ): - for img in multimodal_params.images: - self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - data = img.read() - # must after init_imageitem_extral_params - token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), - ) - img.uuid = int(md5sum, 16) - img.token_num = token_num - - async def _log_req_header(self, request_headers, group_request_id: int, image_count: int): - - x_request_id = request_headers.get("X-Request-Id", "") - x_session_id = request_headers.get("X-Session-Id", "") - - format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"recieved req X-Request-Id:{x_request_id} " - f"X-Session-Id:{x_session_id} start_time:{format_in_time} " - f"lightllm_req_id:{group_request_id} " - f"image_count:{image_count}" - ) - return - - def alloc_req_id(self, sampling_params, is_health_req: bool = False): - # 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性 - # 否则会造成异常问题。目前限制 NORMAL 模式都使用内部id替换, P 和 D 模式按需设置 - # health 请求 request_id 为负数,直接返回 - if is_health_req: - return sampling_params.group_request_id - group_request_id = self.id_gen.generate_id() - - sampling_params.group_request_id = group_request_id - return group_request_id - - # async def generate( - # self, - # sampling_params: SamplingParams, - # multimodal_params: MultimodalParams, - # request: Request, - # is_health_req: bool = False, - # ) -> Tuple[int, str, dict, FinishStatus]: - - # request_headers = request.headers if request is not None else {} - # group_request_id = self.alloc_req_id(sampling_params, is_health_req) - - # try: - # await multimodal_params.verify_and_preload(request) - # image_count = len(multimodal_params.images) - # # 记录请求到达的相关信息 - # await self._log_req_header(request_headers, group_request_id, image_count) - # assert ( - # len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity - # ), "too many multimodal items!" - - # await self._initialize_multimodal_metadata(multimodal_params, sampling_params) - - # visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, - # multimodal_params=multimodal_params) - # self.waiting_reqs_visual_only.append(visual_req_status) - - # except Exception as e: - # logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") - # await self.abort(group_request_id, multimodal_params) - # raise e - # return - - async def abort(self, group_req_id: int, multimodal_params: MultimodalParams): - logger.warning(f"aborted group_request_id {group_req_id}") - for img in multimodal_params.images: - img.is_abort = True - return - - async def loop_for_netio_req(self): - if not hasattr(self, "visual_recv_max_count"): - self.visual_recv_max_count = 64 - - while True: - try: - for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) - print(f"recv_req is {recv_req}") - if isinstance(recv_req, GroupReqIndexes): - self.waiting_reqs_from_httpserver.append(recv_req) - else: - assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) - except zmq.ZMQError: - # 当队列已经开始清空的时候,将一次接受数量下调 - self.visual_recv_max_count = 64 - await asyncio.sleep(0.01) - def clean_up(self): for model_rpc in self.model_rpcs: model_rpc.rpc_server_process.kill() diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 4947fe04d..8fd22d5ad 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -39,16 +39,20 @@ def exposed_init_model(self, kvargs): import torch.distributed as dist self.args = get_env_start_args() - self.vit_dp = kvargs["vit_dp"] - self.vit_tp = kvargs["vit_tp"] + + weight_dir = (self.args.model_dir,) + cache_port = (self.args.cache_port,) + data_type = (self.args.data_type,) + quant_type = (self.args.vit_quant_type,) + quant_cfg = (self.args.vit_quant_cfg,) + max_batch_size = (min(self.args.visual_infer_batch_size // self.args.visual_dp, 1),) + self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] - self.cache_port = kvargs["cache_port"] - weight_dir = kvargs["weight_dir"] - self.vit_rank_id = kvargs["vit_rank_id"] + kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id + if self.args.run_mode != "visual_only": - self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) - self.data_type = kvargs["data_type"] + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) self.visual_only = True if self.args.run_mode == "visual_only" else False init_vision_distributed_env(kvargs) @@ -57,10 +61,10 @@ def exposed_init_model(self, kvargs): try: kvargs = { "weight_dir": weight_dir, - "data_type": self.data_type, - "quant_type": kvargs["quant_type"], - "quant_cfg": kvargs["quant_cfg"], - "max_batch_size": kvargs["max_batch_size"], + "data_type": data_type, + "quant_type": quant_type, + "quant_cfg": quant_cfg, + "max_batch_size": max_batch_size, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -111,28 +115,21 @@ def exposed_encode(self, images: List[ImageItem]): all_img_embeds = all_img_embeds.to(torch.device("cpu")) if self.tp_rank_id == 0: - if self.visual_only: - for i, img in enumerate(images): - uid = img.uuid - start, end = valid_ids[i] - cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - create_afs(get_shm_name_embed(uid), cur_embed_bytes) # 后面替换成redis存 - else: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue - uid = uuids[i] - start, end = valid_ids[i] - cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - if self.args.run_mode == "visual_only": - create_afs(get_shm_name_embed(uid), cur_embed_bytes) - else: - create_shm(get_shm_name_embed(uid), cur_embed_bytes) - ids_to_set.append(uid) - if ids_to_set: - self.cache_client.root.set_items_embed(ids_to_set) + ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + ids_to_set = [] + for i, ready in enumerate(ready_flags): + if ready: + continue + uid = uuids[i] + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + if self.args.enable_remote_vit: + create_afs(get_shm_name_embed(uid), cur_embed_bytes) + else: + create_shm(get_shm_name_embed(uid), cur_embed_bytes) + ids_to_set.append(uid) + if ids_to_set: + self.cache_client.root.set_items_embed(ids_to_set) return diff --git a/lightllm/utils/dist_utils.py b/lightllm/utils/dist_utils.py index 2d4123170..fa9075c59 100644 --- a/lightllm/utils/dist_utils.py +++ b/lightllm/utils/dist_utils.py @@ -55,19 +55,23 @@ def get_environ(environ_name): def init_vision_distributed_env(kvargs): - tp_world_size = kvargs["vit_tp"] + from lightllm.utils.envs_utils import get_env_start_args + + args = get_env_start_args() + tp_world_size = args.visual_tp dp_size = 1 tp_rank_id = kvargs["tp_rank_id"] set_dp_size(dp_size) set_dp_world_size(tp_world_size) set_current_rank_in_dp(tp_rank_id) - visual_gpu_ids = kvargs["visual_gpu_ids"] + visual_gpu_ids = args.visual_gpu_ids device_id = visual_gpu_ids[kvargs["vit_rank_id"]] set_current_device_id(device_id) torch.cuda.set_device(device_id) + visual_nccl_port = args.visual_nccl_ports[kvargs["dp_rank_id"]] dist.init_process_group( "nccl", - init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}', + init_method=f"tcp://127.0.0.1:{visual_nccl_port}", rank=kvargs["tp_rank_id"], world_size=tp_world_size, ) From 3a89cf0a1d350fdb1d49a28f95a07d425594f48f Mon Sep 17 00:00:00 2001 From: baishihao Date: Wed, 27 Aug 2025 21:52:20 +0800 Subject: [PATCH 12/40] add visual start --- lightllm/server/api_cli.py | 8 +- lightllm/server/api_http.py | 6 +- lightllm/server/api_server.py | 6 +- lightllm/server/api_start.py | 79 ++++--------------- lightllm/server/core/objs/req.py | 24 ------ .../impl/memory_cache_with_redis.py | 2 +- lightllm/server/embed_cache/manager.py | 2 +- lightllm/server/httpserver/manager.py | 39 +-------- lightllm/server/visualserver/manager.py | 25 +++--- .../visualserver/model_infer/model_rpc.py | 18 ++--- 10 files changed, 55 insertions(+), 154 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index c07be6cbe..6195dac92 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -7,7 +7,7 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--run_mode", type=str, - choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual_only"], + choices=["normal", "prefill", "decode", "pd_master", "config_server", "visual"], default="normal", help="""set run mode, normal is started for a single server, prefill decode pd_master is for pd split run mode, config_server is for pd split mode used to register pd_master node, and get pd_master node list, @@ -529,6 +529,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=6379, help="The port number for the redis service in config_server mode.", ) + parser.add_argument( + "--redis_evict_fraction", + type=float, + default=0.3, + help="The evict fraction for the redis service in config_server mode.", + ) parser.add_argument( "--start_redis", action="store_true", diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 1d9da925c..8863dd9f8 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -93,7 +93,7 @@ def set_args(self, args): args, metric_port=args.metric_port, ) - elif args.run_mode == "visual_only": + elif args.run_mode == "visual": self.metric_client = MetricClient(args.metric_port) elif args.run_mode == "llm_only": init_tokenizer(args) # for openai api @@ -160,7 +160,7 @@ def get_model_name(): @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") async def healthcheck(request: Request): - if g_objs.args.run_mode in ["pd_master", "visual_only"]: + if g_objs.args.run_mode in ["pd_master", "visual"]: return JSONResponse({"message": "Ok"}, status_code=200) if os.environ.get("DEBUG_HEALTHCHECK_RETURN_FAIL") == "true": @@ -367,7 +367,7 @@ async def startup_event(): logger.info("server start up") loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - if g_objs.args.run_mode != "visual_only": + if g_objs.args.run_mode != "visual": loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 0f8a440b9..c6700c041 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -5,13 +5,13 @@ torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess parser = make_argument_parser() args = parser.parse_args() - from .api_start import pd_master_start, normal_or_p_d_start, visual_only_start, config_server_start, llm_only_start + from .api_start import pd_master_start, normal_or_p_d_start, visual_start, config_server_start if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) - elif args.run_mode == "visual_only": - visual_only_start(args) + elif args.run_mode == "visual": + visual_start(args) else: normal_or_p_d_start(args) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 6b5a4ee4f..ddff59c4c 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -57,7 +57,8 @@ def signal_handler(sig, frame): signal.signal(signal.SIGINT, signal_handler) logger.info(f"start process pid {os.getpid()}") - logger.info(f"http server pid {http_server_process.pid}") + if http_server_process: + logger.info(f"http server pid {http_server_process.pid}") return @@ -72,7 +73,7 @@ def check_and_set_args(args): enable_mps() - if args.run_mode not in ["normal", "prefill", "decode", "llm_only", "visual_only"]: + if args.run_mode not in ["normal", "prefill", "decode", "llm_only", "visual"]: return assert args.zmq_mode in ["tcp://", "ipc:///tmp/"] @@ -420,11 +421,9 @@ def pd_master_start(args): http_server_process.wait() -def visual_only_start(args): +def visual_start(args): check_and_set_args(args) - if args.run_mode != "visual_only": - return - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] + already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.remote_vit_port] can_use_ports = alloc_can_use_network_port( num=5 + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports ) @@ -437,6 +436,7 @@ def visual_only_start(args): metric_port, ) = can_use_ports[0:5] can_use_ports = can_use_ports[5:] + print(cache_port) visual_model_tp_ports = [] for _ in range(args.visual_dp): @@ -456,13 +456,6 @@ def visual_only_start(args): set_env_start_args(args) - process_manager.start_submodule_processes( - start_funcs=[ - start_metric_manager, - ], - start_args=[(metric_port, args)], - ) - from .visualserver.manager import start_visual_process process_manager.start_submodule_processes( @@ -476,58 +469,18 @@ def visual_only_start(args): start_visual_process, ], start_args=[ - (args, audio_port, visual_port, cache_port, visual_model_tp_ports), + (args, router_port, visual_port, cache_port, visual_model_tp_ports), ], ) - if args.enable_multimodal_audio: - from .audioserver.manager import start_audio_process - - process_manager.start_submodule_processes( - start_funcs=[ - start_audio_process, - ], - start_args=[ - (args, router_port, audio_port, cache_port), - ], - ) - - # 启动 gunicorn - command = [ - "gunicorn", - "--workers", - f"{args.httpserver_workers}", - "--worker-class", - "uvicorn.workers.UvicornWorker", - "--bind", - f"{args.host}:{args.port}", - "--log-level", - "info", - "--access-logfile", - "-", - "--error-logfile", - "-", - "lightllm.server.api_http:app", - "--timeout", - f"{get_lightllm_gunicorn_time_out_seconds()}", - "--keep-alive", - f"{get_lightllm_gunicorn_keep_alive()}", - ] - - # 启动子进程 - http_server_process = subprocess.Popen(command) - - if "s3://" in args.model_dir: - from lightllm.utils.petrel_helper import s3_model_clear - - s3_model_clear(args.model_dir) - - if args.health_monitor: - from lightllm.server.health_monitor.manager import start_health_check_process - - process_manager.start_submodule_processes(start_funcs=[start_health_check_process], start_args=[(args,)]) - setup_signal_handlers(http_server_process, process_manager) - http_server_process.wait() - return + setup_signal_handlers(None, process_manager) + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + logger.info("Received keyboard interrupt, shutting down...") + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully.") + sys.exit(0) def config_server_start(args): diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 62b41221c..38397d4e5 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -161,30 +161,6 @@ def init( self.post_init() - def init_visual_only( - self, - request_id: int, - ): - # 只是为了有更好的编码辅助类型提示 - self.index_in_shm_mem: int = self.index_in_shm_mem - self.ref_count: int = self.ref_count - - self.request_id = request_id - self.group_req_id = convert_sub_id_to_group_id(request_id) - self.is_paused = False - self.finish_status = FinishStatus() - self.is_aborted = False - self.router_aborted = False - self.shm_infer_released = False - self.shm_cur_kv_len = 0 - self.shm_cur_output_len = 0 - self.candetoken_out_len = 0 - self.prompt_cache_len = 0 - self.finish_token_index = -1 - self.can_released_mark = False - - self.post_init() - def post_init(self): # 子类继承进行一些额外的初始化操作 pass diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 9996245db..1351d6b28 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -21,7 +21,7 @@ def __init__(self, args) -> None: self.redis_cache = EmbedRefCountRedis( redis_url=redis_url, capacity=args.cache_capacity, - evict_fraction=args.evict_fraction, + evict_fraction=args.redis_evict_fraction, image_embed_dir=args.image_embed_dir, ) # 这里之所以把cache * 2是因为,在分离模式下,cache 服务只是为了更新redis状态,以及维护图片cache的 token_id diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 344e0bf1a..421fc14e6 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -55,7 +55,7 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: def get_cache_manager(args): - if args.enable_remote_vit: + if args.enable_remote_vit or args.run_mode == "visual": return MemoryCacheWithRedis(args) else: return InMemoryCache(args) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 85e7fbc8e..f607e3443 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -13,7 +13,7 @@ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) from typing import Union, List, Tuple, Dict, Optional -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes, VisualOnlyReqIndexes +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from fastapi import Request from ..tokenizer import get_tokenizer from ..pd_io_struct import NodeRole @@ -143,7 +143,7 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): uid_list.append(rec["id"]) # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server - if self.args.run_mode == "llm_only": + if self.enable_remote_vit: return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) @@ -304,41 +304,6 @@ async def _initialize_multimodal_metadata( img.uuid = int(md5sum, 16) img.token_num = token_num - # async def get_image_embeding( - # self, - # sampling_params: SamplingParams, - # multimodal_params: MultimodalParams, - # request: Request, - # is_health_req: bool = False, - # ) -> Tuple[int, str, dict, FinishStatus]: - - # request_headers = request.headers if request is not None else {} - # group_request_id = self.alloc_req_id(sampling_params, is_health_req) - - # try: - # await multimodal_params.verify_and_preload(request) - # image_count = len(multimodal_params.images) - # # 记录请求到达的相关信息 - # await self._log_req_header_for_visual_only(request_headers, group_request_id, image_count) - # assert ( - # len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity - # ), "too many multimodal items!" - - # await self._initialize_multimodal_metadata(multimodal_params, sampling_params) - - # visual_req_status = VisualOnlyReqIndexes(group_req_id=group_request_id, multimodal_params=multimodal_params) - - # self.send_to_visual.send_pyobj( - # visual_req_status, - # protocol=pickle.HIGHEST_PROTOCOL, - # ) - - # except Exception as e: - # logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") - # await self.abort(group_request_id, multimodal_params) - # raise e - # return - async def generate( self, prompt: Union[str, List[int]], diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index d846a5bdb..a0a91acac 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -10,7 +10,7 @@ import inspect from fastapi import Request from ..tokenizer import get_tokenizer -from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes, VisualOnlyReqIndexes +from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes from lightllm.server.core.objs import ShmReqManager from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs import Req, FinishStatus @@ -41,9 +41,8 @@ def __init__( visual_model_rpc_ports, ): self.args = args - self.remote_vit = args.enable_remote_vit + self.remote_vit = args.enable_remote_vit or args.run_mode == "visual" self.cache_port = cache_port - self.memory_cache = MemoryCacheWithRedis(args) self.waiting_reqs: List[GroupReqIndexes] = [] self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code @@ -53,8 +52,10 @@ def __init__( def _setup_connections(self): context = zmq.Context(2) if self.remote_vit: - self.recv_from_httpserver.bind(f"tcp://*:{self.args.remote_vit_port}") + self.recv_from_remote_llm = context.socket(zmq.PULL) + self.recv_from_remote_llm.bind(f"tcp://*:{self.args.remote_vit_port}") else: + self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}") self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}") @@ -62,20 +63,22 @@ def _setup_connections(self): async def wait_to_model_ready(self): # 待完成,需要读取config_server来起多个vit - self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + visual_dp = self.args.visual_dp + visual_tp = self.args.visual_tp + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(visual_dp)] - for dp_rank_id in range(self.args.visual_dp): + for dp_rank_id in range(visual_dp): tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] - for tp_rank_id in range(self.args.visual_tp): - device_id = self.args.visual_gpu_ids[dp_rank_id * self.args.visual_tp + tp_rank_id] + for tp_rank_id in range(visual_tp): + device_id = self.args.visual_gpu_ids[dp_rank_id * visual_tp + tp_rank_id] rpc_model = await start_model_process( - port=tp_ports_each_dp[tp_rank_id], vit_tp=self.args.visual_tp, device_id=device_id + port=tp_ports_each_dp[tp_rank_id], vit_tp=visual_tp, device_id=device_id ) self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] - for dp_rank_id in range(self.args.visual_dp): # async init model process - for tp_rank_id in range(self.args.visual_tp): + for dp_rank_id in range(visual_dp): # async init model process + for tp_rank_id in range(visual_tp): kvargs = { "tp_rank_id": tp_rank_id, "dp_rank_id": dp_rank_id, diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 8fd22d5ad..5b5bb20ea 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -40,20 +40,18 @@ def exposed_init_model(self, kvargs): self.args = get_env_start_args() - weight_dir = (self.args.model_dir,) - cache_port = (self.args.cache_port,) - data_type = (self.args.data_type,) - quant_type = (self.args.vit_quant_type,) - quant_cfg = (self.args.vit_quant_cfg,) - max_batch_size = (min(self.args.visual_infer_batch_size // self.args.visual_dp, 1),) + weight_dir = self.args.model_dir + cache_port = self.args.cache_port + data_type = self.args.data_type + quant_type = self.args.vit_quant_type + quant_cfg = self.args.vit_quant_cfg + max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1) self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id - - if self.args.run_mode != "visual_only": - self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - self.visual_only = True if self.args.run_mode == "visual_only" else False + print(cache_port) + self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) From 630d3eee855864a4b7091f2d1c579e0ec3feda38 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 13:11:48 +0800 Subject: [PATCH 13/40] rename --- lightllm/server/httpserver/manager.py | 11 +- lightllm/server/httpserver/vit_connect.py | 136 ++++++++++++++++++++++ 2 files changed, 137 insertions(+), 10 deletions(-) create mode 100644 lightllm/server/httpserver/vit_connect.py diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f607e3443..d06ec357d 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -85,18 +85,9 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) # 初始化VIT连接管理器 - from .vit_loop import VITConnectionManager + from .vit_connect import VITConnectionManager self.vit_manager = VITConnectionManager(args, context, visual_port) - # self.send_to_visual = context.socket(zmq.PUSH) - # if self.args.run_mode == "llm_only": - # self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{self.args.visual_only_port}") - # else: - # self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}") - # self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) - - self.token_id_range_start = 100000000 - self.token_id_range_end = 2 ** 63 - 1 self.shm_req_manager = ShmReqManager() diff --git a/lightllm/server/httpserver/vit_connect.py b/lightllm/server/httpserver/vit_connect.py new file mode 100644 index 000000000..d1ae69b5c --- /dev/null +++ b/lightllm/server/httpserver/vit_connect.py @@ -0,0 +1,136 @@ +import asyncio +import zmq +import zmq.asyncio +import time +import pickle +from typing import Dict, List, Optional, Any +from lightllm.utils.log_utils import init_logger +import httpx +import base64 +from dataclasses import dataclass + +logger = init_logger(__name__) + + +@dataclass +class VIT_Obj: + node_id: int + host_ip_port: str + + def to_log_str(self): + return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" + + +class VITConnectionManager: + """VIT连接管理器""" + + def __init__(self, args, context, local_visual_port: int): + self.args = args + self.context = context + self.local_visual_port = local_visual_port + + self.send_to_visual = None + self.remote_vit_instances = [] + self.current_vit_index = 0 + self.remote_vit = args.enable_remote_vit + self.remote_vit_port = args.remote_vit_port + + self._setup_vit_connections() + + def _setup_vit_connections(self): + """ + 设置VIT连接,支持本地和远程VIT实例 + 支持多种连接模式: + 1. 本地VIT实例 (默认) + 2. 远程单个VIT实例 + 3. 远程多个VIT实例 (负载均衡) + """ + if self.remote_vit: + # 远程VIT实例模式 + self._setup_remote_vit_connections() + else: + self._setup_local_vit_connection() + + def _setup_local_vit_connection(self): + self.send_to_visual = self.context.socket(zmq.PUSH) + self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") + + def _setup_remote_vit_connections(self): + asyncio.create_task(self.vit_handle_loop()) + + # wait for remote vit instances + while True: + if len(self.remote_vit_instances) > 0: + break + time.sleep(1) + + def _get_vit_instance(self): + """ + 获取下一个可用的VIT实例 (轮询负载均衡) + """ + if not self.remote_vit: + return self.send_to_visual + + # 简单的轮询负载均衡 + index = (self.current_vit_index + 1) % len(self.remote_vit_instances) + self.current_vit_index = index + return self.remote_vit_instances[index] + + async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): + """ + 发送数据到VIT实例,支持本地和远程模式 + """ + instance = self._get_vit_instance() + try: + instance.send_pyobj(data, protocol=protocol) + except Exception as e: + logger.error(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") + raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") + + async def vit_handle_loop(self): + while True: + try: + id_to_vit_obj = await self._get_vit_objs() + logger.info(f"get vit_objs {id_to_vit_obj}") + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance[id].close() + except: + pass + self.remote_vit_instances.pop(id) + logger.info(f"remote vit {id} closed") + + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + self.remote_vit_instances[id] = self.context.socket(zmq.PUSH) + self.remote_vit_instances[id].connect( + f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}" + ) + await asyncio.sleep(30) + except Exception as e: + logger.exception(str(e)) + await asyncio.sleep(10) + + async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。 + """ + # 使用 config_server 服务来发现所有的 pd_master 节点。 + uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_vit" + + try: + async with httpx.AsyncClient() as client: + response = await client.get(uri) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"get pd_master_objs error {response.status_code}") + return None + except Exception as e: + logger.exception(str(e)) + await asyncio.sleep(10) + return None From 44070407f17008d2968ad16420ebfb3e986d4a61 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 16:08:18 +0800 Subject: [PATCH 14/40] add vit register loop --- lightllm/server/api_start.py | 3 + lightllm/server/httpserver/manager.py | 2 +- lightllm/server/httpserver/vit_loop.py | 136 ------------------ lightllm/server/visualserver/manager.py | 13 +- lightllm/server/visualserver/register_loop.py | 42 ++++++ .../vit_connect.py | 0 6 files changed, 58 insertions(+), 138 deletions(-) delete mode 100644 lightllm/server/httpserver/vit_loop.py create mode 100644 lightllm/server/visualserver/register_loop.py rename lightllm/server/{httpserver => visualserver}/vit_connect.py (100%) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index ddff59c4c..998ffc6a2 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -452,6 +452,9 @@ def visual_start(args): args.metric_port = metric_port args.visual_model_rpc_ports = visual_model_tp_ports + # 远程vit server 需要一个唯一的id + args.visual_node_id = uuid.uuid4().int + logger.info(f"all start args:{args}") set_env_start_args(args) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index d06ec357d..47ac71509 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -85,7 +85,7 @@ def __init__( if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) # 初始化VIT连接管理器 - from .vit_connect import VITConnectionManager + from lightllm.server.visualserver.vit_connect import VITConnectionManager self.vit_manager = VITConnectionManager(args, context, visual_port) diff --git a/lightllm/server/httpserver/vit_loop.py b/lightllm/server/httpserver/vit_loop.py deleted file mode 100644 index d1ae69b5c..000000000 --- a/lightllm/server/httpserver/vit_loop.py +++ /dev/null @@ -1,136 +0,0 @@ -import asyncio -import zmq -import zmq.asyncio -import time -import pickle -from typing import Dict, List, Optional, Any -from lightllm.utils.log_utils import init_logger -import httpx -import base64 -from dataclasses import dataclass - -logger = init_logger(__name__) - - -@dataclass -class VIT_Obj: - node_id: int - host_ip_port: str - - def to_log_str(self): - return f"VIT host_ip_port: {self.host_ip_port} node_id: {self.node_id}" - - -class VITConnectionManager: - """VIT连接管理器""" - - def __init__(self, args, context, local_visual_port: int): - self.args = args - self.context = context - self.local_visual_port = local_visual_port - - self.send_to_visual = None - self.remote_vit_instances = [] - self.current_vit_index = 0 - self.remote_vit = args.enable_remote_vit - self.remote_vit_port = args.remote_vit_port - - self._setup_vit_connections() - - def _setup_vit_connections(self): - """ - 设置VIT连接,支持本地和远程VIT实例 - 支持多种连接模式: - 1. 本地VIT实例 (默认) - 2. 远程单个VIT实例 - 3. 远程多个VIT实例 (负载均衡) - """ - if self.remote_vit: - # 远程VIT实例模式 - self._setup_remote_vit_connections() - else: - self._setup_local_vit_connection() - - def _setup_local_vit_connection(self): - self.send_to_visual = self.context.socket(zmq.PUSH) - self.send_to_visual.connect(f"{self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") - logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") - - def _setup_remote_vit_connections(self): - asyncio.create_task(self.vit_handle_loop()) - - # wait for remote vit instances - while True: - if len(self.remote_vit_instances) > 0: - break - time.sleep(1) - - def _get_vit_instance(self): - """ - 获取下一个可用的VIT实例 (轮询负载均衡) - """ - if not self.remote_vit: - return self.send_to_visual - - # 简单的轮询负载均衡 - index = (self.current_vit_index + 1) % len(self.remote_vit_instances) - self.current_vit_index = index - return self.remote_vit_instances[index] - - async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): - """ - 发送数据到VIT实例,支持本地和远程模式 - """ - instance = self._get_vit_instance() - try: - instance.send_pyobj(data, protocol=protocol) - except Exception as e: - logger.error(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") - raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") - - async def vit_handle_loop(self): - while True: - try: - id_to_vit_obj = await self._get_vit_objs() - logger.info(f"get vit_objs {id_to_vit_obj}") - for id, remote_instance in self.remote_vit_instances.items(): - if id not in id_to_vit_obj: - try: - remote_instance[id].close() - except: - pass - self.remote_vit_instances.pop(id) - logger.info(f"remote vit {id} closed") - - for id, vit_obj in id_to_vit_obj.items(): - if id not in self.remote_vit_instances: - self.remote_vit_instances[id] = self.context.socket(zmq.PUSH) - self.remote_vit_instances[id].connect( - f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}" - ) - await asyncio.sleep(30) - except Exception as e: - logger.exception(str(e)) - await asyncio.sleep(10) - - async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: - """ - get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。 - """ - # 使用 config_server 服务来发现所有的 pd_master 节点。 - uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_vit" - - try: - async with httpx.AsyncClient() as client: - response = await client.get(uri) - if response.status_code == 200: - base64data = response.json()["data"] - id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) - return id_to_vit_obj - else: - logger.error(f"get pd_master_objs error {response.status_code}") - return None - except Exception as e: - logger.exception(str(e)) - await asyncio.sleep(10) - return None diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index a0a91acac..fc1c1df3e 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -218,6 +218,17 @@ def clean_up(self): return +def create_forward_loop(args, visualserver: VisualManager, loop: asyncio.AbstractEventLoop): + if args.run_mode == "visual": + from .register_loop import register_loop + + loop.create_task(visualserver.loop_for_fwd_visual_only()) + loop.create_task(register_loop(args)) + else: + loop.create_task(visualserver.loop_for_fwd()) + return + + def start_visual_process(args, next_module_port, visual_port, cache_port, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) @@ -238,7 +249,7 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - if args.run_mode == "visual_only": + if args.run_mode == "visual": loop.create_task(visualserver.loop_for_fwd_visual_only()) else: loop.create_task(visualserver.loop_for_fwd()) diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py new file mode 100644 index 000000000..41c749749 --- /dev/null +++ b/lightllm/server/visualserver/register_loop.py @@ -0,0 +1,42 @@ +import asyncio +import pickle +import websockets +import socket +from lightllm.utils.net_utils import get_hostname_ip +from lightllm.utils.log_utils import init_logger +from .vit_connect import VIT_Obj + +logger = init_logger(__name__) + + +async def register_loop(args): + assert args.host not in ["127.0.0.1", "localhost"], "remote visual server must specify host ip" + + if args.host in ["0.0.0.0"]: + host_ip = get_hostname_ip() + else: + host_ip = args.host + + while True: + + try: + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_server_register" + async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: + + sock = websocket.transport.get_extra_info("socket") + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.port}") + + await websocket.send(pickle.dumps(vit_obj)) + logger.info(f"Sent registration vit_obj: {vit_obj}") + + while True: + await websocket.send("heartbeat") + await asyncio.sleep(60) + + except Exception as e: + logger.error("connetion to config_server has error") + logger.exception(str(e)) + await asyncio.sleep(10) + logger.info("reconnection to config_server") diff --git a/lightllm/server/httpserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py similarity index 100% rename from lightllm/server/httpserver/vit_connect.py rename to lightllm/server/visualserver/vit_connect.py From a566580abb549995fff5ef38fbf05e05bf8b135a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 28 Aug 2025 08:11:29 +0000 Subject: [PATCH 15/40] [0828]temp --- lightllm/models/internvl/model.py | 12 ++++++------ lightllm/server/httpserver/manager.py | 6 ++---- lightllm/server/visualserver/manager.py | 18 +++++++++--------- 3 files changed, 17 insertions(+), 19 deletions(-) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index a724d5668..2f08c39bf 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -71,14 +71,14 @@ def init_audioitem_extral_params( ): return - def get_image_token_length(self, img: ImageItem): - return ( - self.get_image_patch_func( - img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True - ) - * self.image_length + def get_image_patch(self, img: ImageItem): + return self.get_image_patch_func( + img.image_w, img.image_h, max_num=img.extra_params["image_patch_max_num"], use_thumbnail=True ) + def get_image_token_length(self, img: ImageItem): + return self.get_image_patch(img) * self.image_length + def get_audio_token_length(self, audio: AudioItem): L = audio.audio_length L = L if L <= 480000 else 480000 # max_length < 30s diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 47ac71509..f3d7d2b28 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -159,13 +159,11 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) + patch_num = self.tokenizer.get_image_patch(img) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), - ) + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), patch_num) md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index fc1c1df3e..ebf011418 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -47,6 +47,7 @@ def __init__( self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code self.visual_model_rpc_ports = visual_model_rpc_ports + self.shm_req_manager = ShmReqManager() self._setup_connections() def _setup_connections(self): @@ -62,7 +63,6 @@ def _setup_connections(self): self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) async def wait_to_model_ready(self): - # 待完成,需要读取config_server来起多个vit visual_dp = self.args.visual_dp visual_tp = self.args.visual_tp self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(visual_dp)] @@ -155,13 +155,13 @@ def _recv_reqs(self): if self.remote_vit: recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) for img in recv_req.multimodal_params.images: + image_patch = self.tokenizer.get_image_patch_func(img) data = img._preload_data - img._preload_data = None - md5sum = hashlib.md5(data).hexdigest() - uid = int(md5sum, 16) + # img._preload_data = None + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), image_patch) + md5 = int(md5sum, 16) # create_shm(get_shm_name_data(uid), data) - self.cache_client.root.set_items_data([uid]) - + self.cache_client.root.set_items_data([md5]) return recv_req else: return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) @@ -187,13 +187,13 @@ async def loop_for_netio_req(self): # code for visual only mode async def loop_for_fwd_visual_only(self): while True: - if len(self.waiting_reqs_visual_only) == 0: + if len(self.waiting_reqs) == 0: await asyncio.sleep(0.01) # 10ms else: images_need_infer = [] - while len(self.waiting_reqs_visual_only) > 0: - visual_req = self.waiting_reqs_visual_only.pop(0) + while len(self.waiting_reqs) > 0: + visual_req = self.waiting_reqs.pop(0) for img in visual_req.multimodal_params.images: if img.is_abort: From 1ae9cd308f34b38ce67d7428b3cf25e67fab138c Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 28 Aug 2025 10:48:04 +0000 Subject: [PATCH 16/40] [0828]temp --- lightllm/models/internvl/model.py | 1 + .../qwen_vl/layer_infer/pre_layer_infer.py | 8 ++--- .../impl/memory_cache_with_redis.py | 8 +++-- lightllm/server/httpserver/manager.py | 9 +++-- lightllm/server/multimodal_params.py | 7 ++-- lightllm/server/visualserver/manager.py | 35 ++++++++----------- 6 files changed, 31 insertions(+), 37 deletions(-) diff --git a/lightllm/models/internvl/model.py b/lightllm/models/internvl/model.py index 2f08c39bf..d1d269436 100644 --- a/lightllm/models/internvl/model.py +++ b/lightllm/models/internvl/model.py @@ -64,6 +64,7 @@ def init_imageitem_extral_params( img.extra_params["image_patch_max_num"] = 6 elif num_images > 6: img.extra_params["image_patch_max_num"] = 0 + img.patch_num = self.get_image_patch(img) return def init_audioitem_extral_params( diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index dbf2c8ea0..916c75637 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -53,11 +53,11 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei if img["token_id"] in img_start_token_ids or img["_prefill_"] is False: continue # pull the img_embeds by uid from shm or afs - if self.args.run_mode == "llm_only": - data = read_afs(get_shm_name_embed(img["uuid"])) + if self.args.enable_remote_vit: + embed = read_afs(get_shm_name_embed(img["uuid"])) else: - data = read_shm(get_shm_name_embed(img["uuid"])) - img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) + embed = read_shm(get_shm_name_embed(img["uuid"])) + img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 1351d6b28..b5dc7a25e 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -36,13 +36,15 @@ def release(self, ids: list[int]) -> None: self.redis_cache.decr(id_) def set_items_data(self, ids: list[int]) -> None: - pass + for id_ in ids: + self._records[id_].data = True def get_items_data(self, ids: list[int]) -> list[Optional[bool]]: return [self._records.get(id_).data if id_ in self._records else False for id_ in ids] def set_items_embed(self, ids: list[int]) -> None: - pass + for id in ids: + self.redis_cache.insert(id) def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: - pass + return [self.redis_cache.query_and_incre(id) for id in ids] diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f3d7d2b28..993a342f5 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -133,9 +133,9 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): item.token_num = rec["token_num"] uid_list.append(rec["id"]) - # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server - if self.enable_remote_vit: - return + # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server + # if self.enable_remote_vit: + # return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -159,11 +159,10 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, items, md5sums, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - patch_num = self.tokenizer.get_image_patch(img) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), patch_num) + md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) md5sums.append(md5sum) tokens_nums.append(token_num) datas.append(data) diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 6c06082c8..282673eaf 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -78,8 +78,7 @@ def __init__(self, **kwargs): self.token_num = None self.image_w = 0 self.image_h = 0 - self.afs_embed = False - self.is_abort = False + self.patch_num = 0 self._preload_data = None self.extra_params = {} @@ -114,8 +113,8 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None ans = self._preload_data - # self._preload_data = None - # self._data = None + self._preload_data = None + self._data = None return ans def to_dict(self): diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index ebf011418..c0a5a3b6a 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -43,6 +43,8 @@ def __init__( self.args = args self.remote_vit = args.enable_remote_vit or args.run_mode == "visual" self.cache_port = cache_port + self.visual_port = visual_port + self.next_module_port = next_module_port self.waiting_reqs: List[GroupReqIndexes] = [] self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code @@ -151,20 +153,16 @@ async def loop_for_fwd(self): processing_group_reqs = [] images_need_infer = [] - def _recv_reqs(self): - if self.remote_vit: - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) - for img in recv_req.multimodal_params.images: - image_patch = self.tokenizer.get_image_patch_func(img) - data = img._preload_data - # img._preload_data = None - md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), image_patch) - md5 = int(md5sum, 16) - # create_shm(get_shm_name_data(uid), data) - self.cache_client.root.set_items_data([md5]) - return recv_req - else: - return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + # def _recv_reqs(self): + # if self.remote_vit: + # recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + # recv_req.multimodal_params.images[:]= [ + # img for img in recv_req.multimodal_params.images + # if not self.cache_client.root.get_item_embed(img.uuid) # embed已存在的被丢弃 , ref +1 + # ] + # return recv_req + # else: + # return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): @@ -173,7 +171,7 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self._recv_reqs() + recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: @@ -196,8 +194,6 @@ async def loop_for_fwd_visual_only(self): visual_req = self.waiting_reqs.pop(0) for img in visual_req.multimodal_params.images: - if img.is_abort: - continue images_need_infer.append(img) if len(images_need_infer) == self.infer_batch_size: @@ -249,9 +245,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - if args.run_mode == "visual": - loop.create_task(visualserver.loop_for_fwd_visual_only()) - else: - loop.create_task(visualserver.loop_for_fwd()) + create_forward_loop(visualserver, loop) loop.run_until_complete(visualserver.loop_for_netio_req()) return From c99bb462875e0d1ad08946428f27f6a018c161f9 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 19:21:03 +0800 Subject: [PATCH 17/40] fix vit manager --- lightllm/server/api_cli.py | 2 +- lightllm/server/api_http.py | 20 +++-------------- lightllm/server/api_start.py | 10 +++++---- lightllm/server/config_server/api_http.py | 17 +++++++------- lightllm/server/httpserver/manager.py | 3 +++ lightllm/server/pd_io_struct.py | 9 -------- lightllm/server/visualserver/manager.py | 22 +++++++++---------- lightllm/server/visualserver/register_loop.py | 4 ++-- lightllm/server/visualserver/vit_connect.py | 6 +++-- 9 files changed, 39 insertions(+), 54 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 6195dac92..00cc66a7b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -353,7 +353,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_nccl_ports", nargs="+", type=int, - default=[29500], + default=None, help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) parser.add_argument( diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8863dd9f8..b9d6119c0 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -95,21 +95,6 @@ def set_args(self, args): ) elif args.run_mode == "visual": self.metric_client = MetricClient(args.metric_port) - elif args.run_mode == "llm_only": - init_tokenizer(args) # for openai api - SamplingParams.load_generation_cfg(args.model_dir) - self.metric_client = MetricClient(args.metric_port) - self.httpserver_manager = HttpServerManager( - args, - router_port=args.router_port, - cache_port=None, - detokenization_pub_port=args.detokenization_pub_port, - visual_port=None, - enable_multimodal=args.enable_multimodal, - metric_port=args.metric_port, - ) - dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容 - self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node) else: init_tokenizer(args) # for openai api SamplingParams.load_generation_cfg(args.model_dir) @@ -365,9 +350,10 @@ async def shutdown(): @app.on_event("startup") async def startup_event(): logger.info("server start up") + if g_objs.httpserver_manager is None: + return loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) - if g_objs.args.run_mode != "visual": - loop.create_task(g_objs.httpserver_manager.handle_loop()) + loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 998ffc6a2..65d6d2392 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -208,9 +208,9 @@ def check_and_set_args(args): def normal_or_p_d_start(args): check_and_set_args(args) - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port] + already_uesd_ports = [args.nccl_port, args.port] if args.run_mode == "decode": - already_uesd_ports = args.visual_nccl_ports + [args.nccl_port, args.port, args.pd_decode_rpyc_port] + already_uesd_ports = [args.nccl_port, args.port, args.pd_decode_rpyc_port] # 提前锁定端口,防止在单个机器上启动多个实列的时候,要到模型启动的时候才能 # 捕获到端口设置冲突的问题 @@ -219,7 +219,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=7 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + num=7 + node_world_size + args.visual_dp * args.visual_tp + args.visual_dp, used_nccl_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -239,6 +239,9 @@ def normal_or_p_d_start(args): can_use_ports = can_use_ports[args.visual_tp :] visual_model_tp_ports.append(tp_ports_for_dp) + args.visual_nccl_ports = can_use_ports[0 : args.visual_dp] + can_use_ports = can_use_ports[args.visual_dp :] + # 将申请好的端口放入args参数中 args.router_port = router_port args.detokenization_port = detokenization_port @@ -436,7 +439,6 @@ def visual_start(args): metric_port, ) = can_use_ports[0:5] can_use_ports = can_use_ports[5:] - print(cache_port) visual_model_tp_ports = [] for _ in range(args.visual_dp): diff --git a/lightllm/server/config_server/api_http.py b/lightllm/server/config_server/api_http.py index 56645f47f..3d95aa5f0 100644 --- a/lightllm/server/config_server/api_http.py +++ b/lightllm/server/config_server/api_http.py @@ -8,7 +8,8 @@ from typing import Dict, List from fastapi.responses import JSONResponse from lightllm.utils.log_utils import init_logger -from ..pd_io_struct import PD_Master_Obj, Visual_Server_Obj +from lightllm.server.visualserver.vit_connect import VIT_Obj +from ..pd_io_struct import PD_Master_Obj from .nccl_tcp_store import start_tcp_store_server from lightllm.utils.envs_utils import get_env_start_args from lightllm.utils.process_check import start_parent_check_thread @@ -18,7 +19,7 @@ app = FastAPI() registered_pd_master_objs: Dict[str, PD_Master_Obj] = {} -registered_visual_server_obj: Dict[str, Visual_Server_Obj] = {} +registered_visual_server_objs: Dict[str, VIT_Obj] = {} registered_pd_master_obj_lock = Lock() registered_visual_server_obj_lock = Lock() @@ -73,15 +74,15 @@ async def websocket_endpoint(websocket: WebSocket): return -@app.websocket("/visual_server_register") +@app.websocket("/visual_register") async def visual_websocket_endpoint(websocket: WebSocket): await websocket.accept() client_ip, client_port = websocket.client logger.info(f"ws connected from IP: {client_ip}, Port: {client_port}") - registered_visual_server_obj: Visual_Server_Obj = pickle.loads(await websocket.receive_bytes()) + registered_visual_server_obj: VIT_Obj = pickle.loads(await websocket.receive_bytes()) logger.info(f"recieved registered_visual_server_obj {registered_visual_server_obj}") with registered_visual_server_obj_lock: - registered_visual_server_obj_lock[registered_visual_server_obj.node_id] = registered_visual_server_obj + registered_visual_server_objs[registered_visual_server_obj.node_id] = registered_visual_server_obj try: while True: @@ -93,7 +94,7 @@ async def visual_websocket_endpoint(websocket: WebSocket): finally: logger.error(f"registered_visual_server_obj {registered_visual_server_obj} removed") with registered_visual_server_obj_lock: - registered_visual_server_obj.pop(registered_visual_server_obj.node_id, None) + registered_visual_server_objs.pop(registered_visual_server_obj.node_id, None) return @@ -105,10 +106,10 @@ async def get_registered_objects(): return {"data": base64_encoded} -@app.get("/registered_visual_server_objects") +@app.get("/registered_visual_objects") async def get_vit_registered_objects(): with registered_visual_server_obj_lock: - serialized_data = pickle.dumps(registered_visual_server_obj) + serialized_data = pickle.dumps(registered_visual_server_objs) base64_encoded = base64.b64encode(serialized_data).decode("utf-8") return {"data": base64_encoded} diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f3d7d2b28..777b7b322 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -697,6 +697,9 @@ async def handle_loop(self): asyncio.create_task(pd_handle_loop(self)) + if self.enable_multimodal: + asyncio.create_task(self.vit_manager.vit_handle_loop()) + while True: try: await asyncio.wait_for(self.recv_from_detokenization.recv_pyobj(), timeout=0.05) diff --git a/lightllm/server/pd_io_struct.py b/lightllm/server/pd_io_struct.py index 351ee36fa..938c2a06a 100644 --- a/lightllm/server/pd_io_struct.py +++ b/lightllm/server/pd_io_struct.py @@ -73,15 +73,6 @@ def to_log_str(self): return f"PD_MASTER host_ip_port: {self.host_ip_port} node_id: {self.node_id}" -@dataclass -class Visual_Server_Obj: - node_id: int - host_ip_port: str - - def to_log_str(self): - return f"Visual_Server host_ip_port: {self.host_ip_port} node_id: {self.node_id}" - - @dataclass class UpKVStatus: type: str = "kv_move_status" diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index ebf011418..6d9258c52 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -53,11 +53,11 @@ def __init__( def _setup_connections(self): context = zmq.Context(2) if self.remote_vit: - self.recv_from_remote_llm = context.socket(zmq.PULL) - self.recv_from_remote_llm.bind(f"tcp://*:{self.args.remote_vit_port}") + self.vit_receiver = context.socket(zmq.PULL) + self.vit_receiver.bind(f"tcp://*:{self.args.remote_vit_port}") else: - self.recv_from_httpserver = context.socket(zmq.PULL) - self.recv_from_httpserver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}") + self.vit_receiver = context.socket(zmq.PULL) + self.vit_receiver.bind(f"{self.args.zmq_mode}127.0.0.1:{self.visual_port}") self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}") self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) @@ -153,7 +153,7 @@ async def loop_for_fwd(self): def _recv_reqs(self): if self.remote_vit: - recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) for img in recv_req.multimodal_params.images: image_patch = self.tokenizer.get_image_patch_func(img) data = img._preload_data @@ -164,7 +164,7 @@ def _recv_reqs(self): self.cache_client.root.set_items_data([md5]) return recv_req else: - return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): @@ -173,7 +173,7 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self._recv_reqs() + recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): self.waiting_reqs.append(recv_req) else: @@ -182,6 +182,9 @@ async def loop_for_netio_req(self): except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 + except Exception as e: + logger.exception(f"Error in loop_for_netio_req: {e}") + raise e await asyncio.sleep(0.01) # code for visual only mode @@ -249,9 +252,6 @@ def handle_exception(loop, context): loop = asyncio.new_event_loop() loop.set_exception_handler(handle_exception) asyncio.set_event_loop(loop) - if args.run_mode == "visual": - loop.create_task(visualserver.loop_for_fwd_visual_only()) - else: - loop.create_task(visualserver.loop_for_fwd()) + create_forward_loop(args, visualserver, loop) loop.run_until_complete(visualserver.loop_for_netio_req()) return diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py index 41c749749..95a1dc069 100644 --- a/lightllm/server/visualserver/register_loop.py +++ b/lightllm/server/visualserver/register_loop.py @@ -20,7 +20,7 @@ async def register_loop(args): while True: try: - uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_server_register" + uri = f"ws://{args.config_server_host}:{args.config_server_port}/visual_register" async with websockets.connect(uri, max_queue=(2048 * 1024, 2048 * 1023)) as websocket: sock = websocket.transport.get_extra_info("socket") @@ -33,7 +33,7 @@ async def register_loop(args): while True: await websocket.send("heartbeat") - await asyncio.sleep(60) + await asyncio.sleep(40) except Exception as e: logger.error("connetion to config_server has error") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index d1ae69b5c..787d70330 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -57,6 +57,7 @@ def _setup_local_vit_connection(self): logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") def _setup_remote_vit_connections(self): + print("_setup_remote_vit_connections", "fdakpgdakgjadpgkjadk") asyncio.create_task(self.vit_handle_loop()) # wait for remote vit instances @@ -89,6 +90,7 @@ async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") async def vit_handle_loop(self): + print("vit_handle_loop", "fdakpgdakgjadpgkjadk") while True: try: id_to_vit_obj = await self._get_vit_objs() @@ -118,8 +120,8 @@ async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。 """ # 使用 config_server 服务来发现所有的 pd_master 节点。 - uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_vit" - + uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + print("uri", uri) try: async with httpx.AsyncClient() as client: response = await client.get(uri) From 00b3b533a33ff582ba67d8d6696bfc56fd51bc32 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 20:20:50 +0800 Subject: [PATCH 18/40] fix llm remote vit init --- lightllm/server/api_http.py | 4 +- lightllm/server/core/objs/req.py | 3 - lightllm/server/httpserver/manager.py | 2 +- lightllm/server/visualserver/register_loop.py | 2 +- lightllm/server/visualserver/vit_connect.py | 145 +++++++++++++----- 5 files changed, 111 insertions(+), 45 deletions(-) diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index b9d6119c0..a3604e520 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -350,10 +350,10 @@ async def shutdown(): @app.on_event("startup") async def startup_event(): logger.info("server start up") - if g_objs.httpserver_manager is None: - return loop = asyncio.get_event_loop() g_objs.set_args(get_env_start_args()) + if g_objs.httpserver_manager is None: + return loop.create_task(g_objs.httpserver_manager.handle_loop()) logger.info(f"server start up ok, loop use is {asyncio.get_event_loop()}") return diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index 38397d4e5..f96bad4ff 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -214,9 +214,6 @@ def can_release(self): # 只有管理节点有一个引用 ref_count_ok = self.ref_count == 1 can_released_mark = self.can_released_mark - print(f"self.is_aborted is {self.is_aborted}") - print(f"self.finish_status.is_finished() is {self.finish_status.is_finished()}") - print(f"self.ref_count is {self.ref_count}") if self.is_aborted and can_released_mark and ref_count_ok: return True diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e4989a5a3..1f0a8f09a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -512,7 +512,7 @@ async def transfer_to_next_module( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) - else: + if not self.enable_multimodal or self.args.enable_remote_vit: self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, diff --git a/lightllm/server/visualserver/register_loop.py b/lightllm/server/visualserver/register_loop.py index 95a1dc069..31d0f7b8a 100644 --- a/lightllm/server/visualserver/register_loop.py +++ b/lightllm/server/visualserver/register_loop.py @@ -26,7 +26,7 @@ async def register_loop(args): sock = websocket.transport.get_extra_info("socket") sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.port}") + vit_obj = VIT_Obj(node_id=args.visual_node_id, host_ip_port=f"{host_ip}:{args.remote_vit_port}") await websocket.send(pickle.dumps(vit_obj)) logger.info(f"Sent registration vit_obj: {vit_obj}") diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index 787d70330..376b85532 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -30,7 +30,7 @@ def __init__(self, args, context, local_visual_port: int): self.local_visual_port = local_visual_port self.send_to_visual = None - self.remote_vit_instances = [] + self.remote_vit_instances = {} self.current_vit_index = 0 self.remote_vit = args.enable_remote_vit self.remote_vit_port = args.remote_vit_port @@ -42,8 +42,7 @@ def _setup_vit_connections(self): 设置VIT连接,支持本地和远程VIT实例 支持多种连接模式: 1. 本地VIT实例 (默认) - 2. 远程单个VIT实例 - 3. 远程多个VIT实例 (负载均衡) + 2. 远程多个VIT实例 (负载均衡) """ if self.remote_vit: # 远程VIT实例模式 @@ -57,14 +56,86 @@ def _setup_local_vit_connection(self): logger.info(f"Connected to local VIT instance at {self.args.zmq_mode}127.0.0.1:{self.local_visual_port}") def _setup_remote_vit_connections(self): - print("_setup_remote_vit_connections", "fdakpgdakgjadpgkjadk") - asyncio.create_task(self.vit_handle_loop()) + """ + 初始化远程VIT连接,同步获取初始实例 + """ + logger.info("Setting up remote VIT connections...") - # wait for remote vit instances - while True: - if len(self.remote_vit_instances) > 0: - break + self._sync_init_vit_instances() + + retry_count = 0 + max_retries = 30 # 最多等待30秒 + while len(self.remote_vit_instances) == 0 and retry_count < max_retries: + logger.info(f"Waiting for VIT instances... (attempt {retry_count + 1}/{max_retries})") time.sleep(1) + retry_count += 1 + self._sync_init_vit_instances() + + if len(self.remote_vit_instances) == 0: + logger.warning("No VIT instances available after initialization") + else: + logger.info(f"Successfully connected to {len(self.remote_vit_instances)} VIT instances") + + def _sync_init_vit_instances(self): + """ + 同步初始化VIT实例连接 + """ + try: + # 使用同步方式获取VIT实例 + vit_objs = self._sync_get_vit_objs() + if vit_objs: + self._update_vit_connections(vit_objs) + except Exception as e: + logger.error(f"Failed to initialize VIT instances: {e}") + + def _sync_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + """ + 同步获取VIT实例信息 + """ + import requests + + uri = f"http://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" + try: + response = requests.get(uri, timeout=10) + if response.status_code == 200: + base64data = response.json()["data"] + id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) + return id_to_vit_obj + else: + logger.error(f"Failed to get VIT instances: {response.status_code}") + return None + except Exception as e: + logger.error(f"Error getting VIT instances: {e}") + return None + + def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): + """ + 更新VIT连接,添加新的连接,关闭失效的连接 + """ + # 关闭不再存在的连接 + closed_ids = [] + for id, remote_instance in self.remote_vit_instances.items(): + if id not in id_to_vit_obj: + try: + remote_instance.close() + except: + pass + closed_ids.append(id) + logger.info(f"Closed VIT connection {id}") + + for id in closed_ids: + self.remote_vit_instances.pop(id) + + # 建立新的连接 + for id, vit_obj in id_to_vit_obj.items(): + if id not in self.remote_vit_instances: + try: + socket = self.context.socket(zmq.PUSH) + socket.connect(f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}") + self.remote_vit_instances[id] = socket + logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") + except Exception as e: + logger.error(f"Failed to connect to VIT instance {id}: {e}") def _get_vit_instance(self): """ @@ -73,10 +144,13 @@ def _get_vit_instance(self): if not self.remote_vit: return self.send_to_visual + if len(self.remote_vit_instances) == 0: + raise Exception("No available VIT instances") + # 简单的轮询负载均衡 index = (self.current_vit_index + 1) % len(self.remote_vit_instances) self.current_vit_index = index - return self.remote_vit_instances[index] + return list(self.remote_vit_instances.values())[index] async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): """ @@ -86,42 +160,32 @@ async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): try: instance.send_pyobj(data, protocol=protocol) except Exception as e: - logger.error(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") - raise Exception(f"Failed to send to VIT instance {instance.host_ip_port}: {e}") + logger.error(f"Failed to send to VIT instance: {e}") + raise Exception(f"Failed to send to VIT instance: {e}") + + await self._wait_visual_embed_ready() async def vit_handle_loop(self): - print("vit_handle_loop", "fdakpgdakgjadpgkjadk") + """ + 异步VIT连接管理循环,由外部启动 + """ + logger.info("Starting VIT connection management loop") while True: try: - id_to_vit_obj = await self._get_vit_objs() - logger.info(f"get vit_objs {id_to_vit_obj}") - for id, remote_instance in self.remote_vit_instances.items(): - if id not in id_to_vit_obj: - try: - remote_instance[id].close() - except: - pass - self.remote_vit_instances.pop(id) - logger.info(f"remote vit {id} closed") - - for id, vit_obj in id_to_vit_obj.items(): - if id not in self.remote_vit_instances: - self.remote_vit_instances[id] = self.context.socket(zmq.PUSH) - self.remote_vit_instances[id].connect( - f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}" - ) + id_to_vit_obj = await self._async_get_vit_objs() + if id_to_vit_obj: + logger.debug(f"Retrieved {len(id_to_vit_obj)} VIT instances") + self._update_vit_connections(id_to_vit_obj) await asyncio.sleep(30) except Exception as e: - logger.exception(str(e)) + logger.exception(f"Error in VIT handle loop: {e}") await asyncio.sleep(10) - async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: + async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: """ - get_vit_objs 主要负责从 config_server 获取所有的vit远程服务。 + 异步获取VIT实例信息 """ - # 使用 config_server 服务来发现所有的 pd_master 节点。 uri = f"ws://{self.args.config_server_host}:{self.args.config_server_port}/registered_visual_objects" - print("uri", uri) try: async with httpx.AsyncClient() as client: response = await client.get(uri) @@ -130,9 +194,14 @@ async def _get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: id_to_vit_obj = pickle.loads(base64.b64decode(base64data)) return id_to_vit_obj else: - logger.error(f"get pd_master_objs error {response.status_code}") + logger.error(f"Failed to get VIT instances: {response.status_code}") return None except Exception as e: - logger.exception(str(e)) - await asyncio.sleep(10) + logger.exception(f"Error getting VIT instances: {e}") return None + + async def _wait_visual_embed_ready(self): + """ + 等待VIT实例的embed准备好 + """ + await asyncio.sleep(10) From 62f80c40cc0b0b2f6b8bb86992e64d65f7a07c93 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 28 Aug 2025 12:52:09 +0000 Subject: [PATCH 19/40] [0828]temp --- lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py | 2 +- lightllm/server/embed_cache/utils.py | 6 +++--- lightllm/server/visualserver/model_infer/model_rpc.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 916c75637..ed22d96ad 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -54,7 +54,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei continue # pull the img_embeds by uid from shm or afs if self.args.enable_remote_vit: - embed = read_afs(get_shm_name_embed(img["uuid"])) + embed = read_afs(get_shm_name_embed(img["uuid"], self.args.image_embed_dir)) else: embed = read_shm(get_shm_name_embed(img["uuid"])) img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1)) diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 4be50177f..4a048d8b7 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -41,10 +41,10 @@ def create_shm(name, data): print("Warning create shm {} failed because of FileExistsError!".format(name)) -def create_afs(name, data): +def create_afs(name, data, path): try: data_size = len(data) - path = os.path.join("/mtc/sangchengmeng/afs", name) + path = os.path.join(path, name) with open(path, "xb") as f: mem_view = memoryview(data) f.write(mem_view[:data_size]) @@ -60,7 +60,7 @@ def read_shm(name): return data -def read_afs(name: str, base_dir: str = "/mtc/sangchengmeng/afs") -> bytes: +def read_afs(name: str, base_dir) -> bytes: path = Path(base_dir) / name return path.read_bytes() diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 5b5bb20ea..d78454ff2 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -122,7 +122,7 @@ def exposed_encode(self, images: List[ImageItem]): start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) if self.args.enable_remote_vit: - create_afs(get_shm_name_embed(uid), cur_embed_bytes) + create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.args.image_embed_dir) else: create_shm(get_shm_name_embed(uid), cur_embed_bytes) ids_to_set.append(uid) From b673a3650a7bf3dd6622a812dbbced6473536da3 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 20:52:55 +0800 Subject: [PATCH 20/40] fix vit transfer --- lightllm/server/httpserver/manager.py | 5 +++-- lightllm/server/multimodal_params.py | 16 ++++++++++------ lightllm/server/visualserver/manager.py | 1 + lightllm/server/visualserver/vit_connect.py | 9 +++++++-- 4 files changed, 21 insertions(+), 10 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 1f0a8f09a..5170ee9c1 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -134,8 +134,8 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): uid_list.append(rec["id"]) # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server - # if self.enable_remote_vit: - # return + if self.enable_remote_vit: + return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) update_data_ids = [] @@ -512,6 +512,7 @@ async def transfer_to_next_module( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, ) + if not self.enable_multimodal or self.args.enable_remote_vit: self.send_to_router.send_pyobj( group_req_objs.to_group_req_index(), diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 282673eaf..5f38894fe 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -53,10 +53,7 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data - # self._preload_data = None - # self._data = None - return ans + return self._preload_data def to_dict(self): ret = {} @@ -112,10 +109,11 @@ async def preload(self, request: Request): def read(self): assert self._preload_data is not None - ans = self._preload_data + return self._preload_data + + def free(self): self._preload_data = None self._data = None - return ans def to_dict(self): ret = {} @@ -144,6 +142,12 @@ def __init__( self.audios = [AudioItem(**a) for a in audios] return + def free(self): + for image in self.images: + image.free() + for audio in self.audios: + audio.free() + async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 240469517..e786ee634 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -173,6 +173,7 @@ async def loop_for_netio_req(self): for _ in range(self.visual_recv_max_count): recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) if isinstance(recv_req, GroupReqIndexes): + print(recv_req, flush=True) self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index 376b85532..3d3d6a76c 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -131,7 +131,9 @@ def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): if id not in self.remote_vit_instances: try: socket = self.context.socket(zmq.PUSH) - socket.connect(f"tcp://{vit_obj.host_ip_port}:{self.args.remote_vit_port}") + print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) + ip, port = vit_obj.host_ip_port.split(":") + socket.connect(f"tcp://{ip}:{port}") self.remote_vit_instances[id] = socket logger.info(f"Connected to VIT instance {id} at {vit_obj.host_ip_port}") except Exception as e: @@ -158,11 +160,14 @@ async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): """ instance = self._get_vit_instance() try: + print(instance, flush=True) instance.send_pyobj(data, protocol=protocol) except Exception as e: logger.error(f"Failed to send to VIT instance: {e}") raise Exception(f"Failed to send to VIT instance: {e}") - + finally: + # 释放图片资源 + data.multimodal_params.free() await self._wait_visual_embed_ready() async def vit_handle_loop(self): From 67a3c3824a6c7c67257f1c26c07e4f1091d7cc95 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 21:47:09 +0800 Subject: [PATCH 21/40] fix connection bug --- .../qwen_vl/layer_infer/pre_layer_infer.py | 2 +- lightllm/server/api_start.py | 3 ++ .../impl/memory_cache_with_redis.py | 8 +++-- lightllm/server/embed_cache/utils.py | 2 +- lightllm/server/httpserver/manager.py | 31 +------------------ lightllm/server/visualserver/- | 17 ++++++++++ .../visualserver/model_infer/model_rpc.py | 2 +- 7 files changed, 29 insertions(+), 36 deletions(-) create mode 100644 lightllm/server/visualserver/- diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index ed22d96ad..5808c02bb 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -54,7 +54,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei continue # pull the img_embeds by uid from shm or afs if self.args.enable_remote_vit: - embed = read_afs(get_shm_name_embed(img["uuid"], self.args.image_embed_dir)) + embed = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir) else: embed = read_shm(get_shm_name_embed(img["uuid"])) img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1)) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 65d6d2392..6b3212daa 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -495,6 +495,9 @@ def config_server_start(args): logger.info(f"all start args:{args}") + if args.start_redis: + start_redis_service(args) + set_env_start_args(args) command = [ diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index b5dc7a25e..e73450e3e 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -17,7 +17,8 @@ class MemoryCacheWithRedis(InMemoryCache): def __init__(self, args) -> None: super().__init__(args) - redis_url = f"redis://{args.config_server_host}:{args.redis_port}" + redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" + print(redis_url, flush=True) self.redis_cache = EmbedRefCountRedis( redis_url=redis_url, capacity=args.cache_capacity, @@ -28,6 +29,7 @@ def __init__(self, args) -> None: # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 # 硬盘里的图片image embed 数量。 self.cache_capacity = args.cache_capacity * 2 + print(self.redis_cache.stats(), flush=True) def release(self, ids: list[int]) -> None: with self.lock: @@ -44,7 +46,7 @@ def get_items_data(self, ids: list[int]) -> list[Optional[bool]]: def set_items_embed(self, ids: list[int]) -> None: for id in ids: - self.redis_cache.insert(id) + self.redis_cache.insert(str(id)) def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: - return [self.redis_cache.query_and_incre(id) for id in ids] + return [self.redis_cache.query_and_incre(str(id)) for id in ids] diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 4a048d8b7..a9df5fc60 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -113,7 +113,7 @@ def __init__( evict_fraction: float = 0.2, key_prefix: str = "md5:", image_embed_dir: str = None, - path_ext: str = ".embed", + path_ext: str = "-embed", **redis_kwargs, ) -> None: """ diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 5170ee9c1..89e300266 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -134,7 +134,7 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas): uid_list.append(rec["id"]) # # If enable the vit/audio-llm disaggregation, no need to cache the data in the memory of the server - if self.enable_remote_vit: + if self.args.enable_remote_vit: return ready_flags = obtain(self.cache_client.root.get_items_data(uid_list)) @@ -263,35 +263,6 @@ def alloc_req_id(self, sampling_params, is_health_req: bool = False): assert False, "dead code path" return group_request_id - async def _log_req_header_for_visual_only(self, request_headers, group_request_id: int, image_count: int): - - x_request_id = request_headers.get("X-Request-Id", "") - x_session_id = request_headers.get("X-Session-Id", "") - - format_in_time = datetime.datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"recieved req X-Request-Id:{x_request_id} " - f"X-Session-Id:{x_session_id} start_time:{format_in_time} " - f"lightllm_req_id:{group_request_id} " - f"image_count:{image_count}" - ) - return - - async def _initialize_multimodal_metadata( - self, multimodal_params: MultimodalParams, sampling_params: SamplingParams - ): - for img in multimodal_params.images: - self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) - data = img.read() - # must after init_imageitem_extral_params - token_num = self.tokenizer.get_image_token_length(img) - md5sum = "{}_{}".format( - hashlib.md5(data).hexdigest(), - hashlib.md5(pickle.dumps(img.extra_params, protocol=4)).hexdigest(), - ) - img.uuid = int(md5sum, 16) - img.token_num = token_num - async def generate( self, prompt: Union[str, List[int]], diff --git a/lightllm/server/visualserver/- b/lightllm/server/visualserver/- new file mode 100644 index 000000000..68e2f733d --- /dev/null +++ b/lightllm/server/visualserver/- @@ -0,0 +1,17 @@ +529205:C 28 Aug 2025 13:07:04.500 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo +529205:C 28 Aug 2025 13:07:04.501 # Redis version=6.0.16, bits=64, commit=00000000, modified=0, pid=529205, just started +529205:C 28 Aug 2025 13:07:04.503 # Configuration loaded +529205:M 28 Aug 2025 13:07:04.505 * Running mode=standalone, port=6379. +529205:M 28 Aug 2025 13:07:04.506 # Server initialized +529205:M 28 Aug 2025 13:07:04.507 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. +529205:M 28 Aug 2025 13:07:04.509 * Ready to accept connections +529205:signal-handler (1756386794) Received SIGINT scheduling shutdown... +529205:M 28 Aug 2025 13:13:14.912 # User requested shutdown... +529205:M 28 Aug 2025 13:13:14.914 # Redis is now ready to exit, bye bye... +533706:C 28 Aug 2025 13:13:21.718 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo +533706:C 28 Aug 2025 13:13:21.719 # Redis version=6.0.16, bits=64, commit=00000000, modified=0, pid=533706, just started +533706:C 28 Aug 2025 13:13:21.720 # Configuration loaded +533706:M 28 Aug 2025 13:13:21.723 * Running mode=standalone, port=6379. +533706:M 28 Aug 2025 13:13:21.724 # Server initialized +533706:M 28 Aug 2025 13:13:21.724 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. +533706:M 28 Aug 2025 13:13:21.727 * Ready to accept connections diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index d78454ff2..be0877a46 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -121,7 +121,7 @@ def exposed_encode(self, images: List[ImageItem]): uid = uuids[i] start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) - if self.args.enable_remote_vit: + if self.args.run_mode == "visual": create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.args.image_embed_dir) else: create_shm(get_shm_name_embed(uid), cur_embed_bytes) From 676215ed515cb8b7a3fbd265f1738cf7b57a2a33 Mon Sep 17 00:00:00 2001 From: baishihao Date: Thu, 28 Aug 2025 22:15:18 +0800 Subject: [PATCH 22/40] add wait for embed for llm --- .../impl/memory_cache_with_redis.py | 20 +++++----- lightllm/server/httpserver/manager.py | 2 +- lightllm/server/multimodal_params.py | 3 ++ lightllm/server/visualserver/- | 10 +++++ lightllm/server/visualserver/vit_connect.py | 38 ++++++++++++++----- 5 files changed, 52 insertions(+), 21 deletions(-) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index e73450e3e..6f9ad6c29 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -18,7 +18,6 @@ class MemoryCacheWithRedis(InMemoryCache): def __init__(self, args) -> None: super().__init__(args) redis_url = f"redis://{args.config_server_host}:{args.redis_port}/0" - print(redis_url, flush=True) self.redis_cache = EmbedRefCountRedis( redis_url=redis_url, capacity=args.cache_capacity, @@ -29,24 +28,25 @@ def __init__(self, args) -> None: # 便于 dynamic prompt cache 的使用。所以要把cache_capacity * 2,保障其保留的图片cache > redis 服务维护的 # 硬盘里的图片image embed 数量。 self.cache_capacity = args.cache_capacity * 2 - print(self.redis_cache.stats(), flush=True) def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: self._records[id_].ref -= 1 self.redis_cache.decr(id_) - - def set_items_data(self, ids: list[int]) -> None: - for id_ in ids: - self._records[id_].data = True - - def get_items_data(self, ids: list[int]) -> list[Optional[bool]]: - return [self._records.get(id_).data if id_ in self._records else False for id_ in ids] + print(self.redis_cache.stats(), flush=True) def set_items_embed(self, ids: list[int]) -> None: for id in ids: self.redis_cache.insert(str(id)) def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: - return [self.redis_cache.query_and_incre(str(id)) for id in ids] + ret = [] + for id in ids: + # 避免重复的引用计数增加 + if self._records[id].embed: + ret.append(True) + continue + self._records[id].embed = self.redis_cache.query_and_incre(str(id)) + ret.append(self._records[id].embed) + return ret diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 89e300266..468b43028 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -87,7 +87,7 @@ def __init__( # 初始化VIT连接管理器 from lightllm.server.visualserver.vit_connect import VITConnectionManager - self.vit_manager = VITConnectionManager(args, context, visual_port) + self.vit_manager = VITConnectionManager(args, context, visual_port, self.cache_client) self.shm_req_manager = ShmReqManager() diff --git a/lightllm/server/multimodal_params.py b/lightllm/server/multimodal_params.py index 5f38894fe..43f9b611f 100644 --- a/lightllm/server/multimodal_params.py +++ b/lightllm/server/multimodal_params.py @@ -148,6 +148,9 @@ def free(self): for audio in self.audios: audio.free() + def get_all_uuids(self): + return [image.uuid for image in self.images] + [audio.uuid for audio in self.audios] + async def verify_and_preload(self, request: Request): for image in self.images: await image.preload(request) diff --git a/lightllm/server/visualserver/- b/lightllm/server/visualserver/- index 68e2f733d..170eff3a7 100644 --- a/lightllm/server/visualserver/- +++ b/lightllm/server/visualserver/- @@ -15,3 +15,13 @@ 533706:M 28 Aug 2025 13:13:21.724 # Server initialized 533706:M 28 Aug 2025 13:13:21.724 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. 533706:M 28 Aug 2025 13:13:21.727 * Ready to accept connections +533706:signal-handler (1756390331) Received SIGINT scheduling shutdown... +533706:M 28 Aug 2025 14:12:11.921 # User requested shutdown... +533706:M 28 Aug 2025 14:12:11.922 # Redis is now ready to exit, bye bye... +546119:C 28 Aug 2025 14:12:19.084 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo +546119:C 28 Aug 2025 14:12:19.086 # Redis version=6.0.16, bits=64, commit=00000000, modified=0, pid=546119, just started +546119:C 28 Aug 2025 14:12:19.087 # Configuration loaded +546119:M 28 Aug 2025 14:12:19.089 * Running mode=standalone, port=6379. +546119:M 28 Aug 2025 14:12:19.090 # Server initialized +546119:M 28 Aug 2025 14:12:19.091 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. +546119:M 28 Aug 2025 14:12:19.093 * Ready to accept connections diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index 3d3d6a76c..ecaf13ee8 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -8,6 +8,7 @@ import httpx import base64 from dataclasses import dataclass +import rpyc logger = init_logger(__name__) @@ -24,7 +25,7 @@ def to_log_str(self): class VITConnectionManager: """VIT连接管理器""" - def __init__(self, args, context, local_visual_port: int): + def __init__(self, args, context, local_visual_port: int, cache_client: rpyc.Connection): self.args = args self.context = context self.local_visual_port = local_visual_port @@ -34,6 +35,7 @@ def __init__(self, args, context, local_visual_port: int): self.current_vit_index = 0 self.remote_vit = args.enable_remote_vit self.remote_vit_port = args.remote_vit_port + self.cache_client = cache_client self._setup_vit_connections() @@ -159,16 +161,21 @@ async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): 发送数据到VIT实例,支持本地和远程模式 """ instance = self._get_vit_instance() + # 本地模式下,提前释放图片资源,降低传输开销 + if not self.remote_vit: + data.multimodal_params.free() + try: print(instance, flush=True) instance.send_pyobj(data, protocol=protocol) except Exception as e: logger.error(f"Failed to send to VIT instance: {e}") raise Exception(f"Failed to send to VIT instance: {e}") - finally: - # 释放图片资源 + + # 远程模式下,发送完以后,在释放图片资源 + await self._wait_visual_embed_ready(data) + if self.remote_vit: data.multimodal_params.free() - await self._wait_visual_embed_ready() async def vit_handle_loop(self): """ @@ -179,7 +186,6 @@ async def vit_handle_loop(self): try: id_to_vit_obj = await self._async_get_vit_objs() if id_to_vit_obj: - logger.debug(f"Retrieved {len(id_to_vit_obj)} VIT instances") self._update_vit_connections(id_to_vit_obj) await asyncio.sleep(30) except Exception as e: @@ -205,8 +211,20 @@ async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: logger.exception(f"Error getting VIT instances: {e}") return None - async def _wait_visual_embed_ready(self): - """ - 等待VIT实例的embed准备好 - """ - await asyncio.sleep(10) + async def _wait_visual_embed_ready(self, data, timeout_seconds: int = 20): + # 本地模式不需要等待 + if not self.remote_vit: + return + + uuids = data.multimodal_params.get_all_uuids() + + async def wait_for_embeds(): + while not all(self.cache_client.root.get_items_embed(uuids)): + await asyncio.sleep(0.05) + + try: + await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) + except asyncio.TimeoutError: + logger.error( + f"Req {data.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" + ) From 33923b99d764e594a2bf8b9a73fb5d575ed5d4b7 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 29 Aug 2025 07:24:02 +0000 Subject: [PATCH 23/40] [0828]fix vit embed --- .../embed_cache/impl/naive_memory_cache.py | 2 +- lightllm/server/embed_cache/utils.py | 5 ++++ lightllm/server/httpserver/manager.py | 20 +++++++++---- lightllm/server/visualserver/manager.py | 28 +++++++++++-------- 4 files changed, 37 insertions(+), 18 deletions(-) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 7f0f5cf90..a9a1bafe4 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -103,7 +103,7 @@ def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[l rec.visittime = now rec.ref += 1 else: - uid_int = int(md5sum, 16) + uid_int = md5sum self._check_and_set_new_id_range(token_num) rec = Record( id=uid_int, diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index a9df5fc60..817bd8051 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -183,6 +183,11 @@ def insert(self, md5: str) -> Tuple[bool, List[str]]: self._release_lock() raise e + def query(self, md5: str) -> bool: + """Quert if md5 exists.""" + self._wait_if_eviction() + return bool(self.r.exists(self.ref_prefix + md5)) + def query_and_incre(self, md5: str) -> bool: """Query if md5 exists and increment ref_count if found.""" self._wait_if_eviction() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 468b43028..f9d92767c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -117,10 +117,16 @@ def __init__( self.latest_success_infer_time_mark.set_value(int(time.time())) return - async def _alloc_resource(self, items, md5sums, token_nums, datas): + async def _alloc_resource(self, items, uuids, token_nums, datas): while True: - records = obtain(self.cache_client.root.alloc(md5sums, token_nums)) + # 检查这个图片在redis总是否已经存在 + # embed_exists = obtain(self.cache_client.root.get_items_embed(uuids)) + # for exist in embed_exists: + # if exist: + # continue + # else: + records = obtain(self.cache_client.root.alloc(uuids, token_nums)) if records is None: await asyncio.sleep(0.1) @@ -156,14 +162,15 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, # 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10, # 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。 async with self._resource_lock: - items, md5sums, tokens_nums, datas = [], [], [], [] + items, uuids, tokens_nums, datas = [], [], [], [] for img in multimodal_params.images: self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params) data = img.read() # must after init_imageitem_extral_params token_num = self.tokenizer.get_image_token_length(img) md5sum = "{}_{}".format(hashlib.md5(data).hexdigest(), img.patch_num) - md5sums.append(md5sum) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(img) @@ -175,12 +182,13 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, hashlib.md5(data).hexdigest(), hashlib.md5(pickle.dumps(audio.extra_params, protocol=4)).hexdigest(), ) - md5sums.append(md5sum) + uuid = int(md5sum, 16) + uuids.append(uuid) tokens_nums.append(token_num) datas.append(data) items.append(audio) - await self._alloc_resource(items, md5sums, tokens_nums, datas) + await self._alloc_resource(items, uuids, tokens_nums, datas) return async def _release_multimodal_resources(self, multimodal_params: MultimodalParams): diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index e786ee634..558eb2337 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -153,16 +153,22 @@ async def loop_for_fwd(self): processing_group_reqs = [] images_need_infer = [] - # def _recv_reqs(self): - # if self.remote_vit: - # recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) - # recv_req.multimodal_params.images[:]= [ - # img for img in recv_req.multimodal_params.images - # if not self.cache_client.root.get_item_embed(img.uuid) # embed已存在的被丢弃 , ref +1 - # ] - # return recv_req - # else: - # return self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK) + def _recv_reqs(self): + if self.remote_vit: + recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) + # recv_req.multimodal_params.images[:]= [ + # img for img in recv_req.multimodal_params.images + # if not self.cache_client.root.get_item_embed(img.uuid) # embed已存在的被丢弃 , ref +1 + # ] + uuids = [] + token_nums = [] + for img in recv_req.multimodal_params.images: + uuids.append(img.uuid) + token_nums.append(img.token_num) + self.cache_client.root.alloc(uuids, token_nums) + return recv_req + else: + return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) async def loop_for_netio_req(self): if not hasattr(self, "visual_recv_max_count"): @@ -171,7 +177,7 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) + recv_req: GroupReqIndexes = self._recv_reqs() if isinstance(recv_req, GroupReqIndexes): print(recv_req, flush=True) self.waiting_reqs.append(recv_req) From fcac8e5de4ed4c97bd5fa49fcf280d949f0c262b Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 29 Aug 2025 08:51:49 +0000 Subject: [PATCH 24/40] [0828]temp --- .../embed_cache/impl/memory_cache_with_redis.py | 13 +++++++++---- lightllm/server/visualserver/manager.py | 5 +++-- .../server/visualserver/model_infer/model_rpc.py | 1 + lightllm/server/visualserver/vit_connect.py | 2 +- 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 6f9ad6c29..258e4b84d 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -29,16 +29,21 @@ def __init__(self, args) -> None: # 硬盘里的图片image embed 数量。 self.cache_capacity = args.cache_capacity * 2 + # llm 负责release def release(self, ids: list[int]) -> None: with self.lock: for id_ in ids: self._records[id_].ref -= 1 - self.redis_cache.decr(id_) - print(self.redis_cache.stats(), flush=True) + if self.redis_cache.query(str(id_)): + self.redis_cache.decr(str(id_)) + print(self.redis_cache.stats(), flush=True) + # vit 负责set def set_items_embed(self, ids: list[int]) -> None: - for id in ids: - self.redis_cache.insert(str(id)) + with self.lock: + for id in ids: + self.redis_cache.insert(str(id)) + self._records[id].embed = True def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: ret = [] diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 558eb2337..b2929266d 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -165,7 +165,8 @@ def _recv_reqs(self): for img in recv_req.multimodal_params.images: uuids.append(img.uuid) token_nums.append(img.token_num) - self.cache_client.root.alloc(uuids, token_nums) + record = self.cache_client.root.alloc(uuids, token_nums) + print(f"record is {record}") return recv_req else: return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) @@ -179,7 +180,7 @@ async def loop_for_netio_req(self): for _ in range(self.visual_recv_max_count): recv_req: GroupReqIndexes = self._recv_reqs() if isinstance(recv_req, GroupReqIndexes): - print(recv_req, flush=True) + # print(recv_req, flush=True) self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index be0877a46..c4f0de925 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -114,6 +114,7 @@ def exposed_encode(self, images: List[ImageItem]): if self.tp_rank_id == 0: ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) + print(f"ready_flags is {ready_flags}") ids_to_set = [] for i, ready in enumerate(ready_flags): if ready: diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index ecaf13ee8..f9ae2cbd4 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -211,7 +211,7 @@ async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: logger.exception(f"Error getting VIT instances: {e}") return None - async def _wait_visual_embed_ready(self, data, timeout_seconds: int = 20): + async def _wait_visual_embed_ready(self, data, timeout_seconds: int = 100): # 本地模式不需要等待 if not self.remote_vit: return From 06f7817987cb7c7b35688de5c1be4523696dc977 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 29 Aug 2025 08:57:56 +0000 Subject: [PATCH 25/40] [0828]temp --- lightllm/server/embed_cache/impl/memory_cache_with_redis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 258e4b84d..6ffb364a5 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -44,6 +44,7 @@ def set_items_embed(self, ids: list[int]) -> None: for id in ids: self.redis_cache.insert(str(id)) self._records[id].embed = True + self._records[id].ref -= 1 def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: ret = [] From daf131849bb7ccae6096be2b7b87d02577d7e954 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 29 Aug 2025 09:49:24 +0000 Subject: [PATCH 26/40] [0829]add free_afs --- lightllm/server/embed_cache/impl/naive_memory_cache.py | 8 ++++++-- lightllm/server/embed_cache/utils.py | 5 +++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index a9a1bafe4..5b8af402d 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -7,7 +7,7 @@ import time from collections import deque import multiprocessing.shared_memory as shm -from ..utils import get_shm_name_data, get_shm_name_embed, free_shm +from ..utils import get_shm_name_data, get_shm_name_embed, free_shm, free_afs from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) @@ -77,7 +77,11 @@ def _clear(self, free_max_count: int): if record.data: free_shm(get_shm_name_data(id)) if record.embed: - free_shm(get_shm_name_embed(id)) + # 仅vit释放掉afs里的, llm端不做释放 + if self.args.run_mode == "visual": + free_afs(get_shm_name_embed(id), self.args.image_embed_dir) + elif not self.args.enable_remote_vit: + free_shm(get_shm_name_embed(id)) del self._md5_to_record[record.md5sum] del self._records[id] self.occupied -= 1 diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 817bd8051..904d9a38f 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -72,6 +72,11 @@ def free_shm(name): shared_memory.unlink() +def free_afs(name: str, base_dir) -> None: + path = Path(base_dir) / name + path.unlink() + + def get_shm_name_data(uid): return str(uid) + "-data" From 1c169039d7ce54a11d98728999ada19ad36560e3 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 3 Sep 2025 07:25:29 +0000 Subject: [PATCH 27/40] [support]add get_image_embedding --- .../server/core/objs/io_objs/group_req.py | 4 +- lightllm/server/httpserver/manager.py | 43 +++++++++++++++++++ lightllm/server/visualserver/vit_connect.py | 1 + 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/lightllm/server/core/objs/io_objs/group_req.py b/lightllm/server/core/objs/io_objs/group_req.py index dfcbdd256..75f2c0e2f 100644 --- a/lightllm/server/core/objs/io_objs/group_req.py +++ b/lightllm/server/core/objs/io_objs/group_req.py @@ -23,7 +23,9 @@ def to_group_req_index(self): return GroupReqIndexes( group_req_id=self.group_req_id, multimodal_params=self.multimodal_params, - shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs], + shm_req_indexes=[req.index_in_shm_mem for req in self.shm_req_objs] + if self.shm_req_objs is not None + else None, time_mark=self.time_mark, ) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f9d92767c..f28935304 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -374,6 +374,49 @@ async def generate( raise e return + async def get_image_embeding( + self, + sampling_params: SamplingParams, + multimodal_params: MultimodalParams, + request: Request, + is_health_req: bool = False, + ) -> Tuple[int, str, dict, FinishStatus]: + start_time = time.time() + request_headers = request.headers if request is not None else {} + group_request_id = self.alloc_req_id(sampling_params, is_health_req) + + try: + original_multimodal_params = None + if self.is_multinode_tp_master: + original_multimodal_params = copy.deepcopy(multimodal_params) + + if self.pd_mode.is_P_or_NORMAL(): + await multimodal_params.verify_and_preload(request) + + await multimodal_params.verify_and_preload(request) + image_count = len(multimodal_params.images) + # 记录请求到达的相关信息 + + await self._log_req_header(request_headers, group_request_id) + logger.info(f"image_count:{image_count}") + assert ( + len(multimodal_params.images + multimodal_params.audios) <= self.args.cache_capacity + ), "too many multimodal items!" + + await self._alloc_multimodal_resources(multimodal_params, sampling_params) + + visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) + + await self.transfer_to_next_module_or_node( + None, sampling_params, original_multimodal_params, visual_req_status + ) + + except Exception as e: + logger.error(f"group_request_id: {group_request_id} has exception {str(e)}") + await self.abort(group_request_id) + raise e + return + def _count_multimodal_tokens(self, multimodal_params: MultimodalParams) -> Tuple[int, int]: image_tokens = 0 audio_tokens = 0 diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index f9ae2cbd4..a16b0b278 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -217,6 +217,7 @@ async def _wait_visual_embed_ready(self, data, timeout_seconds: int = 100): return uuids = data.multimodal_params.get_all_uuids() + print(f"uuids is {uuids}") async def wait_for_embeds(): while not all(self.cache_client.root.get_items_embed(uuids)): From c1d98eb8164a3f3810dbcf6bef50bacd0142ba78 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Wed, 3 Sep 2025 12:08:56 +0000 Subject: [PATCH 28/40] 0903 --- lightllm/server/embed_cache/impl/memory_cache_with_redis.py | 3 +++ lightllm/server/httpserver/manager.py | 4 ++-- lightllm/server/visualserver/manager.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 6ffb364a5..1a7ea564b 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -49,6 +49,9 @@ def set_items_embed(self, ids: list[int]) -> None: def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: ret = [] for id in ids: + # if self.redis_cache.query(str(id)): + # ret.append(True) + # continue # 避免重复的引用计数增加 if self._records[id].embed: ret.append(True) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index f28935304..5704a7ad8 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -687,8 +687,8 @@ async def recycle_resource_loop(self): for req in req_status.group_req_objs.shm_req_objs: await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) - if self.args.run_mode != "llm_only": - await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) + print("begin release") + await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) # 先保留这个关键得日志,用于方便定位重构中的问题。 if time.time() - pre_time_mark > 120: diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index b2929266d..892723ec0 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -125,7 +125,7 @@ async def loop_for_fwd(self): multimodal_params = group_req_indexes.multimodal_params - img_uuids = [img.uuid for img in multimodal_params.images if not img.afs_embed] + img_uuids = [img.uuid for img in multimodal_params.images] ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids)) for img, ready in zip(multimodal_params.images, ready_image): From cffa0a0a044b66e420d56e573ead2a8a47dcc935 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 9 Sep 2025 11:07:32 +0000 Subject: [PATCH 29/40] 0909 --- lightllm/models/vit/model.py | 2 +- .../impl/memory_cache_with_redis.py | 11 ++++++++++ lightllm/server/embed_cache/utils.py | 20 +++++++++-------- lightllm/server/httpserver/manager.py | 1 - lightllm/server/visualserver/manager.py | 14 +++++++----- lightllm/server/visualserver/vit_connect.py | 22 ++++++++++++------- 6 files changed, 45 insertions(+), 25 deletions(-) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index a114bc1dd..01bb69bdf 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -178,7 +178,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = img._preload_data + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 1a7ea564b..90f5ce11e 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -47,6 +47,17 @@ def set_items_embed(self, ids: list[int]) -> None: self._records[id].ref -= 1 def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + ret = [] + for id in ids: + print(f"id is {id}") + print(f"self.redis_cache.query(str(id)) is {self.redis_cache.query(str(id))}") + exist = self.redis_cache.query(str(id)) + ret.append(exist) + if exist: + self._records[id].embed = True + return ret + + def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]: ret = [] for id in ids: # if self.redis_cache.query(str(id)): diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 904d9a38f..101c6bec7 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -326,17 +326,19 @@ def _delete_afs_files(self, victims: List[str]) -> None: return {-1, 0} -- Not found end +--ref 递减到 0 时保留键,只更新计数与 LRU local rc = tonumber(val) - 1 -if rc <= 0 then - redis.call('DEL', ref_key) - redis.call('ZREM', zset, md5) - return {0, 1} -- Deleted -else - redis.call('SET', ref_key, rc) - local now = redis.call('TIME')[1] * 1000 - redis.call('ZADD', zset, now, md5) - return {rc, 0} -- Updated +if rc < 0 then + rc = 0 end + +redis.call('SET', ref_key, rc) + +-- 更新 LRU 时间戳(最近释放的条目更不容易被立即逐出) +local now = redis.call('TIME')[1] * 1000 +redis.call('ZADD', zset, now, md5) + +return {rc, 0} -- 未删除 """ _EVICT_AND_INSERT_LUA = r""" diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 5704a7ad8..6587e7369 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -687,7 +687,6 @@ async def recycle_resource_loop(self): for req in req_status.group_req_objs.shm_req_objs: await self.shm_req_manager.async_put_back_req_obj(req) await self.shm_req_manager.async_release_req_index(req.index_in_shm_mem) - print("begin release") await self._release_multimodal_resources(req_status.group_req_objs.multimodal_params) # 先保留这个关键得日志,用于方便定位重构中的问题。 diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 892723ec0..ef2265d59 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -160,13 +160,14 @@ def _recv_reqs(self): # img for img in recv_req.multimodal_params.images # if not self.cache_client.root.get_item_embed(img.uuid) # embed已存在的被丢弃 , ref +1 # ] - uuids = [] + uuids = [img.uuid for img in recv_req.multimodal_params.images] + already_embed = self.cache_client.root.get_items_embed(uuids) token_nums = [] - for img in recv_req.multimodal_params.images: - uuids.append(img.uuid) - token_nums.append(img.token_num) - record = self.cache_client.root.alloc(uuids, token_nums) - print(f"record is {record}") + for img, embed in zip(recv_req.multimodal_params.images, already_embed): + if not embed: + uuids.append(img.uuid) + token_nums.append(img.token_num) + self.cache_client.root.alloc(uuids, token_nums) return recv_req else: return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) @@ -182,6 +183,7 @@ async def loop_for_netio_req(self): if isinstance(recv_req, GroupReqIndexes): # print(recv_req, flush=True) self.waiting_reqs.append(recv_req) + print(f"recv_req.multimodal_params is {recv_req.multimodal_params}") else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index a16b0b278..982ddf018 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -5,6 +5,8 @@ import pickle from typing import Dict, List, Optional, Any from lightllm.utils.log_utils import init_logger +from lightllm.server.core.objs.io_objs import GroupReqObjs, GroupReqIndexes +from lightllm.server.multimodal_params import MultimodalParams import httpx import base64 from dataclasses import dataclass @@ -48,8 +50,10 @@ def _setup_vit_connections(self): """ if self.remote_vit: # 远程VIT实例模式 + print("remote") self._setup_remote_vit_connections() else: + print("not remote") self._setup_local_vit_connection() def _setup_local_vit_connection(self): @@ -156,31 +160,33 @@ def _get_vit_instance(self): self.current_vit_index = index return list(self.remote_vit_instances.values())[index] - async def send_to_vit(self, data, protocol=pickle.HIGHEST_PROTOCOL): + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL): """ 发送数据到VIT实例,支持本地和远程模式 """ instance = self._get_vit_instance() # 本地模式下,提前释放图片资源,降低传输开销 if not self.remote_vit: - data.multimodal_params.free() + req.multimodal_params.free() try: print(instance, flush=True) - instance.send_pyobj(data, protocol=protocol) + instance.send_pyobj(req, protocol=protocol) except Exception as e: logger.error(f"Failed to send to VIT instance: {e}") raise Exception(f"Failed to send to VIT instance: {e}") # 远程模式下,发送完以后,在释放图片资源 - await self._wait_visual_embed_ready(data) + await self._wait_visual_embed_ready(req) if self.remote_vit: - data.multimodal_params.free() + req.multimodal_params.free() async def vit_handle_loop(self): """ 异步VIT连接管理循环,由外部启动 """ + if not self.remote_vit: + return logger.info("Starting VIT connection management loop") while True: try: @@ -211,12 +217,12 @@ async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: logger.exception(f"Error getting VIT instances: {e}") return None - async def _wait_visual_embed_ready(self, data, timeout_seconds: int = 100): + async def _wait_visual_embed_ready(self, req: GroupReqIndexes, timeout_seconds: int = 100): # 本地模式不需要等待 if not self.remote_vit: return - uuids = data.multimodal_params.get_all_uuids() + uuids = req.multimodal_params.get_all_uuids() print(f"uuids is {uuids}") async def wait_for_embeds(): @@ -227,5 +233,5 @@ async def wait_for_embeds(): await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) except asyncio.TimeoutError: logger.error( - f"Req {data.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" + f"Req {req.group_req_id}: timeout waiting for visual embed ready after {timeout_seconds} seconds" ) From 0a296a185c67504c5a8132a8d612e44d761cd9ea Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 11 Sep 2025 12:22:52 +0000 Subject: [PATCH 30/40] 0911 --- lightllm/models/vit/model.py | 3 ++- .../impl/memory_cache_with_redis.py | 6 ++--- .../embed_cache/impl/naive_memory_cache.py | 7 ++--- lightllm/server/embed_cache/utils.py | 12 ++++++--- lightllm/server/visualserver/manager.py | 27 +++++++++++++++---- .../visualserver/model_infer/model_rpc.py | 12 +++------ lightllm/server/visualserver/vit_connect.py | 3 +-- 7 files changed, 43 insertions(+), 27 deletions(-) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 01bb69bdf..2042ee109 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -178,7 +178,8 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = img._preload_data + # image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index 90f5ce11e..d024136a8 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -36,7 +36,7 @@ def release(self, ids: list[int]) -> None: self._records[id_].ref -= 1 if self.redis_cache.query(str(id_)): self.redis_cache.decr(str(id_)) - print(self.redis_cache.stats(), flush=True) + # print(self.redis_cache.stats(), flush=True) # vit 负责set def set_items_embed(self, ids: list[int]) -> None: @@ -44,13 +44,11 @@ def set_items_embed(self, ids: list[int]) -> None: for id in ids: self.redis_cache.insert(str(id)) self._records[id].embed = True - self._records[id].ref -= 1 + self._records[id].ref -= 1 # vit端alloc之后ref+1 vit完成后ref-1 def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: ret = [] for id in ids: - print(f"id is {id}") - print(f"self.redis_cache.query(str(id)) is {self.redis_cache.query(str(id))}") exist = self.redis_cache.query(str(id)) ret.append(exist) if exist: diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index 5b8af402d..b810f5658 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -78,9 +78,10 @@ def _clear(self, free_max_count: int): free_shm(get_shm_name_data(id)) if record.embed: # 仅vit释放掉afs里的, llm端不做释放 - if self.args.run_mode == "visual": - free_afs(get_shm_name_embed(id), self.args.image_embed_dir) - elif not self.args.enable_remote_vit: + # if self.args.run_mode == "visual": + # free_afs(get_shm_name_embed(id), self.args.image_embed_dir) + # elif not self.args.enable_remote_vit: + if not self.args.run_mode == "visual": free_shm(get_shm_name_embed(id)) del self._md5_to_record[record.md5sum] del self._records[id] diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index 101c6bec7..c5e259e85 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -8,6 +8,9 @@ from pathlib import Path import multiprocessing.shared_memory as shm from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) def tensor2bytes(t: torch.Tensor): @@ -247,7 +250,7 @@ def _md5_to_afs_path(self, md5: str) -> str: """Convert md5 to AFS file path.""" if not self.image_embed_dir: return None - filename = md5 + self.path_ext + filename = self.image_embed_dir + md5 + self.path_ext return filename def _delete_afs_files(self, victims: List[str]) -> None: @@ -260,9 +263,9 @@ def _delete_afs_files(self, victims: List[str]) -> None: file_path = self._md5_to_afs_path(md5) if file_path and os.path.exists(file_path): os.remove(file_path) - print(f"Deleted AFS file: {file_path}") + logger.debug(f"Deleted AFS file: {file_path}") except Exception as e: - print(f"Warning: Failed to delete AFS file for {md5}: {e}") + logger.debug(f"Warning: Failed to delete AFS file for {md5}: {e}") # ---------------- Lua scripts ---------------- _INSERT_LUA = r""" @@ -273,6 +276,7 @@ def _delete_afs_files(self, victims: List[str]) -> None: local md5 = ARGV[1] local capacity = tonumber(ARGV[2]) +local unpack = unpack or table.unpack local ref_key = ref_prefix .. md5 if redis.call('GET', ref_key) then return {0} -- Already exists @@ -385,7 +389,7 @@ def _delete_afs_files(self, victims: List[str]) -> None: local now = redis.call('TIME')[1] * 1000 redis.call('ZADD', zset, now, new_md5) - return {1, table.unpack(victims)} -- success + victims + return {1, unpack(victims)} -- success + victims else return {0} -- 逐出失败,没有足够的候选 end diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index ef2265d59..8b6f474bd 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -153,21 +153,28 @@ async def loop_for_fwd(self): processing_group_reqs = [] images_need_infer = [] - def _recv_reqs(self): + async def _recv_reqs(self): if self.remote_vit: recv_req: GroupReqIndexes = self.vit_receiver.recv_pyobj(zmq.NOBLOCK) # recv_req.multimodal_params.images[:]= [ # img for img in recv_req.multimodal_params.images # if not self.cache_client.root.get_item_embed(img.uuid) # embed已存在的被丢弃 , ref +1 # ] + logger.info(f"Receive req {recv_req.group_req_id}, image_count:{len(recv_req.multimodal_params.images)}") uuids = [img.uuid for img in recv_req.multimodal_params.images] already_embed = self.cache_client.root.get_items_embed(uuids) + if all(already_embed): + return None token_nums = [] for img, embed in zip(recv_req.multimodal_params.images, already_embed): if not embed: uuids.append(img.uuid) token_nums.append(img.token_num) - self.cache_client.root.alloc(uuids, token_nums) + while True: + records = self.cache_client.root.alloc(uuids, token_nums) + if records is not None: + break + await asyncio.sleep(0.1) return recv_req else: return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) @@ -179,11 +186,11 @@ async def loop_for_netio_req(self): while True: try: for _ in range(self.visual_recv_max_count): - recv_req: GroupReqIndexes = self._recv_reqs() + recv_req: GroupReqIndexes = await self._recv_reqs() + if recv_req is None: + continue if isinstance(recv_req, GroupReqIndexes): - # print(recv_req, flush=True) self.waiting_reqs.append(recv_req) - print(f"recv_req.multimodal_params is {recv_req.multimodal_params}") else: assert False, f"Error Req Inf {recv_req}" self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) @@ -210,11 +217,21 @@ async def loop_for_fwd_visual_only(self): images_need_infer.append(img) if len(images_need_infer) == self.infer_batch_size: + _t0 = time.perf_counter() await self.infer_imgs(images_need_infer) + logger.info( + f"[visual] batch infer complete, image_count: {len(images_need_infer)}, " + f"elapsed_time {(time.perf_counter()-_t0) * 1000}ms" + ) images_need_infer = [] if len(images_need_infer) > 0: + _t1 = time.perf_counter() await self.infer_imgs(images_need_infer) + logger.info( + f"[visual] batch infer complete, image_count:{len(images_need_infer)}, " + f"elapsed_time {(time.perf_counter()-_t1) * 1000}ms" + ) images_need_infer = [] # 在这里release这个image,ref-1 logger.info(f"req-id {visual_req.group_req_id} has been release ok") diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index c4f0de925..cb4f5278e 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -50,7 +50,6 @@ def exposed_init_model(self, kvargs): self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id - print(cache_port) self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) init_vision_distributed_env(kvargs) @@ -87,9 +86,7 @@ def exposed_init_model(self, kvargs): else: raise Exception(f"can not support {self.model_type} now") self.model.load_model(weight_dir) - print("begin load model") self.model = self.model.cuda() - print("load model OK") except Exception as e: print("#" * 16) print("load model error:", str(e), e, type(e)) @@ -113,12 +110,11 @@ def exposed_encode(self, images: List[ImageItem]): all_img_embeds = all_img_embeds.to(torch.device("cpu")) if self.tp_rank_id == 0: - ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) - print(f"ready_flags is {ready_flags}") + # ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] - for i, ready in enumerate(ready_flags): - if ready: - continue + for i, img in enumerate(images): + # if ready: + # continue uid = uuids[i] start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index 982ddf018..b93c7c84c 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -137,7 +137,7 @@ def _update_vit_connections(self, id_to_vit_obj: Dict[int, VIT_Obj]): if id not in self.remote_vit_instances: try: socket = self.context.socket(zmq.PUSH) - print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) + # print(vit_obj.host_ip_port, self.args.remote_vit_port, flush=True) ip, port = vit_obj.host_ip_port.split(":") socket.connect(f"tcp://{ip}:{port}") self.remote_vit_instances[id] = socket @@ -223,7 +223,6 @@ async def _wait_visual_embed_ready(self, req: GroupReqIndexes, timeout_seconds: return uuids = req.multimodal_params.get_all_uuids() - print(f"uuids is {uuids}") async def wait_for_embeds(): while not all(self.cache_client.root.get_items_embed(uuids)): From 6b951563023b9f0586f85b6ed459f96de36c2468 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Thu, 11 Sep 2025 13:16:32 +0000 Subject: [PATCH 31/40] 0911-add-other-multimodal's vit dispatch --- lightllm/models/gemma3/gemma3_visual.py | 8 ++++++-- lightllm/models/llava/llava_visual.py | 8 ++++++-- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 6 +++++- lightllm/models/qwen2_vl/qwen2_visual.py | 6 +++++- lightllm/models/qwen_vl/qwen_visual.py | 7 ++++++- lightllm/models/tarsier2/tarsier2_visual.py | 7 ++++++- lightllm/models/vit/model.py | 7 +++++-- lightllm/server/visualserver/model_infer/model_rpc.py | 10 ++++++---- lightllm/server/visualserver/vit_connect.py | 1 - lightllm/utils/shm_size_check.py | 1 + 10 files changed, 46 insertions(+), 15 deletions(-) diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index b2f7a6b77..8e373c128 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -16,7 +16,8 @@ class Gemma3VisionModel: - def __init__(self): + def __init__(self, kvargs): + self.remote_vit = kvargs.get("remote_vit", False) pass def load_model(self, weight_dir): @@ -122,7 +123,10 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + if self.remote_vit: + image_data = img._preload_data + else: + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] img_tensors.append(t) diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 293bcd445..09e9f0119 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -15,7 +15,8 @@ class LlavaVisionModel: - def __init__(self): + def __init__(self, kvargs): + self.remote_vit = kvargs.get("remote_vit", False) pass def load_model(self, weight_dir): @@ -133,7 +134,10 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + if self.remote_vit: + image_data = img._preload_data + else: + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)).convert("RGB") t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] img_tensors.append(t) diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 01e3a7268..33eb69e9a 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -171,6 +171,7 @@ def __init__( self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size + self.remote_vit = kvargs.get("remote_vit", False) self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size @@ -381,7 +382,10 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + if self.remote_vit: + image_data = img._preload_data + else: + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index 68e161737..fba404cff 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -200,6 +200,7 @@ def __init__( self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size + self.remote_vit = kvargs.get("remote_vit", False) self.patch_embed = PatchEmbed( patch_size=self.patch_size, @@ -309,7 +310,10 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + if self.remote_vit: + image_data = img._preload_data + else: + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index f6468b144..e4b35d624 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -333,6 +333,7 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class QWenVisionTransformer(nn.Module): def __init__( self, + kvargs, image_size: int, patch_size: int, width: int, @@ -344,6 +345,7 @@ def __init__( **kwargs, ): super().__init__() + self.remote_vit = kvargs.get("remote_vit", False) image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) @@ -422,7 +424,10 @@ def encode(self, image_uuids: List): for i, item in enumerate(image_uuids): if isinstance(item, int): uuids.append(item) - image_data = read_shm(get_shm_name_data(item)) + if self.remote_vit: + image_data = item._preload_data + else: + image_data = read_shm(get_shm_name_data(item.uuid)) image_data = Image.open(BytesIO(image_data)).convert("RGB") t = self.image_transform(image_data) img_tensors.append(t) diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 9deaf0857..0915c4ac0 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -152,6 +152,7 @@ def forward(self, image_features, input_embeddings): class TarsierVisionTransformerPretrainedModel(nn.Module): def __init__( self, + kvargs, vision_config=None, text_config=None, ignore_index=-100, @@ -165,6 +166,7 @@ def __init__( **kwargs, ): super().__init__() + self.remote_vit = kvargs.get("remote_vit", False) self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config) if projection_head == "Pixel_Shuffle": @@ -251,7 +253,10 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = read_shm(get_shm_name_data(img.uuid)) + if self.remote_vit: + image_data = img._preload_data + else: + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image=image_data) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 2042ee109..72d8964e0 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -47,6 +47,7 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) + self.remote_vit = kvargs.get("remote_vit", False) self._init_datatype() self._init_config() @@ -178,8 +179,10 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - image_data = img._preload_data - # image_data = read_shm(get_shm_name_data(img.uuid)) + if self.remote_vit: + image_data = img._preload_data + else: + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index cb4f5278e..ecf1e3993 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -46,6 +46,7 @@ def exposed_init_model(self, kvargs): quant_type = self.args.vit_quant_type quant_cfg = self.args.vit_quant_cfg max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1) + remote_vit = True if self.args.run_mode == "visual" else False self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] @@ -62,10 +63,11 @@ def exposed_init_model(self, kvargs): "quant_type": quant_type, "quant_cfg": quant_cfg, "max_batch_size": max_batch_size, + "remote_vit": remote_vit, } self.model_type = model_cfg["model_type"] if self.model_type == "qwen": - self.model = QWenVisionTransformer(**model_cfg["visual"]).eval().bfloat16() + self.model = QWenVisionTransformer(kvargs, **model_cfg["visual"]).eval().bfloat16() elif self.model_type == "qwen2_vl": self.model = ( Qwen2VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() @@ -75,14 +77,14 @@ def exposed_init_model(self, kvargs): Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) elif model_cfg["architectures"][0] == "TarsierForConditionalGeneration": - self.model = TarsierVisionTransformerPretrainedModel(**model_cfg).eval().bfloat16() + self.model = TarsierVisionTransformerPretrainedModel(kvargs, **model_cfg).eval().bfloat16() elif self.model_type == "llava": - self.model = LlavaVisionModel() + self.model = LlavaVisionModel(kvargs) elif self.model_type == "internvl_chat": self.model = VisionTransformer(kvargs) # self.model = InternVLVisionModel() elif self.model_type == "gemma3": - self.model = Gemma3VisionModel() + self.model = Gemma3VisionModel(kvargs) else: raise Exception(f"can not support {self.model_type} now") self.model.load_model(weight_dir) diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index b93c7c84c..6877bf95a 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -50,7 +50,6 @@ def _setup_vit_connections(self): """ if self.remote_vit: # 远程VIT实例模式 - print("remote") self._setup_remote_vit_connections() else: print("not remote") diff --git a/lightllm/utils/shm_size_check.py b/lightllm/utils/shm_size_check.py index 84a9b8f3f..144e524bd 100644 --- a/lightllm/utils/shm_size_check.py +++ b/lightllm/utils/shm_size_check.py @@ -117,6 +117,7 @@ def _get_recommended_shm_size_gb(args, max_image_resolution=(3940, 2160), dtype_ ) fake_image_item.image_w = fake_image_item._data[0] fake_image_item.image_h = fake_image_item._data[1] + fake_image_item.extra_params["image_patch_max_num"] = 12 max_image_tokens = tokenizer.get_image_token_length(fake_image_item) # 估算图片 token 所需的资源 From 485356104c7c9026b4247e0850ab8e253aa7af0a Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 16 Sep 2025 07:33:44 +0000 Subject: [PATCH 32/40] [fix]0915-fix-rpyc-cost --- lightllm/models/vit/model.py | 6 +-- .../embed_cache/impl/naive_memory_cache.py | 2 +- lightllm/server/httpserver/manager.py | 2 + lightllm/server/visualserver/manager.py | 38 +++++++++++-------- .../visualserver/model_infer/model_rpc.py | 32 ++++++++++++---- 5 files changed, 52 insertions(+), 28 deletions(-) diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 72d8964e0..2a6a549a6 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -1,4 +1,5 @@ import os +import time import json import torch from lightllm.models.vit.layer_infer.pre_layer_infer import ViTPreLayerInfer @@ -179,10 +180,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - if self.remote_vit: - image_data = img._preload_data - else: - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) img_tensors.append(t) diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index b810f5658..e4f9544f1 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -81,7 +81,7 @@ def _clear(self, free_max_count: int): # if self.args.run_mode == "visual": # free_afs(get_shm_name_embed(id), self.args.image_embed_dir) # elif not self.args.enable_remote_vit: - if not self.args.run_mode == "visual": + if not self.args.enable_remote_vit and self.args.run_mode != "visual": free_shm(get_shm_name_embed(id)) del self._md5_to_record[record.md5sum] del self._records[id] diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6587e7369..95617a09c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -4,6 +4,7 @@ import asyncio import uvloop import rpyc +import socket import time import copy import hashlib @@ -84,6 +85,7 @@ def __init__( self.enable_multimodal = enable_multimodal if self.enable_multimodal: self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) # 初始化VIT连接管理器 from lightllm.server.visualserver.vit_connect import VITConnectionManager diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index 8b6f474bd..0196ff65c 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -4,6 +4,7 @@ import asyncio import uvloop import rpyc +import socket import pickle import hashlib import datetime @@ -11,6 +12,7 @@ from fastapi import Request from ..tokenizer import get_tokenizer from lightllm.server.core.objs.io_objs.group_req import GroupReqIndexes +from lightllm.server.embed_cache.utils import get_shm_name_data, create_shm from lightllm.server.core.objs import ShmReqManager from lightllm.server.core.objs import SamplingParams from lightllm.server.core.objs import Req, FinishStatus @@ -63,6 +65,7 @@ def _setup_connections(self): self.send_to_next_module = context.socket(zmq.PUSH) # router or audio server (if --enable_multimodal_audio) self.send_to_next_module.connect(f"{self.args.zmq_mode}127.0.0.1:{self.next_module_port}") self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) async def wait_to_model_ready(self): visual_dp = self.args.visual_dp @@ -100,7 +103,6 @@ async def infer_imgs(self, images: List[ImageItem]): for vit_tp_rank in range(self.args.visual_tp): task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_images)) tasks.append(task) - await asyncio.gather(*tasks) return @@ -162,19 +164,34 @@ async def _recv_reqs(self): # ] logger.info(f"Receive req {recv_req.group_req_id}, image_count:{len(recv_req.multimodal_params.images)}") uuids = [img.uuid for img in recv_req.multimodal_params.images] - already_embed = self.cache_client.root.get_items_embed(uuids) + already_embed = await asyncio.to_thread(self.cache_client.root.get_items_embed, uuids) if all(already_embed): return None + + uuids = [] token_nums = [] + datas = [] for img, embed in zip(recv_req.multimodal_params.images, already_embed): if not embed: uuids.append(img.uuid) token_nums.append(img.token_num) + datas.append(img._preload_data) + img.free() while True: - records = self.cache_client.root.alloc(uuids, token_nums) + records = await asyncio.to_thread(self.cache_client.root.alloc, uuids, token_nums) if records is not None: break - await asyncio.sleep(0.1) + await asyncio.sleep(0.01) + ready_flags = obtain(self.cache_client.root.get_items_data(uuids)) + update_data_ids = [] + + for uid, ready, data in zip(uuids, ready_flags, datas): + if not ready: + create_shm(get_shm_name_data(uid), data) + update_data_ids.append(uid) + + if update_data_ids: + await asyncio.to_thread(self.cache_client.root.set_items_data, update_data_ids) return recv_req else: return self.vit_receiver.recv_pyobj(zmq.NOBLOCK) @@ -193,7 +210,8 @@ async def loop_for_netio_req(self): self.waiting_reqs.append(recv_req) else: assert False, f"Error Req Inf {recv_req}" - self.visual_recv_max_count = min(self.visual_recv_max_count * 1.3, 256) + await asyncio.sleep(0) + self.visual_recv_max_count = min(int(self.visual_recv_max_count * 1.3), 256) except zmq.ZMQError: # 当队列已经开始清空的时候,将一次接受数量下调 self.visual_recv_max_count = 64 @@ -217,21 +235,11 @@ async def loop_for_fwd_visual_only(self): images_need_infer.append(img) if len(images_need_infer) == self.infer_batch_size: - _t0 = time.perf_counter() await self.infer_imgs(images_need_infer) - logger.info( - f"[visual] batch infer complete, image_count: {len(images_need_infer)}, " - f"elapsed_time {(time.perf_counter()-_t0) * 1000}ms" - ) images_need_infer = [] if len(images_need_infer) > 0: - _t1 = time.perf_counter() await self.infer_imgs(images_need_infer) - logger.info( - f"[visual] batch infer complete, image_count:{len(images_need_infer)}, " - f"elapsed_time {(time.perf_counter()-_t1) * 1000}ms" - ) images_need_infer = [] # 在这里release这个image,ref-1 logger.info(f"req-id {visual_req.group_req_id} has been release ok") diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index ecf1e3993..d52a0382c 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -2,6 +2,7 @@ import numpy as np import rpyc import torch +import time import inspect from datetime import timedelta from typing import Dict, List, Tuple @@ -30,6 +31,11 @@ from lightllm.utils.dist_utils import init_vision_distributed_env from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.log_utils import init_logger +import pickle +import socket + +logger = init_logger(__name__) class VisualModelRpcServer(rpyc.Service): @@ -48,10 +54,12 @@ def exposed_init_model(self, kvargs): max_batch_size = min(self.args.visual_infer_batch_size // self.args.visual_dp, 1) remote_vit = True if self.args.run_mode == "visual" else False + self.image_embed_dir = self.args.image_embed_dir self.dp_rank_id = kvargs["dp_rank_id"] self.tp_rank_id = kvargs["tp_rank_id"] kvargs["vit_rank_id"] = self.dp_rank_id * self.args.visual_tp + self.tp_rank_id self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) init_vision_distributed_env(kvargs) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) @@ -109,19 +117,18 @@ def forward(self, images: List[ImageItem]): def exposed_encode(self, images: List[ImageItem]): images = obtain(images) all_img_embeds, uuids, valid_ids = self.forward(images) - all_img_embeds = all_img_embeds.to(torch.device("cpu")) - + all_img_embeds = all_img_embeds.to(torch.device("cpu"), non_blocking=True) if self.tp_rank_id == 0: # ready_flags = obtain(self.cache_client.root.get_items_embed(uuids)) ids_to_set = [] - for i, img in enumerate(images): + for i in range(len(images)): # if ready: # continue uid = uuids[i] start, end = valid_ids[i] cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) if self.args.run_mode == "visual": - create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.args.image_embed_dir) + create_afs(get_shm_name_embed(uid), cur_embed_bytes, self.image_embed_dir) else: create_shm(get_shm_name_embed(uid), cur_embed_bytes) ids_to_set.append(uid) @@ -131,11 +138,13 @@ def exposed_encode(self, images: List[ImageItem]): class VisualModelRpcClient: - def __init__(self, model_rpc, vit_tp, rpc_server_process=None): - self.model: VisualModelRpcServer = model_rpc + def __init__(self, conn, vit_tp, rpc_server_process=None): + self.conn = conn + self.model: VisualModelRpcServer = conn.root self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process self.use_rpc = True + self._bg = rpyc.BgServingThread(self.conn) if self.use_rpc: def async_wrap(f): @@ -176,7 +185,13 @@ def _init_env(port, device_id): # 注册graceful 退出的处理 graceful_registry(inspect.currentframe().f_code.co_name) - t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) + auth = lambda sock: (sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) or (sock, None)) + t = ThreadedServer( + VisualModelRpcServer(), + port=port, + protocol_config={"allow_pickle": True}, + authenticator=auth, + ) t.start() return @@ -197,6 +212,7 @@ async def start_model_process(port, vit_tp, device_id): while repeat_count < 20: try: con = rpyc.connect("localhost", port, config={"allow_pickle": True}) + con._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) break except BaseException: await asyncio.sleep(1) @@ -205,4 +221,4 @@ async def start_model_process(port, vit_tp, device_id): raise Exception("init rpc env error!") assert proc.is_alive() - return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) + return VisualModelRpcClient(con, vit_tp, rpc_server_process=proc) From ffe2f6bafdf8251b21817a0516824c9dd99b5d67 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Fri, 19 Sep 2025 09:11:06 +0000 Subject: [PATCH 33/40] [fix]fix redis --- .../qwen_vl/layer_infer/pre_layer_infer.py | 5 + .../impl/memory_cache_with_redis.py | 44 +++--- .../embed_cache/impl/naive_memory_cache.py | 2 +- lightllm/server/embed_cache/manager.py | 4 +- lightllm/server/embed_cache/utils.py | 130 +++++++++++------- lightllm/server/httpserver/manager.py | 19 ++- lightllm/server/visualserver/vit_connect.py | 13 +- 7 files changed, 126 insertions(+), 91 deletions(-) diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 5808c02bb..dd6585d60 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -1,3 +1,5 @@ +import rpyc +import socket import torch import torch.distributed as dist @@ -31,6 +33,8 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): def __init__(self, network_config, mode): super().__init__(network_config, mode) self.args = get_env_start_args() + self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True}) + self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) return def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): @@ -57,6 +61,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei embed = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir) else: embed = read_shm(get_shm_name_embed(img["uuid"])) + self.cache_client.root.release([img["uuid"]]) img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1)) img_start_token_ids.append(img["token_id"]) img_token_lens.append(img["token_num"]) diff --git a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py index d024136a8..05bd0bc23 100644 --- a/lightllm/server/embed_cache/impl/memory_cache_with_redis.py +++ b/lightllm/server/embed_cache/impl/memory_cache_with_redis.py @@ -32,10 +32,10 @@ def __init__(self, args) -> None: # llm 负责release def release(self, ids: list[int]) -> None: with self.lock: - for id_ in ids: - self._records[id_].ref -= 1 - if self.redis_cache.query(str(id_)): - self.redis_cache.decr(str(id_)) + for id in ids: + self._records[id].ref -= 1 + if self.redis_cache.query(str(id)): + self.redis_cache.decr(str(id)) # print(self.redis_cache.stats(), flush=True) # vit 负责set @@ -44,27 +44,31 @@ def set_items_embed(self, ids: list[int]) -> None: for id in ids: self.redis_cache.insert(str(id)) self._records[id].embed = True - self._records[id].ref -= 1 # vit端alloc之后ref+1 vit完成后ref-1 + self._records[id].ref -= 1 + self.redis_cache.decr(str(id)) # vit端alloc之后ref+1 vit完成后ref-1 - def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: ret = [] for id in ids: - exist = self.redis_cache.query(str(id)) + if embeding_only: + exist = self.redis_cache.query(str(id)) + else: + exist = self.redis_cache.query_and_incre(str(id)) ret.append(exist) if exist: self._records[id].embed = True return ret - def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]: - ret = [] - for id in ids: - # if self.redis_cache.query(str(id)): - # ret.append(True) - # continue - # 避免重复的引用计数增加 - if self._records[id].embed: - ret.append(True) - continue - self._records[id].embed = self.redis_cache.query_and_incre(str(id)) - ret.append(self._records[id].embed) - return ret + # def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]: + # ret = [] + # for id in ids: + # # if self.redis_cache.query(str(id)): + # # ret.append(True) + # # continue + # # 避免重复的引用计数增加 + # if self._records[id].embed: + # ret.append(True) + # continue + # self._records[id].embed = self.redis_cache.query_and_incre(str(id)) + # ret.append(self._records[id].embed) + # return ret diff --git a/lightllm/server/embed_cache/impl/naive_memory_cache.py b/lightllm/server/embed_cache/impl/naive_memory_cache.py index e4f9544f1..7f9ac58f4 100644 --- a/lightllm/server/embed_cache/impl/naive_memory_cache.py +++ b/lightllm/server/embed_cache/impl/naive_memory_cache.py @@ -144,5 +144,5 @@ def set_items_embed(self, ids: list[int]) -> None: for id_ in ids: self._records[id_].embed = True - def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]: + def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]: return [self._records.get(id_).embed if id_ in self._records else False for id_ in ids] diff --git a/lightllm/server/embed_cache/manager.py b/lightllm/server/embed_cache/manager.py index 421fc14e6..a417a7d79 100644 --- a/lightllm/server/embed_cache/manager.py +++ b/lightllm/server/embed_cache/manager.py @@ -49,9 +49,9 @@ def exposed_set_items_embed(self, ids: list[int]) -> None: ids = obtain(ids) return self._impl.set_items_embed(ids) - def exposed_get_items_embed(self, ids: list[int]) -> list[bool]: + def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]: ids = obtain(ids) - return self._impl.get_items_embed(ids) + return self._impl.get_items_embed(ids, embeding_only) def get_cache_manager(args): diff --git a/lightllm/server/embed_cache/utils.py b/lightllm/server/embed_cache/utils.py index c5e259e85..66bb72674 100644 --- a/lightllm/server/embed_cache/utils.py +++ b/lightllm/server/embed_cache/utils.py @@ -118,7 +118,7 @@ def __init__( self, redis_url: str = "redis://localhost:6379/0", capacity: int = 50000, - evict_fraction: float = 0.2, + evict_fraction: float = 0.1, key_prefix: str = "md5:", image_embed_dir: str = None, path_ext: str = "-embed", @@ -128,7 +128,7 @@ def __init__( - capacity: max count of md5 entries allowed in Redis - evict_fraction: fraction to evict when inserting a NEW md5 and at capacity - image_embed_dir: base directory for image embed files (e.g., "/afs/embeds") - - path_ext: file extension for embed files (default: ".embed") + - path_ext: file extension for embed files (default: "-embed") """ if not (0.0 <= evict_fraction <= 1.0): raise ValueError("evict_fraction must be 0..1") @@ -152,7 +152,7 @@ def __init__( self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA) def insert(self, md5: str) -> Tuple[bool, List[str]]: - """Insert a new md5 with default ref_count=0. May trigger LRU eviction.""" + """Insert a new md5 with default ref_count=1. May trigger LRU eviction.""" # 等待任何正在进行的逐出操作 self._wait_if_eviction() @@ -176,16 +176,20 @@ def insert(self, md5: str) -> Tuple[bool, List[str]]: success = bool(evict_res[0]) victims = evict_res[1:] if len(evict_res) > 1 else [] - # 删除被逐出md5对应的AFS文件 - if victims and self.image_embed_dir: - self._delete_afs_files(victims) - - return success, victims + if success: + # 删除被逐出md5对应的AFS文件 + if victims and self.image_embed_dir: + self._delete_afs_files(victims) + return True, victims + else: + # 逐出失败,短暂退避后重试 + time.sleep(0.01) + return self.insert(md5) finally: self._release_lock() else: # 等待锁释放后重试 - time.sleep(0.1) + time.sleep(0.01) return self.insert(md5) except Exception as e: self._release_lock() @@ -199,7 +203,6 @@ def query(self, md5: str) -> bool: def query_and_incre(self, md5: str) -> bool: """Query if md5 exists and increment ref_count if found.""" self._wait_if_eviction() - res = self._query_incre_script( keys=[self.zset_key, self.ref_prefix], args=[md5], @@ -228,6 +231,11 @@ def stats(self) -> dict: "evict_fraction": self.evict_fraction, } + def get_ref(self, md5: str) -> int | None: + self._wait_if_eviction() + val = self.r.get(self.ref_prefix + md5) + return int(val) if val is not None else None + def _wait_if_eviction(self) -> None: max_wait = 30 start_time = time.time() @@ -284,8 +292,8 @@ def _delete_afs_files(self, victims: List[str]) -> None: local size = redis.call('ZCARD', zset) if size < capacity then - -- Insert with ref_count=0 - redis.call('SET', ref_key, 0) + -- Insert with ref_count=1 + redis.call('SET', ref_key, 1) local now = redis.call('TIME')[1] * 1000 redis.call('ZADD', zset, now, md5) return {0} -- Success, no eviction @@ -332,17 +340,16 @@ def _delete_afs_files(self, victims: List[str]) -> None: --ref 递减到 0 时保留键,只更新计数与 LRU local rc = tonumber(val) - 1 -if rc < 0 then - rc = 0 -end - +if rc < 0 then rc = 0 end redis.call('SET', ref_key, rc) --- 更新 LRU 时间戳(最近释放的条目更不容易被立即逐出) -local now = redis.call('TIME')[1] * 1000 -redis.call('ZADD', zset, now, md5) +if rc > 0 then + -- 只有仍被引用时才更新 LRU + local now = redis.call('TIME')[1] * 1000 + redis.call('ZADD', zset, now, md5) +end -return {rc, 0} -- 未删除 +return {rc, 0} """ _EVICT_AND_INSERT_LUA = r""" @@ -354,43 +361,64 @@ def _delete_afs_files(self, victims: List[str]) -> None: local capacity = tonumber(ARGV[2]) local evict_fraction = tonumber(ARGV[3]) --- 计算需要逐出的数量 -local need = math.max(1, math.floor(capacity * evict_fraction + 0.5)) +local unpack = unpack or table.unpack + +-- helper: now millis +local function now_ms() + local t = redis.call('TIME') + return t[1] * 1000 + math.floor(t[2] / 1000) +end + +local new_ref_key = ref_prefix .. new_md5 + +-- If already exists, treat as a hit: bump ref_count and refresh LRU +local cur = redis.call('GET', new_ref_key) +if cur then + local rc = tonumber(cur) + 1 + redis.call('SET', new_ref_key, rc) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- If not at capacity, just insert +local size = redis.call('ZCARD', zset) +if size < capacity then + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1} -- success, no victims +end + +-- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed +local max_try = math.max(1, math.floor(size * evict_fraction + 0.5)) local victims = {} +local freed = 0 --- 获取所有键并按LRU排序 +-- Scan from LRU (smallest score) to MRU local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES') local i = 1 - --- 查找引用计数为0的键作为逐出候选 -while #victims < need and i <= #all_keys do - local md5 = all_keys[i] - local ref_key = ref_prefix .. md5 - local rc = redis.call('GET', ref_key) - - if rc and tonumber(rc) <= 0 then - table.insert(victims, md5) - end - i = i + 2 -- 跳过分数 +while freed < 1 and i <= #all_keys and #victims < max_try do + local md5 = all_keys[i] + local ref_key = ref_prefix .. md5 + local v = redis.call('GET', ref_key) + if v and tonumber(v) <= 0 then + table.insert(victims, md5) + freed = freed + 1 + end + i = i + 2 -- skip score end --- 如果找到足够的候选,执行逐出 -if #victims >= need then - -- 删除受害者 - for _, v in ipairs(victims) do - local ref_key = ref_prefix .. v - redis.call('DEL', ref_key) - redis.call('ZREM', zset, v) - end - - -- 插入新的md5 - local ref_key = ref_prefix .. new_md5 - redis.call('SET', ref_key, 0) - local now = redis.call('TIME')[1] * 1000 - redis.call('ZADD', zset, now, new_md5) - - return {1, unpack(victims)} -- success + victims +if freed >= 1 then + -- delete victims + for _, v in ipairs(victims) do + redis.call('DEL', ref_prefix .. v) + redis.call('ZREM', zset, v) + end + -- insert new + redis.call('SET', new_ref_key, 1) + redis.call('ZADD', zset, now_ms(), new_md5) + return {1, unpack(victims)} else - return {0} -- 逐出失败,没有足够的候选 + -- no zero-ref items found + return {0} end """ diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 95617a09c..a0e394049 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -122,12 +122,6 @@ def __init__( async def _alloc_resource(self, items, uuids, token_nums, datas): while True: - # 检查这个图片在redis总是否已经存在 - # embed_exists = obtain(self.cache_client.root.get_items_embed(uuids)) - # for exist in embed_exists: - # if exist: - # continue - # else: records = obtain(self.cache_client.root.alloc(uuids, token_nums)) if records is None: @@ -212,8 +206,8 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam audio.uuid = None audio.token_id = None audio.token_num = None - if ids_to_release: - self.cache_client.root.release(ids_to_release) + # if ids_to_release: + # self.cache_client.root.release(ids_to_release) return def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None): @@ -370,7 +364,7 @@ async def generate( # 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放 # 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环 # 进行回收。 - if group_request_id not in self.req_id_to_out_inf and self.args.run_mode != "llm_only": + if group_request_id not in self.req_id_to_out_inf: await self._release_multimodal_resources(multimodal_params) await self.abort(group_request_id) raise e @@ -410,7 +404,7 @@ async def get_image_embeding( visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time) await self.transfer_to_next_module_or_node( - None, sampling_params, original_multimodal_params, visual_req_status + None, sampling_params, original_multimodal_params, visual_req_status, embeding_only=True ) except Exception as e: @@ -513,6 +507,7 @@ async def transfer_to_next_module_or_node( sampling_params: SamplingParams, original_multimodal_params: MultimodalParams, group_req_objs: Optional[GroupReqObjs] = None, + embeding_only: Optional[bool] = False, ): # 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点. if self.is_multinode_tp_master: @@ -522,12 +517,13 @@ async def transfer_to_next_module_or_node( protocol=pickle.HIGHEST_PROTOCOL, ) - await self.transfer_to_next_module(group_req_objs) + await self.transfer_to_next_module(group_req_objs, embeding_only) return async def transfer_to_next_module( self, group_req_objs: Optional[GroupReqObjs] = None, + embeding_only: Optional[bool] = False, ): if self.pd_mode.is_P_or_NORMAL(): @@ -535,6 +531,7 @@ async def transfer_to_next_module( await self.vit_manager.send_to_vit( group_req_objs.to_group_req_index(), protocol=pickle.HIGHEST_PROTOCOL, + embeding_only=embeding_only, ) if not self.enable_multimodal or self.args.enable_remote_vit: diff --git a/lightllm/server/visualserver/vit_connect.py b/lightllm/server/visualserver/vit_connect.py index 6877bf95a..7a1443f02 100644 --- a/lightllm/server/visualserver/vit_connect.py +++ b/lightllm/server/visualserver/vit_connect.py @@ -159,7 +159,7 @@ def _get_vit_instance(self): self.current_vit_index = index return list(self.remote_vit_instances.values())[index] - async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL): + async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOCOL, embeding_only=False): """ 发送数据到VIT实例,支持本地和远程模式 """ @@ -176,7 +176,7 @@ async def send_to_vit(self, req: GroupReqIndexes, protocol=pickle.HIGHEST_PROTOC raise Exception(f"Failed to send to VIT instance: {e}") # 远程模式下,发送完以后,在释放图片资源 - await self._wait_visual_embed_ready(req) + await self._wait_visual_embed_ready(req, embeding_only) if self.remote_vit: req.multimodal_params.free() @@ -216,16 +216,17 @@ async def _async_get_vit_objs(self) -> Optional[Dict[int, VIT_Obj]]: logger.exception(f"Error getting VIT instances: {e}") return None - async def _wait_visual_embed_ready(self, req: GroupReqIndexes, timeout_seconds: int = 100): + async def _wait_visual_embed_ready( + self, req: GroupReqIndexes, embeding_only: bool = False, timeout_seconds: int = 1000 + ): # 本地模式不需要等待 if not self.remote_vit: return - uuids = req.multimodal_params.get_all_uuids() async def wait_for_embeds(): - while not all(self.cache_client.root.get_items_embed(uuids)): - await asyncio.sleep(0.05) + while not all(self.cache_client.root.get_items_embed(uuids, embeding_only)): + await asyncio.sleep(0.01) try: await asyncio.wait_for(wait_for_embeds(), timeout=timeout_seconds) From e723c404e1b1f8161c761e5947468b59e5c565e9 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Tue, 23 Sep 2025 07:28:45 +0000 Subject: [PATCH 34/40] [fix]clean redis before start --- lightllm/utils/redis_utils.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/lightllm/utils/redis_utils.py b/lightllm/utils/redis_utils.py index bd4d87ca3..0cf19afb9 100644 --- a/lightllm/utils/redis_utils.py +++ b/lightllm/utils/redis_utils.py @@ -9,21 +9,35 @@ def start_redis_service(args): if not hasattr(args, "start_redis") or not args.start_redis: return None + config_server_host = args.config_server_host + redis_port = args.redis_port try: - redis_port = args.redis_port + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "FLUSHALL", "ASYNC"], check=False, timeout=2 + ) + subprocess.run( + ["redis-cli", "-h", config_server_host, "-p", str(redis_port), "SHUTDOWN", "NOSAVE"], check=False, timeout=2 + ) + except Exception: + pass + try: redis_command = [ "redis-server", "--port", str(redis_port), "--bind", - f"{args.config_server_host}", + f"{config_server_host}", "--daemonize", "no", "--logfile", "-", "--loglevel", "notice", + "--save", + '""', # 不触发 RDB 快照 + "--appendonly", + "no", # 关闭 AOF ] logger.info(f"Starting Redis service on port {redis_port}") From d53a924ac83589d816f5754cf55df4f0ce6c1796 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 13 Oct 2025 08:14:01 +0000 Subject: [PATCH 35/40] merge main --- lightllm/server/api_cli.py | 2 +- lightllm/server/visualserver/- | 27 --------------------------- 2 files changed, 1 insertion(+), 28 deletions(-) delete mode 100644 lightllm/server/visualserver/- diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index d3f3efb84..fd4f6bad6 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -367,7 +367,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "--visual_nccl_ports", nargs="+", type=int, - default=None, + default=[29500], help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", ) parser.add_argument( diff --git a/lightllm/server/visualserver/- b/lightllm/server/visualserver/- deleted file mode 100644 index 170eff3a7..000000000 --- a/lightllm/server/visualserver/- +++ /dev/null @@ -1,27 +0,0 @@ -529205:C 28 Aug 2025 13:07:04.500 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo -529205:C 28 Aug 2025 13:07:04.501 # Redis version=6.0.16, bits=64, commit=00000000, modified=0, pid=529205, just started -529205:C 28 Aug 2025 13:07:04.503 # Configuration loaded -529205:M 28 Aug 2025 13:07:04.505 * Running mode=standalone, port=6379. -529205:M 28 Aug 2025 13:07:04.506 # Server initialized -529205:M 28 Aug 2025 13:07:04.507 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. -529205:M 28 Aug 2025 13:07:04.509 * Ready to accept connections -529205:signal-handler (1756386794) Received SIGINT scheduling shutdown... -529205:M 28 Aug 2025 13:13:14.912 # User requested shutdown... -529205:M 28 Aug 2025 13:13:14.914 # Redis is now ready to exit, bye bye... -533706:C 28 Aug 2025 13:13:21.718 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo -533706:C 28 Aug 2025 13:13:21.719 # Redis version=6.0.16, bits=64, commit=00000000, modified=0, pid=533706, just started -533706:C 28 Aug 2025 13:13:21.720 # Configuration loaded -533706:M 28 Aug 2025 13:13:21.723 * Running mode=standalone, port=6379. -533706:M 28 Aug 2025 13:13:21.724 # Server initialized -533706:M 28 Aug 2025 13:13:21.724 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. -533706:M 28 Aug 2025 13:13:21.727 * Ready to accept connections -533706:signal-handler (1756390331) Received SIGINT scheduling shutdown... -533706:M 28 Aug 2025 14:12:11.921 # User requested shutdown... -533706:M 28 Aug 2025 14:12:11.922 # Redis is now ready to exit, bye bye... -546119:C 28 Aug 2025 14:12:19.084 # oO0OoO0OoO0Oo Redis is starting oO0OoO0OoO0Oo -546119:C 28 Aug 2025 14:12:19.086 # Redis version=6.0.16, bits=64, commit=00000000, modified=0, pid=546119, just started -546119:C 28 Aug 2025 14:12:19.087 # Configuration loaded -546119:M 28 Aug 2025 14:12:19.089 * Running mode=standalone, port=6379. -546119:M 28 Aug 2025 14:12:19.090 # Server initialized -546119:M 28 Aug 2025 14:12:19.091 # WARNING overcommit_memory is set to 0! Background save may fail under low memory condition. To fix this issue add 'vm.overcommit_memory = 1' to /etc/sysctl.conf and then reboot or run the command 'sysctl vm.overcommit_memory=1' for this to take effect. -546119:M 28 Aug 2025 14:12:19.093 * Ready to accept connections From 7d5b9a63fecf659c636b74b68c44e3568daac143 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 13 Oct 2025 08:17:21 +0000 Subject: [PATCH 36/40] fix other vlm --- lightllm/models/llava/llava_visual.py | 5 +---- lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 5 +---- lightllm/models/qwen2_vl/qwen2_visual.py | 5 +---- lightllm/models/tarsier2/tarsier2_visual.py | 5 +---- 4 files changed, 4 insertions(+), 16 deletions(-) diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 09e9f0119..59a10dff1 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -134,10 +134,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - if self.remote_vit: - image_data = img._preload_data - else: - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)).convert("RGB") t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] img_tensors.append(t) diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 33eb69e9a..ab1478e75 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -382,10 +382,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - if self.remote_vit: - image_data = img._preload_data - else: - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index fba404cff..febe9ad83 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -310,10 +310,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - if self.remote_vit: - image_data = img._preload_data - else: - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image_data) diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 0915c4ac0..e432f14fb 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -253,10 +253,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - if self.remote_vit: - image_data = img._preload_data - else: - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) image_data = resize_image(image_data) pixel_values, image_grid_thw = self.processor.preprocess(image=image_data) From b477506e863a299d972a66cbe64461d4ea953dd2 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 13 Oct 2025 08:22:39 +0000 Subject: [PATCH 37/40] fix other vlm --- lightllm/models/gemma3/gemma3_visual.py | 6 +----- lightllm/models/llava/llava_visual.py | 1 - lightllm/models/qwen2_5_vl/qwen2_5_visual.py | 1 - lightllm/models/qwen2_vl/qwen2_visual.py | 1 - lightllm/models/qwen_vl/qwen_visual.py | 6 +----- lightllm/models/tarsier2/tarsier2_visual.py | 1 - lightllm/models/vit/model.py | 1 - 7 files changed, 2 insertions(+), 15 deletions(-) diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index 8e373c128..6469174ed 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -17,7 +17,6 @@ class Gemma3VisionModel: def __init__(self, kvargs): - self.remote_vit = kvargs.get("remote_vit", False) pass def load_model(self, weight_dir): @@ -123,10 +122,7 @@ def encode(self, images: List[ImageItem]): for i, img in enumerate(images): if isinstance(img, ImageItem): uuids.append(img.uuid) - if self.remote_vit: - image_data = img._preload_data - else: - image_data = read_shm(get_shm_name_data(img.uuid)) + image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] img_tensors.append(t) diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 59a10dff1..07ce7f86d 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -16,7 +16,6 @@ class LlavaVisionModel: def __init__(self, kvargs): - self.remote_vit = kvargs.get("remote_vit", False) pass def load_model(self, weight_dir): diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index ab1478e75..01e3a7268 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -171,7 +171,6 @@ def __init__( self.window_size = window_size self.fullatt_block_indexes = fullatt_block_indexes self.out_hidden_size = out_hidden_size - self.remote_vit = kvargs.get("remote_vit", False) self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index febe9ad83..68e161737 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -200,7 +200,6 @@ def __init__( self.patch_size = patch_size self.spatial_merge_size = spatial_merge_size self.temporal_patch_size = temporal_patch_size - self.remote_vit = kvargs.get("remote_vit", False) self.patch_embed = PatchEmbed( patch_size=self.patch_size, diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index e4b35d624..96f8440e8 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -345,7 +345,6 @@ def __init__( **kwargs, ): super().__init__() - self.remote_vit = kvargs.get("remote_vit", False) image_height, image_width = self.image_size = (image_size, image_size) patch_height, patch_width = self.patch_size = (patch_size, patch_size) self.grid_size = (image_height // patch_height, image_width // patch_width) @@ -424,10 +423,7 @@ def encode(self, image_uuids: List): for i, item in enumerate(image_uuids): if isinstance(item, int): uuids.append(item) - if self.remote_vit: - image_data = item._preload_data - else: - image_data = read_shm(get_shm_name_data(item.uuid)) + image_data = read_shm(get_shm_name_data(item.uuid)) image_data = Image.open(BytesIO(image_data)).convert("RGB") t = self.image_transform(image_data) img_tensors.append(t) diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index e432f14fb..5b34c637a 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -166,7 +166,6 @@ def __init__( **kwargs, ): super().__init__() - self.remote_vit = kvargs.get("remote_vit", False) self.vision_tower = Qwen2VisionTransformerPretrainedModel(**vision_config) if projection_head == "Pixel_Shuffle": diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 2a6a549a6..16e8f8300 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -48,7 +48,6 @@ def __init__(self, kvargs): self.quant_cfg_path = kvargs.get("quant_cfg", None) self.load_image_func = get_load_image_func(self.weight_dir_) self.max_batch_size = kvargs.get("max_batch_size", 1) - self.remote_vit = kvargs.get("remote_vit", False) self._init_datatype() self._init_config() From ac67fcc9e3fea21128c51ddc8e741f1c05b46b40 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 13 Oct 2025 08:38:29 +0000 Subject: [PATCH 38/40] fix other vlm --- lightllm/models/gemma3/gemma3_visual.py | 2 +- lightllm/models/llava/llava_visual.py | 2 +- lightllm/models/qwen_vl/qwen_visual.py | 2 +- lightllm/models/tarsier2/tarsier2_visual.py | 1 - lightllm/models/vit/model.py | 1 - 5 files changed, 3 insertions(+), 5 deletions(-) diff --git a/lightllm/models/gemma3/gemma3_visual.py b/lightllm/models/gemma3/gemma3_visual.py index 6469174ed..b2f7a6b77 100644 --- a/lightllm/models/gemma3/gemma3_visual.py +++ b/lightllm/models/gemma3/gemma3_visual.py @@ -16,7 +16,7 @@ class Gemma3VisionModel: - def __init__(self, kvargs): + def __init__(self): pass def load_model(self, weight_dir): diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 07ce7f86d..293bcd445 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -15,7 +15,7 @@ class LlavaVisionModel: - def __init__(self, kvargs): + def __init__(self): pass def load_model(self, weight_dir): diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index 96f8440e8..0e65216b8 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -423,7 +423,7 @@ def encode(self, image_uuids: List): for i, item in enumerate(image_uuids): if isinstance(item, int): uuids.append(item) - image_data = read_shm(get_shm_name_data(item.uuid)) + image_data = read_shm(get_shm_name_data(item)) image_data = Image.open(BytesIO(image_data)).convert("RGB") t = self.image_transform(image_data) img_tensors.append(t) diff --git a/lightllm/models/tarsier2/tarsier2_visual.py b/lightllm/models/tarsier2/tarsier2_visual.py index 5b34c637a..9deaf0857 100644 --- a/lightllm/models/tarsier2/tarsier2_visual.py +++ b/lightllm/models/tarsier2/tarsier2_visual.py @@ -152,7 +152,6 @@ def forward(self, image_features, input_embeddings): class TarsierVisionTransformerPretrainedModel(nn.Module): def __init__( self, - kvargs, vision_config=None, text_config=None, ignore_index=-100, diff --git a/lightllm/models/vit/model.py b/lightllm/models/vit/model.py index 16e8f8300..01bb69bdf 100644 --- a/lightllm/models/vit/model.py +++ b/lightllm/models/vit/model.py @@ -1,5 +1,4 @@ import os -import time import json import torch from lightllm.models.vit.layer_infer.pre_layer_infer import ViTPreLayerInfer From 40f8c6aed64503bb3e6c5c07dc015e31a465b234 Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 13 Oct 2025 08:39:34 +0000 Subject: [PATCH 39/40] fix other vlm --- lightllm/models/qwen_vl/qwen_visual.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index 0e65216b8..f6468b144 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -333,7 +333,6 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class QWenVisionTransformer(nn.Module): def __init__( self, - kvargs, image_size: int, patch_size: int, width: int, From 9838d89ee9c7027658331c07db751c1ea6301aaa Mon Sep 17 00:00:00 2001 From: sangchengmeng Date: Mon, 24 Nov 2025 10:51:43 +0000 Subject: [PATCH 40/40] fix1124 --- .../qwen2_vl/triton_kernel/rotary_pos_emb.py | 92 +++++++++++++------ lightllm/models/qwen2_vl/vision_process.py | 47 ++++++++-- 2 files changed, 104 insertions(+), 35 deletions(-) diff --git a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py index 07e7c8b3f..0ae8099c6 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py +++ b/lightllm/models/qwen2_vl/triton_kernel/rotary_pos_emb.py @@ -1,7 +1,6 @@ -import math -import torch import triton import triton.language as tl +import torch @triton.jit @@ -17,46 +16,72 @@ def rotary_kernel( stride_cos_d, stride_sin_l, stride_sin_d, - D: tl.constexpr, - HALF_D: tl.constexpr, + L, + H, + D, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEAD: tl.constexpr, BLOCK_D: tl.constexpr, ): - pid_h = tl.program_id(0).to(tl.int64) - pid_l = tl.program_id(1).to(tl.int64) - pid_blk = tl.program_id(2).to(tl.int64) + pid_head_blk = tl.program_id(0) + pid_seq_blk = tl.program_id(1) + offs_h = pid_head_blk * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + offs_l = pid_seq_blk * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) offs_d = tl.arange(0, BLOCK_D) - d = pid_blk * BLOCK_D + offs_d - mask = d < D - base = pid_l * stride_l + pid_h * stride_h + offs_h = offs_h.to(tl.int64) + offs_l = offs_l.to(tl.int64) + offs_d = offs_d.to(tl.int64) + + mask_h = offs_h < H + mask_l = offs_l < L + mask_d = offs_d < D + + HALF_D = D // 2 + + l_b = offs_l[:, None, None] + h_b = offs_h[None, :, None] + d_b = offs_d[None, None, :] + + mask = mask_l[:, None, None] & mask_h[None, :, None] & mask_d[None, None, :] + + base = l_b * stride_l + h_b * stride_h + d_b * stride_d + x = tl.load(inp_ptr + base, mask=mask, other=0.0) - in_ptr = inp_ptr + base + d * stride_d - cos_ptr_ = cos_ptr + pid_l * stride_cos_l + d - sin_ptr_ = sin_ptr + pid_l * stride_sin_l + d + cos_base_2d = offs_l[:, None] * stride_cos_l + offs_d[None, :] * stride_cos_d + sin_base_2d = offs_l[:, None] * stride_sin_l + offs_d[None, :] * stride_sin_d + mask_ld = mask_l[:, None] & mask_d[None, :] - x = tl.load(in_ptr, mask=mask) - cos = tl.load(cos_ptr_, mask=mask) - sin = tl.load(sin_ptr_, mask=mask) + cos_2d = tl.load(cos_ptr + cos_base_2d, mask=mask_ld, other=0.0) + sin_2d = tl.load(sin_ptr + sin_base_2d, mask=mask_ld, other=0.0) - partner_d = tl.where(d < HALF_D, d + HALF_D, d - HALF_D) - partner_ptr = inp_ptr + base + partner_d * stride_d - partner_val = tl.load(partner_ptr, mask=mask) - rotated = tl.where(d < HALF_D, -partner_val, partner_val) + cos = cos_2d[:, None, :] + sin = sin_2d[:, None, :] + + partner_d = tl.where(offs_d < HALF_D, offs_d + HALF_D, offs_d - HALF_D) + partner_d_b = partner_d[None, None, :] + + partner_base = l_b * stride_l + h_b * stride_h + partner_d_b * stride_d + partner_val = tl.load(inp_ptr + partner_base, mask=mask, other=0.0) + + rotated = tl.where(d_b < HALF_D, -partner_val, partner_val) y = x * cos + rotated * sin - out_ptr_ = out_ptr + base + d - tl.store(out_ptr_, y, mask=mask) + tl.store(out_ptr + base, y, mask=mask) def apply_rotary_pos_emb_triton( - tensor: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, BLOCK_D: int = 128 + tensor: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, ) -> torch.Tensor: assert tensor.is_cuda and cos.is_cuda and sin.is_cuda assert cos.is_contiguous() and sin.is_contiguous() if tensor.ndim != 3: raise RuntimeError("tensor shape should be [L, H, D]") + orig_dtype = tensor.dtype x = tensor.float() @@ -64,10 +89,21 @@ def apply_rotary_pos_emb_triton( sin = sin.repeat(1, 2).view(sin.size(0), -1).contiguous().float() L, H, D = x.shape - HALF_D = D // 2 y = torch.empty_like(x) - grid = (H, L, triton.cdiv(D, BLOCK_D)) + BLOCK_SEQ = 16 + BLOCK_HEAD = 4 + BLOCK_D = triton.next_power_of_2(D) + + if D >= 128: + num_warps = 8 + else: + num_warps = 4 + + grid = ( + triton.cdiv(H, BLOCK_HEAD), + triton.cdiv(L, BLOCK_SEQ), + ) rotary_kernel[grid]( inp_ptr=x, @@ -81,9 +117,13 @@ def apply_rotary_pos_emb_triton( stride_cos_d=cos.stride(1), stride_sin_l=sin.stride(0), stride_sin_d=sin.stride(1), + L=L, + H=H, D=D, - HALF_D=HALF_D, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_HEAD=BLOCK_HEAD, BLOCK_D=BLOCK_D, + num_warps=num_warps, ) return y.to(orig_dtype) diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index 95622fc02..90601f759 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -22,6 +22,11 @@ ) from torchvision.transforms.v2 import functional as F +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + IMAGE_FACTOR = 28 MIN_PIXELS = 4 * 28 * 28 MAX_PIXELS = 16384 * 28 * 28 @@ -160,9 +165,19 @@ def rescale_and_normalize( return images + @torch.inference_mode() def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: + try: + return self._preprocess_bydevice(image, device="cuda") + except Exception as e: + logger.warning(f"Exception during image preprocessing on CUDA: {str(e)}") + torch.cuda.current_stream().synchronize() + return self._preprocess_bydevice(image, device="cpu") + + def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torch.Tensor]: image_arr = np.asarray(image, dtype=np.uint8) - image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to("cuda", non_blocking=True) + image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + grouped_images, grouped_images_index = group_images_by_shape( [image_data], disable_grouping=self.disable_grouping ) @@ -183,27 +198,39 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: interpolation=self.interpolation, ) resized_images_grouped[shape] = stacked_images + + grouped_images = None resized_images = reorder_images(resized_images_grouped, grouped_images_index) + resized_images_grouped = None - # Group images by size for further processing - # Needed in case do_resize is False, or resize returns images with different sizes grouped_images, grouped_images_index = group_images_by_shape( resized_images, disable_grouping=self.disable_grouping ) + resized_images = None + processed_images_grouped = {} processed_grids = {} + for shape, stacked_images in grouped_images.items(): + stacked_images = stacked_images.to("cuda", non_blocking=True) + resized_height, resized_width = stacked_images.shape[-2:] - # Fused rescale and normalize + patches = self.rescale_and_normalize( - stacked_images, self.do_rescale, self.rescale_factor, self.do_normalize, self.image_mean, self.image_std + stacked_images, + self.do_rescale, + self.rescale_factor, + self.do_normalize, + self.image_mean, + self.image_std, ) if patches.ndim == 4: - # add a temporal dimension if we have images patches = patches.unsqueeze(1) + if patches.shape[1] % self.temporal_patch_size != 0: repeats = patches[:, -1:].repeat(1, self.temporal_patch_size - 1, 1, 1, 1) patches = torch.cat([patches, repeats], dim=1) + batch_size, grid_t, channel = patches.shape[:3] grid_t = grid_t // self.temporal_patch_size grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size @@ -224,8 +251,7 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: .permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) .contiguous() ) - # Reorder dimensions to group grid and patch information for subsequent flattening. - # (batch, grid_t, grid_h, grid_w, merge_h, merge_w, channel, temp_patch_size, patch_h, patch_w) + flatten_patches = patches.view( batch_size, grid_t * grid_h * grid_w, @@ -235,9 +261,12 @@ def preprocess(self, image) -> Tuple[torch.Tensor, torch.Tensor]: processed_images_grouped[shape] = flatten_patches processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + grouped_images = None + processed_images = reorder_images(processed_images_grouped, grouped_images_index) processed_grids = reorder_images(processed_grids, grouped_images_index) - pixel_values = torch.cat(processed_images, dim=0) # (num_patches_total, C*T*ps*ps) + + pixel_values = torch.cat(processed_images, dim=0) image_grid_thw = torch.as_tensor(processed_grids) return pixel_values, image_grid_thw