Skip to content

Commit ffe2f6b

Browse files
committed
[fix]fix redis
1 parent 4853561 commit ffe2f6b

File tree

7 files changed

+126
-91
lines changed

7 files changed

+126
-91
lines changed

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import rpyc
2+
import socket
13
import torch
24
import torch.distributed as dist
35

@@ -31,6 +33,8 @@ class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
3133
def __init__(self, network_config, mode):
3234
super().__init__(network_config, mode)
3335
self.args = get_env_start_args()
36+
self.cache_client = rpyc.connect("localhost", self.args.cache_port, config={"allow_pickle": True})
37+
self.cache_client._channel.stream.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
3438
return
3539

3640
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
@@ -57,6 +61,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
5761
embed = read_afs(get_shm_name_embed(img["uuid"]), self.args.image_embed_dir)
5862
else:
5963
embed = read_shm(get_shm_name_embed(img["uuid"]))
64+
self.cache_client.root.release([img["uuid"]])
6065
img_weight.append(bytes2tensor(embed).cuda().reshape(img["token_num"], -1))
6166
img_start_token_ids.append(img["token_id"])
6267
img_token_lens.append(img["token_num"])

lightllm/server/embed_cache/impl/memory_cache_with_redis.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ def __init__(self, args) -> None:
3232
# llm 负责release
3333
def release(self, ids: list[int]) -> None:
3434
with self.lock:
35-
for id_ in ids:
36-
self._records[id_].ref -= 1
37-
if self.redis_cache.query(str(id_)):
38-
self.redis_cache.decr(str(id_))
35+
for id in ids:
36+
self._records[id].ref -= 1
37+
if self.redis_cache.query(str(id)):
38+
self.redis_cache.decr(str(id))
3939
# print(self.redis_cache.stats(), flush=True)
4040

4141
# vit 负责set
@@ -44,27 +44,31 @@ def set_items_embed(self, ids: list[int]) -> None:
4444
for id in ids:
4545
self.redis_cache.insert(str(id))
4646
self._records[id].embed = True
47-
self._records[id].ref -= 1 # vit端alloc之后ref+1 vit完成后ref-1
47+
self._records[id].ref -= 1
48+
self.redis_cache.decr(str(id)) # vit端alloc之后ref+1 vit完成后ref-1
4849

49-
def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]:
50+
def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]:
5051
ret = []
5152
for id in ids:
52-
exist = self.redis_cache.query(str(id))
53+
if embeding_only:
54+
exist = self.redis_cache.query(str(id))
55+
else:
56+
exist = self.redis_cache.query_and_incre(str(id))
5357
ret.append(exist)
5458
if exist:
5559
self._records[id].embed = True
5660
return ret
5761

58-
def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]:
59-
ret = []
60-
for id in ids:
61-
# if self.redis_cache.query(str(id)):
62-
# ret.append(True)
63-
# continue
64-
# 避免重复的引用计数增加
65-
if self._records[id].embed:
66-
ret.append(True)
67-
continue
68-
self._records[id].embed = self.redis_cache.query_and_incre(str(id))
69-
ret.append(self._records[id].embed)
70-
return ret
62+
# def get_items_embed_and_incre(self, ids: list[int]) -> list[Optional[bool]]:
63+
# ret = []
64+
# for id in ids:
65+
# # if self.redis_cache.query(str(id)):
66+
# # ret.append(True)
67+
# # continue
68+
# # 避免重复的引用计数增加
69+
# if self._records[id].embed:
70+
# ret.append(True)
71+
# continue
72+
# self._records[id].embed = self.redis_cache.query_and_incre(str(id))
73+
# ret.append(self._records[id].embed)
74+
# return ret

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,5 @@ def set_items_embed(self, ids: list[int]) -> None:
144144
for id_ in ids:
145145
self._records[id_].embed = True
146146

147-
def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]:
147+
def get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[Optional[bool]]:
148148
return [self._records.get(id_).embed if id_ in self._records else False for id_ in ids]

lightllm/server/embed_cache/manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def exposed_set_items_embed(self, ids: list[int]) -> None:
4949
ids = obtain(ids)
5050
return self._impl.set_items_embed(ids)
5151

52-
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
52+
def exposed_get_items_embed(self, ids: list[int], embeding_only: bool = False) -> list[bool]:
5353
ids = obtain(ids)
54-
return self._impl.get_items_embed(ids)
54+
return self._impl.get_items_embed(ids, embeding_only)
5555

5656

5757
def get_cache_manager(args):

lightllm/server/embed_cache/utils.py

Lines changed: 79 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
self,
119119
redis_url: str = "redis://localhost:6379/0",
120120
capacity: int = 50000,
121-
evict_fraction: float = 0.2,
121+
evict_fraction: float = 0.1,
122122
key_prefix: str = "md5:",
123123
image_embed_dir: str = None,
124124
path_ext: str = "-embed",
@@ -128,7 +128,7 @@ def __init__(
128128
- capacity: max count of md5 entries allowed in Redis
129129
- evict_fraction: fraction to evict when inserting a NEW md5 and at capacity
130130
- image_embed_dir: base directory for image embed files (e.g., "/afs/embeds")
131-
- path_ext: file extension for embed files (default: ".embed")
131+
- path_ext: file extension for embed files (default: "-embed")
132132
"""
133133
if not (0.0 <= evict_fraction <= 1.0):
134134
raise ValueError("evict_fraction must be 0..1")
@@ -152,7 +152,7 @@ def __init__(
152152
self._evict_and_insert_script = self.r.register_script(self._EVICT_AND_INSERT_LUA)
153153

154154
def insert(self, md5: str) -> Tuple[bool, List[str]]:
155-
"""Insert a new md5 with default ref_count=0. May trigger LRU eviction."""
155+
"""Insert a new md5 with default ref_count=1. May trigger LRU eviction."""
156156
# 等待任何正在进行的逐出操作
157157
self._wait_if_eviction()
158158

@@ -176,16 +176,20 @@ def insert(self, md5: str) -> Tuple[bool, List[str]]:
176176
success = bool(evict_res[0])
177177
victims = evict_res[1:] if len(evict_res) > 1 else []
178178

179-
# 删除被逐出md5对应的AFS文件
180-
if victims and self.image_embed_dir:
181-
self._delete_afs_files(victims)
182-
183-
return success, victims
179+
if success:
180+
# 删除被逐出md5对应的AFS文件
181+
if victims and self.image_embed_dir:
182+
self._delete_afs_files(victims)
183+
return True, victims
184+
else:
185+
# 逐出失败,短暂退避后重试
186+
time.sleep(0.01)
187+
return self.insert(md5)
184188
finally:
185189
self._release_lock()
186190
else:
187191
# 等待锁释放后重试
188-
time.sleep(0.1)
192+
time.sleep(0.01)
189193
return self.insert(md5)
190194
except Exception as e:
191195
self._release_lock()
@@ -199,7 +203,6 @@ def query(self, md5: str) -> bool:
199203
def query_and_incre(self, md5: str) -> bool:
200204
"""Query if md5 exists and increment ref_count if found."""
201205
self._wait_if_eviction()
202-
203206
res = self._query_incre_script(
204207
keys=[self.zset_key, self.ref_prefix],
205208
args=[md5],
@@ -228,6 +231,11 @@ def stats(self) -> dict:
228231
"evict_fraction": self.evict_fraction,
229232
}
230233

234+
def get_ref(self, md5: str) -> int | None:
235+
self._wait_if_eviction()
236+
val = self.r.get(self.ref_prefix + md5)
237+
return int(val) if val is not None else None
238+
231239
def _wait_if_eviction(self) -> None:
232240
max_wait = 30
233241
start_time = time.time()
@@ -284,8 +292,8 @@ def _delete_afs_files(self, victims: List[str]) -> None:
284292
285293
local size = redis.call('ZCARD', zset)
286294
if size < capacity then
287-
-- Insert with ref_count=0
288-
redis.call('SET', ref_key, 0)
295+
-- Insert with ref_count=1
296+
redis.call('SET', ref_key, 1)
289297
local now = redis.call('TIME')[1] * 1000
290298
redis.call('ZADD', zset, now, md5)
291299
return {0} -- Success, no eviction
@@ -332,17 +340,16 @@ def _delete_afs_files(self, victims: List[str]) -> None:
332340
333341
--ref 递减到 0 时保留键,只更新计数与 LRU
334342
local rc = tonumber(val) - 1
335-
if rc < 0 then
336-
rc = 0
337-
end
338-
343+
if rc < 0 then rc = 0 end
339344
redis.call('SET', ref_key, rc)
340345
341-
-- 更新 LRU 时间戳(最近释放的条目更不容易被立即逐出)
342-
local now = redis.call('TIME')[1] * 1000
343-
redis.call('ZADD', zset, now, md5)
346+
if rc > 0 then
347+
-- 只有仍被引用时才更新 LRU
348+
local now = redis.call('TIME')[1] * 1000
349+
redis.call('ZADD', zset, now, md5)
350+
end
344351
345-
return {rc, 0} -- 未删除
352+
return {rc, 0}
346353
"""
347354

348355
_EVICT_AND_INSERT_LUA = r"""
@@ -354,43 +361,64 @@ def _delete_afs_files(self, victims: List[str]) -> None:
354361
local capacity = tonumber(ARGV[2])
355362
local evict_fraction = tonumber(ARGV[3])
356363
357-
-- 计算需要逐出的数量
358-
local need = math.max(1, math.floor(capacity * evict_fraction + 0.5))
364+
local unpack = unpack or table.unpack
365+
366+
-- helper: now millis
367+
local function now_ms()
368+
local t = redis.call('TIME')
369+
return t[1] * 1000 + math.floor(t[2] / 1000)
370+
end
371+
372+
local new_ref_key = ref_prefix .. new_md5
373+
374+
-- If already exists, treat as a hit: bump ref_count and refresh LRU
375+
local cur = redis.call('GET', new_ref_key)
376+
if cur then
377+
local rc = tonumber(cur) + 1
378+
redis.call('SET', new_ref_key, rc)
379+
redis.call('ZADD', zset, now_ms(), new_md5)
380+
return {1} -- success, no victims
381+
end
382+
383+
-- If not at capacity, just insert
384+
local size = redis.call('ZCARD', zset)
385+
if size < capacity then
386+
redis.call('SET', new_ref_key, 1)
387+
redis.call('ZADD', zset, now_ms(), new_md5)
388+
return {1} -- success, no victims
389+
end
390+
391+
-- At capacity: try to evict up to max_try items with rc==0, but success if at least 1 is freed
392+
local max_try = math.max(1, math.floor(size * evict_fraction + 0.5))
359393
local victims = {}
394+
local freed = 0
360395
361-
-- 获取所有键并按LRU排序
396+
-- Scan from LRU (smallest score) to MRU
362397
local all_keys = redis.call('ZRANGE', zset, 0, -1, 'WITHSCORES')
363398
local i = 1
364-
365-
-- 查找引用计数为0的键作为逐出候选
366-
while #victims < need and i <= #all_keys do
367-
local md5 = all_keys[i]
368-
local ref_key = ref_prefix .. md5
369-
local rc = redis.call('GET', ref_key)
370-
371-
if rc and tonumber(rc) <= 0 then
372-
table.insert(victims, md5)
373-
end
374-
i = i + 2 -- 跳过分数
399+
while freed < 1 and i <= #all_keys and #victims < max_try do
400+
local md5 = all_keys[i]
401+
local ref_key = ref_prefix .. md5
402+
local v = redis.call('GET', ref_key)
403+
if v and tonumber(v) <= 0 then
404+
table.insert(victims, md5)
405+
freed = freed + 1
406+
end
407+
i = i + 2 -- skip score
375408
end
376409
377-
-- 如果找到足够的候选,执行逐出
378-
if #victims >= need then
379-
-- 删除受害者
380-
for _, v in ipairs(victims) do
381-
local ref_key = ref_prefix .. v
382-
redis.call('DEL', ref_key)
383-
redis.call('ZREM', zset, v)
384-
end
385-
386-
-- 插入新的md5
387-
local ref_key = ref_prefix .. new_md5
388-
redis.call('SET', ref_key, 0)
389-
local now = redis.call('TIME')[1] * 1000
390-
redis.call('ZADD', zset, now, new_md5)
391-
392-
return {1, unpack(victims)} -- success + victims
410+
if freed >= 1 then
411+
-- delete victims
412+
for _, v in ipairs(victims) do
413+
redis.call('DEL', ref_prefix .. v)
414+
redis.call('ZREM', zset, v)
415+
end
416+
-- insert new
417+
redis.call('SET', new_ref_key, 1)
418+
redis.call('ZADD', zset, now_ms(), new_md5)
419+
return {1, unpack(victims)}
393420
else
394-
return {0} -- 逐出失败,没有足够的候选
421+
-- no zero-ref items found
422+
return {0}
395423
end
396424
"""

lightllm/server/httpserver/manager.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,6 @@ def __init__(
122122
async def _alloc_resource(self, items, uuids, token_nums, datas):
123123

124124
while True:
125-
# 检查这个图片在redis总是否已经存在
126-
# embed_exists = obtain(self.cache_client.root.get_items_embed(uuids))
127-
# for exist in embed_exists:
128-
# if exist:
129-
# continue
130-
# else:
131125
records = obtain(self.cache_client.root.alloc(uuids, token_nums))
132126

133127
if records is None:
@@ -212,8 +206,8 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
212206
audio.uuid = None
213207
audio.token_id = None
214208
audio.token_num = None
215-
if ids_to_release:
216-
self.cache_client.root.release(ids_to_release)
209+
# if ids_to_release:
210+
# self.cache_client.root.release(ids_to_release)
217211
return
218212

219213
def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None):
@@ -370,7 +364,7 @@ async def generate(
370364
# 对于还没有形成正式请求对象管理的多模态资源,需要单独自己释放
371365
# 已经放入到 req_id_to_out_inf 中的请求对象,由统一的回收循环
372366
# 进行回收。
373-
if group_request_id not in self.req_id_to_out_inf and self.args.run_mode != "llm_only":
367+
if group_request_id not in self.req_id_to_out_inf:
374368
await self._release_multimodal_resources(multimodal_params)
375369
await self.abort(group_request_id)
376370
raise e
@@ -410,7 +404,7 @@ async def get_image_embeding(
410404
visual_req_status = GroupReqObjs(group_request_id, multimodal_params, None, start_time)
411405

412406
await self.transfer_to_next_module_or_node(
413-
None, sampling_params, original_multimodal_params, visual_req_status
407+
None, sampling_params, original_multimodal_params, visual_req_status, embeding_only=True
414408
)
415409

416410
except Exception as e:
@@ -513,6 +507,7 @@ async def transfer_to_next_module_or_node(
513507
sampling_params: SamplingParams,
514508
original_multimodal_params: MultimodalParams,
515509
group_req_objs: Optional[GroupReqObjs] = None,
510+
embeding_only: Optional[bool] = False,
516511
):
517512
# 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点.
518513
if self.is_multinode_tp_master:
@@ -522,19 +517,21 @@ async def transfer_to_next_module_or_node(
522517
protocol=pickle.HIGHEST_PROTOCOL,
523518
)
524519

525-
await self.transfer_to_next_module(group_req_objs)
520+
await self.transfer_to_next_module(group_req_objs, embeding_only)
526521
return
527522

528523
async def transfer_to_next_module(
529524
self,
530525
group_req_objs: Optional[GroupReqObjs] = None,
526+
embeding_only: Optional[bool] = False,
531527
):
532528

533529
if self.pd_mode.is_P_or_NORMAL():
534530
if self.enable_multimodal:
535531
await self.vit_manager.send_to_vit(
536532
group_req_objs.to_group_req_index(),
537533
protocol=pickle.HIGHEST_PROTOCOL,
534+
embeding_only=embeding_only,
538535
)
539536

540537
if not self.enable_multimodal or self.args.enable_remote_vit:

0 commit comments

Comments
 (0)