diff --git a/backends/candle/src/models/flash_qwen2.rs b/backends/candle/src/models/flash_qwen2.rs index c9116311..f94e26ae 100644 --- a/backends/candle/src/models/flash_qwen2.rs +++ b/backends/candle/src/models/flash_qwen2.rs @@ -15,6 +15,7 @@ struct Qwen2Attention { attention_head_size: usize, softmax_scale: f32, + is_causal: bool, span: tracing::Span, } @@ -66,6 +67,7 @@ impl Qwen2Attention { num_key_value_heads, attention_head_size, softmax_scale, + is_causal: config.is_causal, span: tracing::span!(tracing::Level::TRACE, "attention"), }) } @@ -111,7 +113,7 @@ impl Qwen2Attention { max_s, max_s, self.softmax_scale, - false, + self.is_causal, None, None, )?; diff --git a/backends/candle/src/models/qwen2.rs b/backends/candle/src/models/qwen2.rs index 42559b87..f7855cd9 100644 --- a/backends/candle/src/models/qwen2.rs +++ b/backends/candle/src/models/qwen2.rs @@ -1,5 +1,11 @@ use crate::layers::HiddenAct; use serde::Deserialize; +use tracing; + +fn default_is_causal() -> bool { + tracing::warn!("is_causal not set in Qwen2Config, defaulting to true. e.g. Alibaba-NLP/gte-Qwen2-1.5B-instruct/ was trained with causal=False attention, but jinaai/jina-code-embeddings-0.5b with causal=True. Please set this field explicitly in the huggingface repo to avoid this warning."); + true +} #[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Qwen2Config { @@ -15,4 +21,6 @@ pub struct Qwen2Config { pub rope_theta: f32, pub sliding_window: Option, pub use_sliding_window: bool, + #[serde(default = "default_is_causal")] + pub is_causal: bool, } diff --git a/router/src/lib.rs b/router/src/lib.rs index d83bd95c..a9837f1b 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -141,9 +141,13 @@ pub async fn run( "tokenizer.json not found. text-embeddings-inference only supports fast tokenizers", ); tokenizer.with_padding(None); - // Qwen2 updates the post processor manually instead of into the tokenizer.json... + // Old Qwen2 repos updates the post processor manually instead of into the tokenizer.json. + // Newer ones (https://huggingface.co/jinaai/jina-code-embeddings-0.5b/tree/main) have it in the tokenizer.json. This is to support both cases. // https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct/blob/main/tokenization_qwen.py#L246 - if config.model_type == "qwen2" { + if config.model_type == "qwen2" && config.auto_map.as_ref().map_or(false, |m| { + m.get("AutoModel") == Some(&"modeling_qwen.Qwen2Model".to_string()) + }) { + tracing::warn!("Model is detected as a Qwen2 model with remote code. Adding a post processor manually as the tokenizer.json does not contain a post processor."); let template = TemplateProcessing::builder() .try_single("$A:0 <|endoftext|>:0") .unwrap() @@ -449,6 +453,7 @@ pub struct ModelConfig { pub pad_token_id: usize, pub id2label: Option>, pub label2id: Option>, + pub auto_map: Option>, } #[derive(Debug, Clone, PartialEq, Deserialize)]