Skip to content

Commit f756420

Browse files
authored
[bugfix]: qwen2_vl rope_type default (#1129)
1 parent a4ff3c6 commit f756420

File tree

3 files changed

+10
-1
lines changed

3 files changed

+10
-1
lines changed

lightllm/models/llama/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def _init_custom(self):
118118
scaling_type = rope_scaling["type"]
119119
else:
120120
raise ValueError(f"Unknown RoPE scaling format {rope_scaling}")
121-
if scaling_type == "yarn":
121+
if scaling_type == "default":
122+
self._init_to_get_rotary()
123+
elif scaling_type == "yarn":
122124
self._init_to_get_yarn_rotary()
123125
elif scaling_type == "dynamic":
124126
self._init_to_get_dynamic_ntk_rotary()

lightllm/models/qwen2_vl/infer_struct.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ def __init__(self):
1111
self.position_sin = None
1212

1313
def init_some_extra_state(self, model, input_ids: torch.Tensor):
14+
rope_scaling = model.config.get("rope_scaling", {})
15+
self.rope_type = rope_scaling.get("rope_type", rope_scaling.get("type", None))
16+
if self.rope_type != "mrope":
17+
super().init_some_extra_state(model, input_ids)
18+
return
1419
InferStateInfo.init_some_extra_state(self, model, input_ids)
1520
if self.is_prefill:
1621
position_ids = self.position_ids

lightllm/models/qwen2_vl/layer_infer/transformer_layer_infer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ def __init__(self, layer_num, network_config, mode=[]):
1919
self.axis_map = torch.tensor(axis_map, dtype=torch.int32, device="cuda")
2020

2121
def _get_qkv(self, input, infer_state, layer_weight):
22+
if infer_state.rope_type != "mrope":
23+
return super()._get_qkv(input, infer_state, layer_weight)
2224
q = layer_weight.q_proj.mm(input)
2325
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
2426
seq_len, _ = q.shape

0 commit comments

Comments
 (0)