From 50082384e4bb743a140ec45fd368ded7b5530822 Mon Sep 17 00:00:00 2001 From: Shuying Luo Date: Fri, 5 Dec 2025 18:50:24 -0800 Subject: [PATCH] Add DeepSeek V3.2 model support with Lightning Indexer sparse attention MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds support for DeepSeek-V3.2-Exp, which extends DeepSeek V3 with DeepSeek Sparse Attention (DSA) powered by a Lightning Indexer. Key features: - Lightning Indexer for O(L*k) sparse attention (vs O(L^2) dense) - Hadamard transform for activation rotation with PyTorch fallback - Non-interleaved RoPE in indexer (vs interleaved in MLA) - Configurable sparse attention toggle for training stages - Full training support with detachable indexer input New config parameters: - index_n_heads: Number of indexer heads (default: 64) - index_head_dim: Indexer head dimension (default: 128) - index_topk: Tokens selected for sparse attention (default: 2048) - use_sparse_attention: Toggle sparse vs dense attention - detach_indexer_input: For Stage 2 training optimization Reference: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- DEEPSEEK_V32_IMPLEMENTATION_PLAN.md | 278 +++++ src/transformers/models/__init__.py | 1 + .../models/auto/configuration_auto.py | 2 + src/transformers/models/auto/modeling_auto.py | 4 + .../models/deepseek_v32/__init__.py | 28 + .../configuration_deepseek_v32.py | 266 ++++ .../deepseek_v32/modeling_deepseek_v32.py | 1074 +++++++++++++++++ .../deepseek_v32/modular_deepseek_v32.py | 655 ++++++++++ 8 files changed, 2308 insertions(+) create mode 100644 DEEPSEEK_V32_IMPLEMENTATION_PLAN.md create mode 100644 src/transformers/models/deepseek_v32/__init__.py create mode 100644 src/transformers/models/deepseek_v32/configuration_deepseek_v32.py create mode 100644 src/transformers/models/deepseek_v32/modeling_deepseek_v32.py create mode 100644 src/transformers/models/deepseek_v32/modular_deepseek_v32.py diff --git a/DEEPSEEK_V32_IMPLEMENTATION_PLAN.md b/DEEPSEEK_V32_IMPLEMENTATION_PLAN.md new file mode 100644 index 000000000000..58dc3ed5caba --- /dev/null +++ b/DEEPSEEK_V32_IMPLEMENTATION_PLAN.md @@ -0,0 +1,278 @@ +# DeepSeek V3.2 Implementation Plan + +## Overview + +This document describes the implementation plan for adding `deepseek_v32` model support to HuggingFace Transformers, based on the official DeepSeek-V3.2-Exp release. + +**Key Innovation**: DeepSeek V3.2 = DeepSeek V3 + DeepSeek Sparse Attention (DSA) + +## References + +- **Official Repository**: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp +- **Technical Report**: DeepSeek_V3_2.pdf (in the repo) +- **HuggingFace Model**: https://huggingface.co/deepseek-ai/DeepSeek-V3.2-Exp +- **Reference Implementation**: `/tmp/DeepSeek-V3.2-Exp/inference/model.py` + +## Architecture Summary + +### Model Configuration (671B) + +| Parameter | Value | Description | +|-----------|-------|-------------| +| `vocab_size` | 129280 | Vocabulary size | +| `hidden_size` | 7168 | Model dimension (`dim` in reference) | +| `intermediate_size` | 18432 | Dense MLP intermediate size (`inter_dim`) | +| `moe_intermediate_size` | 2048 | MoE expert intermediate size (`moe_inter_dim`) | +| `num_hidden_layers` | 61 | Number of transformer layers (`n_layers`) | +| `first_k_dense_replace` | 3 | First k layers use dense MLP (`n_dense_layers`) | +| `num_attention_heads` | 128 | Number of attention heads (`n_heads`) | +| `n_routed_experts` | 256 | Number of routed experts | +| `n_shared_experts` | 1 | Number of shared experts | +| `num_experts_per_tok` | 8 | Activated experts per token (`n_activated_experts`) | +| `n_group` | 8 | Expert groups (`n_expert_groups`) | +| `topk_group` | 4 | Groups selected per token (`n_limited_groups`) | +| `routed_scaling_factor` | 2.5 | MoE routing scale (`route_scale`) | +| `scoring_func` | "sigmoid" | MoE scoring function (`score_func`) | +| `q_lora_rank` | 1536 | Query LoRA rank | +| `kv_lora_rank` | 512 | KV LoRA rank | +| `qk_nope_head_dim` | 128 | QK dimension without RoPE | +| `qk_rope_head_dim` | 64 | QK dimension with RoPE | +| `v_head_dim` | 128 | Value head dimension | +| **`index_n_heads`** | 64 | **NEW: Indexer heads** | +| **`index_head_dim`** | 128 | **NEW: Indexer head dimension** | +| **`index_topk`** | 2048 | **NEW: Top-k tokens for sparse attention** | + +### Key Components + +1. **MLA (Multi-Head Latent Attention)** - Same as V3 + - LoRA-compressed Q/KV projections + - Split head dims: `qk_nope_head_dim` + `qk_rope_head_dim` + - **Interleaved RoPE** layout + +2. **Lightning Indexer** - **NEW in V3.2** + - Computes index scores for sparse token selection + - **Non-interleaved RoPE** layout (critical difference!) + - Uses Hadamard transform for activation rotation + - Learnable parameters: `wq_b`, `wk`, `k_norm`, `weights_proj` + +3. **MoE** - Same as V3 + - Sigmoid scoring with group routing + - Shared experts always active + +4. **YaRN RoPE** - Same as V3 + - Extended context support + +## Training Strategy (from Technical Report) + +DeepSeek trains the sparse attention in **two stages**: + +### Stage 1: Dense Warm-up (Indexer Only) +- **Duration**: 1000 steps, 2.1B tokens +- **Learning rate**: 1e-3 +- **What's trained**: Only the Lightning Indexer +- **What's frozen**: All other model parameters +- **Attention**: Dense (full attention) +- **Objective**: KL-divergence loss to align indexer with main attention distribution + +``` +L_I = sum_t DKL(p_t,: || Softmax(I_t,:)) +``` + +Where `p_t,:` is the L1-normalized sum of main attention scores across heads. + +### Stage 2: Sparse Training (Full Model) +- **Duration**: 15000 steps, 943.7B tokens +- **Learning rate**: 7.3e-6 +- **What's trained**: All parameters (main model + indexer) +- **Attention**: Sparse (top-k = 2048) +- **Key detail**: Indexer input is **detached** from computational graph + - Indexer optimized only via L_I (KL loss) + - Main model optimized only via language modeling loss + +``` +L_I = sum_t DKL(p_t,S_t || Softmax(I_t,S_t)) +``` + +Where `S_t` is the set of selected top-k tokens. + +## Implementation Approach + +### Strategy: Modular Extension of DeepSeek V3 + +We extend `deepseek_v3` with minimal changes, adding only the Indexer and sparse attention logic. + +### Files to Create + +``` +src/transformers/models/deepseek_v32/ +├── __init__.py +├── configuration_deepseek_v32.py +├── modular_deepseek_v32.py # Source of truth +└── modeling_deepseek_v32.py # Auto-generated +``` + +### Files to Modify + +1. `src/transformers/models/__init__.py` - Add import +2. `src/transformers/models/auto/configuration_auto.py` - Register config +3. `src/transformers/models/auto/modeling_auto.py` - Register models + +### New Classes + +| Class | Extends | Description | +|-------|---------|-------------| +| `DeepseekV32Config` | `DeepseekV3Config` | Adds indexer config params | +| `DeepseekV32Indexer` | `nn.Module` | Lightning Indexer implementation | +| `DeepseekV32Attention` | `DeepseekV3Attention` | Adds indexer + sparse attention | +| `DeepseekV32DecoderLayer` | `DeepseekV3DecoderLayer` | Uses new attention | +| `DeepseekV32Model` | `DeepseekV3Model` | Uses new decoder layers | +| `DeepseekV32ForCausalLM` | `DeepseekV3ForCausalLM` | Main model class | + +### Hadamard Transform Strategy + +**Option B: Optional with Pure PyTorch Fallback** + +```python +try: + from fast_hadamard_transform import hadamard_transform + HAS_FAST_HADAMARD = True +except ImportError: + HAS_FAST_HADAMARD = False + logger.warning( + "fast-hadamard-transform not installed. Using slower PyTorch fallback. " + "Install with: pip install fast-hadamard-transform" + ) + +def hadamard_transform_fallback(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + """Pure PyTorch Hadamard transform (slower but CPU-compatible).""" + dim = x.shape[-1] + # Pad to power of 2 if needed + if dim & (dim - 1) != 0: + next_pow2 = 1 << (dim - 1).bit_length() + x = F.pad(x, (0, next_pow2 - dim)) + dim = next_pow2 + + # Fast Walsh-Hadamard Transform + h = 1 + while h < dim: + for i in range(0, dim, h * 2): + for j in range(i, i + h): + a = x[..., j] + b = x[..., j + h] + x[..., j] = a + b + x[..., j + h] = a - b + h *= 2 + + return x * scale + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + """Apply Hadamard transform for activation rotation.""" + hidden_size = x.size(-1) + if HAS_FAST_HADAMARD: + return hadamard_transform(x.contiguous(), scale=hidden_size ** -0.5) + else: + return hadamard_transform_fallback(x.clone(), scale=hidden_size ** -0.5) +``` + +**Note on fallback limitations**: The pure PyTorch fallback will be significantly slower (10-100x) than the CUDA version. For production training, `fast-hadamard-transform` should be installed. + +### Training Support + +All components are **fully trainable** to support: +- Full fine-tuning +- LoRA/adapter on any component +- Freezing specific components (like the indexer during warm-up) + +Key implementation details for training: + +1. **No `@torch.no_grad()`** - All forward passes support gradients +2. **Detachable indexer input** - Config flag `detach_indexer_input` (default: False for inference, can be set True for Stage 2 training) +3. **Indexer loss computation** - Helper method to compute KL divergence loss for indexer training +4. **Sparse attention toggle** - Config flag `use_sparse_attention` to enable/disable + +### RoPE Layout Critical Note + +From the technical report update: + +> "The input tensor to RoPE in the indexer module requires a **non-interleaved** layout, whereas RoPE in the MLA module expects an **interleaved** layout." + +Implementation: +```python +# In MLA (main attention) - interleaved RoPE +q_pe = apply_rotary_emb(q_pe, freqs_cis, interleaved=True) + +# In Indexer - non-interleaved RoPE +q_pe = apply_rotary_emb(q_pe, freqs_cis, interleaved=False) +k_pe = apply_rotary_emb(k_pe, freqs_cis, interleaved=False) +``` + +## Sparse Attention Training Recommendation + +Based on the technical report, here's the recommended training approach: + +### For Fine-tuning from V3.2 Checkpoint +- Use sparse attention (same as inference) +- All parameters trainable +- Optionally detach indexer input for separate optimization + +### For Training from Scratch or V3 Checkpoint +Follow the two-stage approach from the paper: + +**Stage 1: Dense Warm-up** +```python +# Freeze all except indexer +for name, param in model.named_parameters(): + if "indexer" not in name: + param.requires_grad = False + +# Use dense attention +model.config.use_sparse_attention = False + +# Train with KL loss on indexer +``` + +**Stage 2: Sparse Training** +```python +# Unfreeze all parameters +for param in model.parameters(): + param.requires_grad = True + +# Enable sparse attention +model.config.use_sparse_attention = True + +# Detach indexer input for separate optimization +model.config.detach_indexer_input = True + +# Train with: +# - Language modeling loss for main model +# - KL loss for indexer (computed separately) +``` + +### Configuration Flags for Training + +| Flag | Default | Description | +|------|---------|-------------| +| `use_sparse_attention` | True | Enable/disable sparse attention | +| `detach_indexer_input` | False | Detach indexer input from main model graph | +| `index_topk` | 2048 | Number of tokens to select | + +## Testing Plan + +1. **Unit tests**: Test each component (Indexer, Attention, etc.) +2. **Integration test**: Load tiny model, run forward pass +3. **Numerical equivalence**: Compare with reference implementation +4. **Gradient flow**: Verify gradients flow through all components + +## Timeline + +1. Configuration class (~50 lines) +2. Modular implementation (~400 lines) +3. Auto-generation of modeling file +4. Registration in auto mappings +5. Basic tests + +## Open Questions / Future Work + +1. **FP8 support**: The reference uses FP8 quantization. This could be added later as an optimization. +2. **FlashAttention integration**: Sparse attention with FlashAttention kernels +3. **Gradient checkpointing**: For memory-efficient training diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index 71b2155e9bc5..54a3b7d9e20d 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -89,6 +89,7 @@ from .decision_transformer import * from .deepseek_v2 import * from .deepseek_v3 import * + from .deepseek_v32 import * from .deepseek_vl import * from .deepseek_vl_hybrid import * from .deformable_detr import * diff --git a/src/transformers/models/auto/configuration_auto.py b/src/transformers/models/auto/configuration_auto.py index 38a0abb9e2d7..0b1980859e46 100644 --- a/src/transformers/models/auto/configuration_auto.py +++ b/src/transformers/models/auto/configuration_auto.py @@ -109,6 +109,7 @@ ("decision_transformer", "DecisionTransformerConfig"), ("deepseek_v2", "DeepseekV2Config"), ("deepseek_v3", "DeepseekV3Config"), + ("deepseek_v32", "DeepseekV32Config"), ("deepseek_vl", "DeepseekVLConfig"), ("deepseek_vl_hybrid", "DeepseekVLHybridConfig"), ("deformable_detr", "DeformableDetrConfig"), @@ -544,6 +545,7 @@ ("decision_transformer", "Decision Transformer"), ("deepseek_v2", "DeepSeek-V2"), ("deepseek_v3", "DeepSeek-V3"), + ("deepseek_v32", "DeepSeek-V3.2"), ("deepseek_vl", "DeepseekVL"), ("deepseek_vl_hybrid", "DeepseekVLHybrid"), ("deformable_detr", "Deformable DETR"), diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index ddd29ad96d5b..40f3793dcce9 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -115,6 +115,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("decision_transformer", "DecisionTransformerModel"), ("deepseek_v2", "DeepseekV2Model"), ("deepseek_v3", "DeepseekV3Model"), + ("deepseek_v32", "DeepseekV32Model"), ("deepseek_vl", "DeepseekVLModel"), ("deepseek_vl_hybrid", "DeepseekVLHybridModel"), ("deformable_detr", "DeformableDetrModel"), @@ -649,6 +650,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("dbrx", "DbrxForCausalLM"), ("deepseek_v2", "DeepseekV2ForCausalLM"), ("deepseek_v3", "DeepseekV3ForCausalLM"), + ("deepseek_v32", "DeepseekV32ForCausalLM"), ("diffllama", "DiffLlamaForCausalLM"), ("doge", "DogeForCausalLM"), ("dots1", "Dots1ForCausalLM"), @@ -1217,6 +1219,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deberta-v2", "DebertaV2ForSequenceClassification"), ("deepseek_v2", "DeepseekV2ForSequenceClassification"), ("deepseek_v3", "DeepseekV3ForSequenceClassification"), + ("deepseek_v32", "DeepseekV32ForSequenceClassification"), ("diffllama", "DiffLlamaForSequenceClassification"), ("distilbert", "DistilBertForSequenceClassification"), ("doge", "DogeForSequenceClassification"), @@ -1436,6 +1439,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("deberta", "DebertaForTokenClassification"), ("deberta-v2", "DebertaV2ForTokenClassification"), ("deepseek_v3", "DeepseekV3ForTokenClassification"), + ("deepseek_v32", "DeepseekV32ForTokenClassification"), ("diffllama", "DiffLlamaForTokenClassification"), ("distilbert", "DistilBertForTokenClassification"), ("electra", "ElectraForTokenClassification"), diff --git a/src/transformers/models/deepseek_v32/__init__.py b/src/transformers/models/deepseek_v32/__init__.py new file mode 100644 index 000000000000..7ceb4699c6ae --- /dev/null +++ b/src/transformers/models/deepseek_v32/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 DeepSeek AI and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...utils import _LazyModule +from ...utils.import_utils import define_import_structure + + +if TYPE_CHECKING: + from .configuration_deepseek_v32 import * + from .modeling_deepseek_v32 import * +else: + import sys + + _file = globals()["__file__"] + sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__) diff --git a/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py b/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py new file mode 100644 index 000000000000..c79912e36042 --- /dev/null +++ b/src/transformers/models/deepseek_v32/configuration_deepseek_v32.py @@ -0,0 +1,266 @@ +# coding=utf-8 +# Copyright 2025 DeepSeek AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on the DeepSeek-V3.2-Exp implementation from DeepSeek AI. +# Reference: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""DeepSeek V3.2 model configuration""" + +from typing import Optional + +from ...configuration_utils import PreTrainedConfig +from ...modeling_rope_utils import RopeParameters + + +class DeepseekV32Config(PreTrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV32Model`]. It is used to instantiate + a DeepSeek V3.2 model according to the specified arguments, defining the model architecture. + + DeepSeek V3.2 extends DeepSeek V3 with DeepSeek Sparse Attention (DSA), which uses a Lightning Indexer + to select top-k tokens for sparse attention, reducing complexity from O(L^2) to O(L*k). + + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. + Read the documentation from [`PreTrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the DeepSeek V3.2 model. + hidden_size (`int`, *optional*, defaults to 7168): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 18432): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 2048): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 61): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 128): + Number of attention heads for each attention layer. + num_key_value_heads (`int`, *optional*, defaults to 128): + Number of key_value heads for Grouped Query Attention. + n_shared_experts (`int`, *optional*, defaults to 1): + Number of shared experts (always active). + n_routed_experts (`int`, *optional*, defaults to 256): + Number of routed experts. + routed_scaling_factor (`float`, *optional*, defaults to 2.5): + Scaling factor for routed experts. + kv_lora_rank (`int`, *optional*, defaults to 512): + Rank of the LoRA matrices for key and value projections. + q_lora_rank (`int`, *optional*, defaults to 1536): + Rank of the LoRA matrices for query projections. + qk_rope_head_dim (`int`, *optional*, defaults to 64): + Dimension of query/key heads that use rotary position embeddings. + v_head_dim (`int`, *optional*, defaults to 128): + Dimension of the value heads. + qk_nope_head_dim (`int`, *optional*, defaults to 128): + Dimension of query/key heads without rotary position embeddings. + n_group (`int`, *optional*, defaults to 8): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to 4): + Number of groups selected per token for expert routing. + num_experts_per_tok (`int`, *optional*, defaults to 8): + Number of experts activated per token. + first_k_dense_replace (`int`, *optional*, defaults to 3): + Number of dense layers before switching to MoE layers. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the weights of the routed experts. + hidden_act (`str`, *optional*, defaults to `"silu"`): + The non-linear activation function in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether to return the last key/values attentions. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings. + rope_parameters (`RopeParameters`, *optional*): + Configuration for the RoPE embeddings. + rope_interleave (`bool`, *optional*, defaults to `True`): + Whether to interleave the rotary position embeddings (for MLA). + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in attention layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + index_n_heads (`int`, *optional*, defaults to 64): + Number of heads in the Lightning Indexer. + index_head_dim (`int`, *optional*, defaults to 128): + Dimension of each indexer head. + index_topk (`int`, *optional*, defaults to 2048): + Number of tokens to select for sparse attention. + use_sparse_attention (`bool`, *optional*, defaults to True): + Whether to use sparse attention. Set to False for dense attention + (useful for the dense warm-up training stage). + detach_indexer_input (`bool`, *optional*, defaults to False): + Whether to detach the indexer input from the computational graph. + Used in Stage 2 training for separate optimization of indexer. + + Example: + + ```python + >>> from transformers import DeepseekV32Model, DeepseekV32Config + + >>> # Initializing a DeepSeek V3.2 style configuration + >>> configuration = DeepseekV32Config() + + >>> # Initializing a model from the configuration + >>> model = DeepseekV32Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "deepseek_v32" + keys_to_ignore_at_inference = ["past_key_values"] + + base_model_tp_plan = { + "layers.*.mlp.experts.gate_up_proj": "local_rowwise", + "layers.*.mlp.experts.down_proj": "local_rowwise", + "layers.*.mlp.experts": "gather", + "layers.*.mlp.shared_experts.gate_proj": "colwise", + "layers.*.mlp.shared_experts.up_proj": "colwise", + "layers.*.mlp.shared_experts.down_proj": "rowwise", + "layers.*.mlp.gate_proj": "colwise", + "layers.*.mlp.up_proj": "colwise", + "layers.*.mlp.down_proj": "rowwise", + } + + base_model_pp_plan = { + "embed_tokens": (["input_ids"], ["inputs_embeds"]), + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), + "norm": (["hidden_states"], ["hidden_states"]), + } + + attribute_map = { + "num_local_experts": "n_routed_experts", + } + + def __init__( + self, + vocab_size: Optional[int] = 129280, + hidden_size: Optional[int] = 7168, + intermediate_size: Optional[int] = 18432, + moe_intermediate_size: Optional[int] = 2048, + num_hidden_layers: Optional[int] = 61, + num_attention_heads: Optional[int] = 128, + num_key_value_heads: Optional[int] = 128, + n_shared_experts: Optional[int] = 1, + n_routed_experts: Optional[int] = 256, + routed_scaling_factor: Optional[float] = 2.5, + kv_lora_rank: Optional[int] = 512, + q_lora_rank: Optional[int] = 1536, + qk_rope_head_dim: Optional[int] = 64, + v_head_dim: Optional[int] = 128, + qk_nope_head_dim: Optional[int] = 128, + n_group: Optional[int] = 8, + topk_group: Optional[int] = 4, + num_experts_per_tok: Optional[int] = 8, + first_k_dense_replace: Optional[int] = 3, + norm_topk_prob: Optional[bool] = True, + hidden_act: Optional[str] = "silu", + max_position_embeddings: Optional[int] = 4096, + initializer_range: Optional[float] = 0.02, + rms_norm_eps: Optional[float] = 1e-6, + use_cache: Optional[bool] = True, + pad_token_id: Optional[int] = None, + bos_token_id: Optional[int] = 0, + eos_token_id: Optional[int] = 1, + pretraining_tp: Optional[int] = 1, + tie_word_embeddings: Optional[bool] = False, + rope_parameters: Optional[RopeParameters | dict[str, RopeParameters]] = None, + rope_interleave: Optional[bool] = True, + attention_bias: Optional[bool] = False, + attention_dropout: Optional[float] = 0.0, + # DeepSeek V3.2 specific: Lightning Indexer + index_n_heads: Optional[int] = 64, + index_head_dim: Optional[int] = 128, + index_topk: Optional[int] = 2048, + use_sparse_attention: Optional[bool] = True, + detach_indexer_input: Optional[bool] = False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.head_dim = qk_rope_head_dim + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.rope_interleave = rope_interleave + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.rope_parameters = rope_parameters + + # DeepSeek V3.2 specific: Lightning Indexer + self.index_n_heads = index_n_heads + self.index_head_dim = index_head_dim + self.index_topk = index_topk + self.use_sparse_attention = use_sparse_attention + self.detach_indexer_input = detach_indexer_input + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def convert_rope_params_to_dict(self, ignore_keys_at_rope_validation: Optional[set] = None, **kwargs): + rope_scaling = kwargs.pop("rope_scaling", None) + self.rope_parameters = rope_scaling or self.rope_parameters + self.rope_parameters = self.rope_parameters if self.rope_parameters is not None else {} + + # Standardize and validate the correctness of rotary position embeddings parameters + self.rope_parameters.setdefault("rope_theta", kwargs.pop("rope_theta", self.default_theta)) + self.standardize_rope_params() + self.validate_rope(ignore_keys=ignore_keys_at_rope_validation) + + # Convert to float because RoPE fn expect a float. Models on the hub were saved as int + for key in ["beta_fast", "beta_slow", "factor"]: + if key in self.rope_parameters: + self.rope_parameters[key] = float(self.rope_parameters[key]) + return kwargs + + +__all__ = ["DeepseekV32Config"] diff --git a/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py b/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py new file mode 100644 index 000000000000..063d965eaca2 --- /dev/null +++ b/src/transformers/models/deepseek_v32/modeling_deepseek_v32.py @@ -0,0 +1,1074 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/deepseek_v32/modular_deepseek_v32.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_deepseek_v32.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 DeepSeek AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on the DeepSeek-V3.2-Exp implementation from DeepSeek AI. +# Reference: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from collections.abc import Callable +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ... import initialization as init +from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache +from ...generation import GenerationMixin +from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub +from ...masking_utils import create_causal_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_layers import ( + GenericForSequenceClassification, + GenericForTokenClassification, + GradientCheckpointingLayer, +) +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs +from .configuration_deepseek_v32 import DeepseekV32Config + + +@use_kernel_forward_from_hub("RMSNorm") +class DeepseekV32RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV32RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class DeepseekV32RotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: DeepseekV32Config, device=None): + super().__init__() + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + + self.rope_type = self.config.rope_parameters["rope_type"] + rope_init_fn: Callable = self.compute_default_rope_parameters + if self.rope_type != "default": + rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = rope_init_fn(self.config, device) + + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = inv_freq + + @staticmethod + def compute_default_rope_parameters( + config: Optional[DeepseekV32Config] = None, + device: Optional["torch.device"] = None, + seq_len: Optional[int] = None, + ) -> tuple["torch.Tensor", float]: + """ + Computes the inverse frequencies according to the original RoPE implementation + Args: + config ([`~transformers.PreTrainedConfig`]): + The model configuration. + device (`torch.device`): + The device to use for initialization of the inverse frequencies. + seq_len (`int`, *optional*): + The current sequence length. Unused for this type of RoPE. + Returns: + Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the + post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE). + """ + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + + attention_factor = 1.0 # Unused in this type of RoPE + + # Compute the inverse frequencies + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, attention_factor + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class DeepseekV32MLP(nn.Module): + def __init__(self, config, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size if intermediate_size is None else intermediate_size + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class DeepseekV32TopkRouter(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.n_routed_experts = config.n_routed_experts + + self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size))) + self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts)) + + def forward(self, hidden_states): + hidden_states = hidden_states.view(-1, self.config.hidden_size) + router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32)) + return router_logits + + +class DeepseekV32NaiveMoe(nn.Module): + """Collection of expert weights stored as 3D tensors.""" + + def __init__(self, config): + super().__init__() + self.num_experts = config.num_local_experts + self.hidden_dim = config.hidden_size + self.intermediate_dim = config.intermediate_size + self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim)) + self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim)) + self.act_fn = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) + + return final_hidden_states + + +class DeepseekV32MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.experts = DeepseekV32NaiveMoe(config) + self.gate = DeepseekV32TopkRouter(config) + self.shared_experts = DeepseekV32MLP( + config=config, intermediate_size=config.moe_intermediate_size * config.n_shared_experts + ) + self.n_routed_experts = config.n_routed_experts + self.n_group = config.n_group + self.topk_group = config.topk_group + self.norm_topk_prob = config.norm_topk_prob + self.routed_scaling_factor = config.routed_scaling_factor + self.top_k = config.num_experts_per_tok + + def route_tokens_to_experts(self, router_logits): + router_logits = router_logits.sigmoid() + router_logits_for_choice = router_logits + self.gate.e_score_correction_bias + group_scores = ( + router_logits_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + .topk(2, dim=-1)[0] + .sum(dim=-1) + ) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + score_mask = ( + group_mask.unsqueeze(-1) + .expand(-1, self.n_group, self.n_routed_experts // self.n_group) + .reshape(-1, self.n_routed_experts) + ) + scores_for_choice = router_logits_for_choice.masked_fill(~score_mask.bool(), 0.0) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1] + topk_weights = router_logits.gather(1, topk_indices) + if self.norm_topk_prob: + denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20 + topk_weights /= denominator + topk_weights = topk_weights * self.routed_scaling_factor + return topk_indices, topk_weights + + def forward(self, hidden_states): + residuals = hidden_states + orig_shape = hidden_states.shape + router_logits = self.gate(hidden_states) + topk_indices, topk_weights = self.route_tokens_to_experts(router_logits) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape) + hidden_states = hidden_states + self.shared_experts(residuals) + return hidden_states + + +def hadamard_transform_fallback(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + """ + Pure PyTorch Fast Walsh-Hadamard Transform fallback. + + This is significantly slower than the CUDA version but works on CPU and + doesn't require the fast-hadamard-transform package. + + Args: + x: Input tensor with shape (..., dim) where dim should be a power of 2 + scale: Multiplier for the output + + Returns: + Transformed tensor with same shape as input + """ + orig_dtype = x.dtype + x = x.float() + dim = x.shape[-1] + + # Pad to power of 2 if needed + if dim & (dim - 1) != 0: + next_pow2 = 1 << (dim - 1).bit_length() + x = F.pad(x, (0, next_pow2 - dim)) + dim = next_pow2 + + # Fast Walsh-Hadamard Transform using butterfly operations + h = 1 + while h < dim: + # Reshape for butterfly operation + x = x.view(*x.shape[:-1], dim // (2 * h), 2, h) + # Butterfly: [a, b] -> [a + b, a - b] + a = x[..., 0, :] + b = x[..., 1, :] + x = torch.stack([a + b, a - b], dim=-2) + x = x.view(*x.shape[:-3], dim) + h *= 2 + + return (x * scale).to(orig_dtype) + + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + """ + Apply Hadamard transform for activation rotation in the indexer. + + This is used in the Lightning Indexer to rotate Q and K activations + before computing index scores. + + Args: + x: Input tensor with shape (..., hidden_size) + + Returns: + Rotated tensor with same shape + """ + hidden_size = x.size(-1) + scale = hidden_size**-0.5 + + if HAS_FAST_HADAMARD: + # fast-hadamard-transform requires contiguous bfloat16 input + return hadamard_transform(x.contiguous(), scale=scale) + else: + return hadamard_transform_fallback(x, scale=scale) + + +def apply_rotary_pos_emb_non_interleave( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Applies Rotary Position Embedding with NON-INTERLEAVED layout. + + This is specifically for the Indexer, which requires non-interleaved RoPE + (different from the MLA which uses interleaved RoPE). + + The difference is in how dimensions are paired: + - Interleaved: pairs (0,1), (2,3), (4,5), ... + - Non-interleaved: pairs (0, dim/2), (1, dim/2+1), ... + + Args: + q: Query tensor of shape (batch, seq_len, heads, head_dim) + k: Key tensor of shape (batch, seq_len, heads, head_dim) + cos: Cosine of rotary angles + sin: Sine of rotary angles + unsqueeze_dim: Dimension to unsqueeze cos/sin for broadcasting + + Returns: + Tuple of rotated (query, key) tensors + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Non-interleaved: split in half and rotate + # q = [q1, q2] where q1 is first half, q2 is second half + # rotated = [q1 * cos - q2 * sin, q1 * sin + q2 * cos] + q1, q2 = q.chunk(2, dim=-1) + k1, k2 = k.chunk(2, dim=-1) + + q_embed = torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1) + k_embed = torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1) + + return q_embed, k_embed + + +class DeepseekV32Indexer(nn.Module): + """ + Lightning Indexer for DeepSeek Sparse Attention (DSA). + + The indexer computes index scores to select which tokens each query should + attend to, reducing attention complexity from O(L^2) to O(L*k). + + The index score formula is: + I_{t,s} = sum_j w^I_{t,j} * ReLU(q^I_{t,j} * k^I_s) + + Key implementation details: + 1. Uses Hadamard transform on Q and K before scoring + 2. Uses NON-INTERLEAVED RoPE (different from MLA which uses interleaved) + 3. Uses LayerNorm on K (not RMSNorm) + + Args: + config: DeepseekV32Config + layer_idx: Index of the layer this indexer belongs to + """ + + def __init__(self, config: DeepseekV32Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.index_topk = config.index_topk + self.q_lora_rank = config.q_lora_rank + + # Query projection from compressed representation + self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False) + + # Key projection (single head, broadcast to all heads) + self.k_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + + # LayerNorm for keys (not RMSNorm, following reference) + self.k_layernorm = nn.LayerNorm(self.head_dim) + + # Per-head weight projection + self.weight_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) + + self.softmax_scale = self.head_dim**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + q_compressed: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute top-k token indices for sparse attention. + + Args: + hidden_states: Input hidden states [batch, seq_len, hidden_size] + q_compressed: Compressed query representation [batch, seq_len, q_lora_rank] + position_embeddings: Tuple of (cos, sin) for RoPE + attention_mask: Optional attention mask + + Returns: + topk_indices: Indices of selected tokens [batch, seq_len, topk] + """ + batch_size, seq_len, _ = hidden_states.shape + cos, sin = position_embeddings + + # Query path + q = self.q_b_proj(q_compressed) # [B, S, num_heads * head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Split into RoPE and non-RoPE parts + q_rope, q_nope = torch.split(q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) + + # Key path + k = self.k_proj(hidden_states) # [B, S, head_dim] + k = self.k_layernorm(k) + + k_rope, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) + + # Apply NON-INTERLEAVED RoPE (critical difference from MLA!) + k_rope = k_rope.unsqueeze(2) # [B, S, 1, rope_dim] + q_rope, k_rope = apply_rotary_pos_emb_non_interleave(q_rope, k_rope, cos, sin) + k_rope = k_rope.squeeze(2) # [B, S, rope_dim] + + # Concatenate back + q = torch.cat([q_rope, q_nope], dim=-1) # [B, S, H, D] + k = torch.cat([k_rope, k_nope], dim=-1) # [B, S, D] + + # Apply Hadamard transform for activation rotation + q = rotate_activation(q) + k = rotate_activation(k) + + # Compute index scores: I_{t,s} = sum_j w_{t,j} * ReLU(q_{t,j} * k_s) + # q: [B, S, H, D], k: [B, S, D] + # First compute q * k for all pairs: [B, S_q, H, S_k] + q = q.transpose(1, 2) # [B, H, S_q, D] + k = k.unsqueeze(1) # [B, 1, S_k, D] + + # Compute attention-like scores + scores = torch.matmul(q, k.transpose(-1, -2)) # [B, H, S_q, S_k] + + # Apply ReLU + scores = F.relu(scores) + + # Get per-head weights + weights = self.weight_proj(hidden_states.float()) # [B, S, H] + weights = weights * (self.num_heads**-0.5) * self.softmax_scale + weights = weights.transpose(1, 2).unsqueeze(-1) # [B, H, S, 1] + + # Weighted sum over heads: [B, S_q, S_k] + index_scores = (scores * weights).sum(dim=1) # [B, S_q, S_k] + + # Apply attention mask if provided + if attention_mask is not None: + # attention_mask is typically [B, 1, S_q, S_k] or [B, S_q, S_k] + if attention_mask.dim() == 4: + attention_mask = attention_mask.squeeze(1) + index_scores = index_scores + attention_mask + + # Select top-k tokens + k_select = min(self.index_topk, seq_len) + topk_indices = index_scores.topk(k_select, dim=-1).indices # [B, S, topk] + + return topk_indices + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +@use_kernel_func_from_hub("rotary_pos_emb") +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def apply_rotary_pos_emb_interleave(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + r""" + TODO let's just use the original freqcis computation to not have the view + transpose + reshape! This is not optimized! + Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekV32Attention(nn.Module): + """ + DeepSeek V3.2 Attention with Lightning Indexer for sparse attention. + + Extends DeepseekV3Attention by adding the Lightning Indexer which selects + top-k tokens for each query position, enabling sparse attention. + """ + + def __init__(self, config: DeepseekV32Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.qk_head_dim = config.qk_head_dim + + self.is_causal = True + if self.q_lora_rank is None: + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=False) + else: + self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=config.attention_bias) + self.q_a_layernorm = DeepseekV32RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=False) + + self.kv_a_proj_with_mqa = nn.Linear( + config.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV32RMSNorm(self.kv_lora_rank) + self.kv_b_proj = nn.Linear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + + self.scaling = self.qk_head_dim ** (-0.5) + if self.config.rope_parameters.get("rope_type", "default") != "default": + mscale_all_dim = self.config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_parameters["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + # Add the Lightning Indexer + self.indexer = DeepseekV32Indexer(config, layer_idx) + + # Update softmax scale with YaRN mscale if needed + if hasattr(config, "rope_parameters") and config.rope_parameters: + rope_type = config.rope_parameters.get("rope_type", "default") + if rope_type != "default": + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = config.rope_parameters.get("factor", 1.0) + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + """ + Forward pass with sparse attention via Lightning Indexer. + """ + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + # Query path with LoRA compression + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + q_compressed = None + else: + q_compressed = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_states = self.q_b_proj(q_compressed) + + # Optionally detach for separate indexer optimization (Stage 2 training) + if self.config.detach_indexer_input: + q_compressed_for_indexer = q_compressed.detach() + else: + q_compressed_for_indexer = q_compressed + + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # KV path with compression + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + # Apply RoPE (INTERLEAVED for MLA) + cos, sin = position_embeddings + if self.config.rope_interleave: + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + from ..llama.modeling_llama import apply_rotary_pos_emb + + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + # Update cache if provided + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # Compute attention scores + # For flash attention with different head dims, pad value states + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + # Standard attention computation + attn_output, attn_weights = self._compute_attention( + query_states, + key_states, + value_states, + attention_mask, + q_compressed_for_indexer if self.q_lora_rank else None, + hidden_states, + position_embeddings, + **kwargs, + ) + + # Remove padding if added + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + def _compute_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + q_compressed: Optional[torch.Tensor], + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute attention with optional sparse masking from the indexer. + """ + batch_size, num_heads, seq_len, _ = query_states.shape + + # Check if we should use sparse attention + use_sparse = ( + self.config.use_sparse_attention + and q_compressed is not None + and seq_len > 1 # Only for prefill, not decode + ) + + if use_sparse: + # Get top-k indices from the indexer + topk_indices = self.indexer( + hidden_states, + q_compressed, + position_embeddings, + attention_mask, + ) + + # Create sparse attention mask + # topk_indices: [B, S, topk] + # We need to create a mask that only allows attention to selected tokens + kv_seq_len = key_states.shape[2] + sparse_mask = torch.full( + (batch_size, seq_len, kv_seq_len), + float("-inf"), + device=query_states.device, + dtype=query_states.dtype, + ) + + # Scatter 0s at the selected positions + sparse_mask.scatter_(-1, topk_indices, 0.0) + + # Combine with causal mask if provided + if attention_mask is not None: + if attention_mask.dim() == 4: + # [B, 1, S, S] -> [B, S, S] + attention_mask = attention_mask.squeeze(1) + sparse_mask = sparse_mask + attention_mask + + # Expand for heads: [B, H, S, S] + attention_mask = sparse_mask.unsqueeze(1).expand(-1, num_heads, -1, -1) + + # Use eager attention for now (can be extended to flash attention) + attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class DeepseekV32DecoderLayer(GradientCheckpointingLayer): + """DeepSeek V3.2 decoder layer with sparse attention.""" + + def __init__(self, config: DeepseekV32Config, layer_idx: int): + # Call grandparent init to avoid V3 attention + super().__init__() + self.hidden_size = config.hidden_size + + # Use V3.2 attention with indexer + self.self_attn = DeepseekV32Attention(config=config, layer_idx=layer_idx) + + # MLP: dense for first k layers, MoE for rest + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV32MoE(config) + else: + self.mlp = DeepseekV32MLP(config) + + self.input_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, _ = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@auto_docstring +class DeepseekV32PreTrainedModel(PreTrainedModel): + """Base class for DeepSeek V3.2 models.""" + + config: DeepseekV32Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV32DecoderLayer"] + _skip_keys_device_placement = ["past_key_values"] + _supports_flash_attn = True + _supports_sdpa = True + _supports_flex_attn = True + _can_compile_fullgraph = False + _supports_attention_backend = True + _can_record_outputs = { + "hidden_states": DeepseekV32DecoderLayer, + "attentions": DeepseekV32Attention, + } + + config_class = DeepseekV32Config + + @torch.no_grad() + def _init_weights(self, module): + super()._init_weights(module) + if isinstance(module, DeepseekV32TopkRouter): + init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) + elif isinstance(module, DeepseekV32NaiveMoe): + init.normal_(module.gate_up_proj, mean=0.0, std=self.config.initializer_range) + init.normal_(module.down_proj, mean=0.0, std=self.config.initializer_range) + + +@auto_docstring +class DeepseekV32Model(DeepseekV32PreTrainedModel): + """DeepSeek V3.2 Model with sparse attention.""" + + _keys_to_ignore_on_load_unexpected = [r"model\.layers\.61.*"] + + config_class = DeepseekV32Config + + def __init__(self, config: DeepseekV32Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV32DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV32RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + @check_model_inputs + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutputWithPast: + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if inputs_embeds is None: + inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) + + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position: torch.Tensor = ( + torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens + ) + + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = create_causal_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=position_ids, + ) + + hidden_states = inputs_embeds + position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) + + for decoder_layer in self.layers[: self.config.num_hidden_layers]: + hidden_states = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_embeddings=position_embeddings, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = self.norm(hidden_states) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + ) + + +@auto_docstring +class DeepseekV32ForCausalLM(DeepseekV32PreTrainedModel, GenerationMixin): + """DeepSeek V3.2 for causal language modeling.""" + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} + + config_class = DeepseekV32Config + + def __init__(self, config: DeepseekV32Config): + super().__init__(config) + self.model = DeepseekV32Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> CausalLMOutputWithPast: + r""" + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV32ForCausalLM + + >>> model = DeepseekV32ForCausalLM.from_pretrained("meta-deepseek_v32/DeepseekV32-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-deepseek_v32/DeepseekV32-2-7b-hf") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + outputs: BaseModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class DeepseekV32ForSequenceClassification(GenericForSequenceClassification, DeepseekV32PreTrainedModel): + """DeepSeek V3.2 for sequence classification.""" + + config_class = DeepseekV32Config + + +class DeepseekV32ForTokenClassification(GenericForTokenClassification, DeepseekV32PreTrainedModel): + """DeepSeek V3.2 for token classification.""" + + config_class = DeepseekV32Config + + +__all__ = [ + "DeepseekV32PreTrainedModel", + "DeepseekV32Model", + "DeepseekV32ForCausalLM", + "DeepseekV32ForSequenceClassification", + "DeepseekV32ForTokenClassification", +] diff --git a/src/transformers/models/deepseek_v32/modular_deepseek_v32.py b/src/transformers/models/deepseek_v32/modular_deepseek_v32.py new file mode 100644 index 000000000000..b942230396ff --- /dev/null +++ b/src/transformers/models/deepseek_v32/modular_deepseek_v32.py @@ -0,0 +1,655 @@ +# coding=utf-8 +# Copyright 2025 DeepSeek AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on the DeepSeek-V3.2-Exp implementation from DeepSeek AI. +# Reference: https://github.com/deepseek-ai/DeepSeek-V3.2-Exp +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +DeepSeek V3.2 model implementation. + +DeepSeek V3.2 extends DeepSeek V3 with DeepSeek Sparse Attention (DSA), which uses a +Lightning Indexer to select top-k tokens for sparse attention, reducing complexity +from O(L^2) to O(L*k) where k is the number of selected tokens (default 2048). + +Key architectural differences from V3: +1. Lightning Indexer: Computes index scores to select relevant tokens +2. Hadamard Transform: Applied to Q/K in the indexer for activation rotation +3. Non-interleaved RoPE in Indexer: Different from interleaved RoPE in MLA +4. Sparse Attention: Only attends to top-k selected tokens +""" +import math +from typing import Callable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn + +from ...cache_utils import Cache +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from ...processing_utils import Unpack +from ...utils import logging +from ..deepseek_v3.configuration_deepseek_v3 import DeepseekV3Config +from ..deepseek_v3.modeling_deepseek_v3 import ( + DeepseekV3Attention, + DeepseekV3DecoderLayer, + DeepseekV3ForCausalLM, + DeepseekV3ForSequenceClassification, + DeepseekV3ForTokenClassification, + DeepseekV3Model, + DeepseekV3MLP, + DeepseekV3MoE, + DeepseekV3PreTrainedModel, + DeepseekV3RMSNorm, + DeepseekV3RotaryEmbedding, + DeepseekV3TopkRouter, + apply_rotary_pos_emb_interleave, + yarn_get_mscale, +) + + +logger = logging.get_logger(__name__) + +# Try to import fast_hadamard_transform, fall back to pure PyTorch if not available +try: + from fast_hadamard_transform import hadamard_transform + + HAS_FAST_HADAMARD = True +except ImportError: + HAS_FAST_HADAMARD = False + logger.warning_once( + "fast-hadamard-transform not installed. Using slower PyTorch fallback. " + "For better performance, install with: pip install fast-hadamard-transform" + ) + + +def hadamard_transform_fallback(x: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + """ + Pure PyTorch Fast Walsh-Hadamard Transform fallback. + + This is significantly slower than the CUDA version but works on CPU and + doesn't require the fast-hadamard-transform package. + + Args: + x: Input tensor with shape (..., dim) where dim should be a power of 2 + scale: Multiplier for the output + + Returns: + Transformed tensor with same shape as input + """ + orig_dtype = x.dtype + x = x.float() + dim = x.shape[-1] + + # Pad to power of 2 if needed + if dim & (dim - 1) != 0: + next_pow2 = 1 << (dim - 1).bit_length() + x = F.pad(x, (0, next_pow2 - dim)) + dim = next_pow2 + + # Fast Walsh-Hadamard Transform using butterfly operations + h = 1 + while h < dim: + # Reshape for butterfly operation + x = x.view(*x.shape[:-1], dim // (2 * h), 2, h) + # Butterfly: [a, b] -> [a + b, a - b] + a = x[..., 0, :] + b = x[..., 1, :] + x = torch.stack([a + b, a - b], dim=-2) + x = x.view(*x.shape[:-3], dim) + h *= 2 + + return (x * scale).to(orig_dtype) + + +def rotate_activation(x: torch.Tensor) -> torch.Tensor: + """ + Apply Hadamard transform for activation rotation in the indexer. + + This is used in the Lightning Indexer to rotate Q and K activations + before computing index scores. + + Args: + x: Input tensor with shape (..., hidden_size) + + Returns: + Rotated tensor with same shape + """ + hidden_size = x.size(-1) + scale = hidden_size**-0.5 + + if HAS_FAST_HADAMARD: + # fast-hadamard-transform requires contiguous bfloat16 input + return hadamard_transform(x.contiguous(), scale=scale) + else: + return hadamard_transform_fallback(x, scale=scale) + + +def apply_rotary_pos_emb_non_interleave( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + unsqueeze_dim: int = 1, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Applies Rotary Position Embedding with NON-INTERLEAVED layout. + + This is specifically for the Indexer, which requires non-interleaved RoPE + (different from the MLA which uses interleaved RoPE). + + The difference is in how dimensions are paired: + - Interleaved: pairs (0,1), (2,3), (4,5), ... + - Non-interleaved: pairs (0, dim/2), (1, dim/2+1), ... + + Args: + q: Query tensor of shape (batch, seq_len, heads, head_dim) + k: Key tensor of shape (batch, seq_len, heads, head_dim) + cos: Cosine of rotary angles + sin: Sine of rotary angles + unsqueeze_dim: Dimension to unsqueeze cos/sin for broadcasting + + Returns: + Tuple of rotated (query, key) tensors + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + # Non-interleaved: split in half and rotate + # q = [q1, q2] where q1 is first half, q2 is second half + # rotated = [q1 * cos - q2 * sin, q1 * sin + q2 * cos] + q1, q2 = q.chunk(2, dim=-1) + k1, k2 = k.chunk(2, dim=-1) + + q_embed = torch.cat([q1 * cos - q2 * sin, q1 * sin + q2 * cos], dim=-1) + k_embed = torch.cat([k1 * cos - k2 * sin, k1 * sin + k2 * cos], dim=-1) + + return q_embed, k_embed + + +class DeepseekV32Config(DeepseekV3Config): + """ + Configuration class for DeepSeek V3.2 model. + + Extends DeepseekV3Config with parameters for the Lightning Indexer + and DeepSeek Sparse Attention (DSA). + + Args: + index_n_heads (`int`, *optional*, defaults to 64): + Number of heads in the Lightning Indexer. + index_head_dim (`int`, *optional*, defaults to 128): + Dimension of each indexer head. + index_topk (`int`, *optional*, defaults to 2048): + Number of tokens to select for sparse attention. + use_sparse_attention (`bool`, *optional*, defaults to True): + Whether to use sparse attention. Set to False to use dense attention + (useful for the dense warm-up training stage). + detach_indexer_input (`bool`, *optional*, defaults to False): + Whether to detach the indexer input from the computational graph. + Used in Stage 2 training for separate optimization of indexer. + **kwargs: + Additional arguments passed to DeepseekV3Config. + """ + + model_type = "deepseek_v32" + + def __init__( + self, + index_n_heads: int = 64, + index_head_dim: int = 128, + index_topk: int = 2048, + use_sparse_attention: bool = True, + detach_indexer_input: bool = False, + **kwargs, + ): + # Set V3.2 specific defaults if not provided + kwargs.setdefault("n_routed_experts", 256) + kwargs.setdefault("n_shared_experts", 1) + kwargs.setdefault("num_experts_per_tok", 8) + kwargs.setdefault("n_group", 8) + kwargs.setdefault("topk_group", 4) + kwargs.setdefault("routed_scaling_factor", 2.5) + kwargs.setdefault("first_k_dense_replace", 3) + + super().__init__(**kwargs) + + self.index_n_heads = index_n_heads + self.index_head_dim = index_head_dim + self.index_topk = index_topk + self.use_sparse_attention = use_sparse_attention + self.detach_indexer_input = detach_indexer_input + + +class DeepseekV32RMSNorm(DeepseekV3RMSNorm): + pass + + +class DeepseekV32RotaryEmbedding(DeepseekV3RotaryEmbedding): + pass + + +class DeepseekV32MLP(DeepseekV3MLP): + pass + + +class DeepseekV32TopkRouter(DeepseekV3TopkRouter): + pass + + +class DeepseekV32MoE(DeepseekV3MoE): + pass + + +class DeepseekV32Indexer(nn.Module): + """ + Lightning Indexer for DeepSeek Sparse Attention (DSA). + + The indexer computes index scores to select which tokens each query should + attend to, reducing attention complexity from O(L^2) to O(L*k). + + The index score formula is: + I_{t,s} = sum_j w^I_{t,j} * ReLU(q^I_{t,j} * k^I_s) + + Key implementation details: + 1. Uses Hadamard transform on Q and K before scoring + 2. Uses NON-INTERLEAVED RoPE (different from MLA which uses interleaved) + 3. Uses LayerNorm on K (not RMSNorm) + + Args: + config: DeepseekV32Config + layer_idx: Index of the layer this indexer belongs to + """ + + def __init__(self, config: DeepseekV32Config, layer_idx: int): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.hidden_size = config.hidden_size + self.num_heads = config.index_n_heads + self.head_dim = config.index_head_dim + self.qk_rope_head_dim = config.qk_rope_head_dim + self.index_topk = config.index_topk + self.q_lora_rank = config.q_lora_rank + + # Query projection from compressed representation + self.q_b_proj = nn.Linear(self.q_lora_rank, self.num_heads * self.head_dim, bias=False) + + # Key projection (single head, broadcast to all heads) + self.k_proj = nn.Linear(self.hidden_size, self.head_dim, bias=False) + + # LayerNorm for keys (not RMSNorm, following reference) + self.k_layernorm = nn.LayerNorm(self.head_dim) + + # Per-head weight projection + self.weight_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False) + + self.softmax_scale = self.head_dim**-0.5 + + def forward( + self, + hidden_states: torch.Tensor, + q_compressed: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Compute top-k token indices for sparse attention. + + Args: + hidden_states: Input hidden states [batch, seq_len, hidden_size] + q_compressed: Compressed query representation [batch, seq_len, q_lora_rank] + position_embeddings: Tuple of (cos, sin) for RoPE + attention_mask: Optional attention mask + + Returns: + topk_indices: Indices of selected tokens [batch, seq_len, topk] + """ + batch_size, seq_len, _ = hidden_states.shape + cos, sin = position_embeddings + + # Query path + q = self.q_b_proj(q_compressed) # [B, S, num_heads * head_dim] + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) + + # Split into RoPE and non-RoPE parts + q_rope, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) + + # Key path + k = self.k_proj(hidden_states) # [B, S, head_dim] + k = self.k_layernorm(k) + + k_rope, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) + + # Apply NON-INTERLEAVED RoPE (critical difference from MLA!) + k_rope = k_rope.unsqueeze(2) # [B, S, 1, rope_dim] + q_rope, k_rope = apply_rotary_pos_emb_non_interleave(q_rope, k_rope, cos, sin) + k_rope = k_rope.squeeze(2) # [B, S, rope_dim] + + # Concatenate back + q = torch.cat([q_rope, q_nope], dim=-1) # [B, S, H, D] + k = torch.cat([k_rope, k_nope], dim=-1) # [B, S, D] + + # Apply Hadamard transform for activation rotation + q = rotate_activation(q) + k = rotate_activation(k) + + # Compute index scores: I_{t,s} = sum_j w_{t,j} * ReLU(q_{t,j} * k_s) + # q: [B, S, H, D], k: [B, S, D] + # First compute q * k for all pairs: [B, S_q, H, S_k] + q = q.transpose(1, 2) # [B, H, S_q, D] + k = k.unsqueeze(1) # [B, 1, S_k, D] + + # Compute attention-like scores + scores = torch.matmul(q, k.transpose(-1, -2)) # [B, H, S_q, S_k] + + # Apply ReLU + scores = F.relu(scores) + + # Get per-head weights + weights = self.weight_proj(hidden_states.float()) # [B, S, H] + weights = weights * (self.num_heads**-0.5) * self.softmax_scale + weights = weights.transpose(1, 2).unsqueeze(-1) # [B, H, S, 1] + + # Weighted sum over heads: [B, S_q, S_k] + index_scores = (scores * weights).sum(dim=1) # [B, S_q, S_k] + + # Apply attention mask if provided + if attention_mask is not None: + # attention_mask is typically [B, 1, S_q, S_k] or [B, S_q, S_k] + if attention_mask.dim() == 4: + attention_mask = attention_mask.squeeze(1) + index_scores = index_scores + attention_mask + + # Select top-k tokens + k_select = min(self.index_topk, seq_len) + topk_indices = index_scores.topk(k_select, dim=-1).indices # [B, S, topk] + + return topk_indices + + +class DeepseekV32Attention(DeepseekV3Attention): + """ + DeepSeek V3.2 Attention with Lightning Indexer for sparse attention. + + Extends DeepseekV3Attention by adding the Lightning Indexer which selects + top-k tokens for each query position, enabling sparse attention. + """ + + def __init__(self, config: DeepseekV32Config, layer_idx: int): + super().__init__(config, layer_idx) + + # Add the Lightning Indexer + self.indexer = DeepseekV32Indexer(config, layer_idx) + + # Update softmax scale with YaRN mscale if needed + if hasattr(config, "rope_parameters") and config.rope_parameters: + rope_type = config.rope_parameters.get("rope_type", "default") + if rope_type != "default": + mscale_all_dim = config.rope_parameters.get("mscale_all_dim", 0) + scaling_factor = config.rope_parameters.get("factor", 1.0) + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scaling = self.scaling * mscale * mscale + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + past_key_values: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """ + Forward pass with sparse attention via Lightning Indexer. + """ + batch_size, seq_length = hidden_states.shape[:-1] + query_shape = (batch_size, seq_length, -1, self.qk_head_dim) + key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + + # Query path with LoRA compression + if self.q_lora_rank is None: + q_states = self.q_proj(hidden_states) + q_compressed = None + else: + q_compressed = self.q_a_layernorm(self.q_a_proj(hidden_states)) + q_states = self.q_b_proj(q_compressed) + + # Optionally detach for separate indexer optimization (Stage 2 training) + if self.config.detach_indexer_input: + q_compressed_for_indexer = q_compressed.detach() + else: + q_compressed_for_indexer = q_compressed + + q_states = q_states.view(query_shape).transpose(1, 2) + q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # KV path with compression + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pass = self.kv_b_proj(self.kv_a_layernorm(k_pass)).view(key_shape).transpose(1, 2) + k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + + # Apply RoPE (INTERLEAVED for MLA) + cos, sin = position_embeddings + if self.config.rope_interleave: + q_rot, k_rot = apply_rotary_pos_emb_interleave(q_rot, k_rot, cos, sin) + else: + from ..llama.modeling_llama import apply_rotary_pos_emb + q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + + k_rot = k_rot.expand(*k_pass.shape[:-1], -1) + + query_states = torch.cat((q_pass, q_rot), dim=-1) + key_states = torch.cat((k_pass, k_rot), dim=-1) + + # Update cache if provided + if past_key_values is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_values.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # Compute attention scores + # For flash attention with different head dims, pad value states + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + # Standard attention computation + attn_output, attn_weights = self._compute_attention( + query_states, + key_states, + value_states, + attention_mask, + q_compressed_for_indexer if self.q_lora_rank else None, + hidden_states, + position_embeddings, + **kwargs, + ) + + # Remove padding if added + if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + + def _compute_attention( + self, + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor], + q_compressed: Optional[torch.Tensor], + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Compute attention with optional sparse masking from the indexer. + """ + batch_size, num_heads, seq_len, _ = query_states.shape + + # Check if we should use sparse attention + use_sparse = ( + self.config.use_sparse_attention + and q_compressed is not None + and seq_len > 1 # Only for prefill, not decode + ) + + if use_sparse: + # Get top-k indices from the indexer + topk_indices = self.indexer( + hidden_states, + q_compressed, + position_embeddings, + attention_mask, + ) + + # Create sparse attention mask + # topk_indices: [B, S, topk] + # We need to create a mask that only allows attention to selected tokens + kv_seq_len = key_states.shape[2] + sparse_mask = torch.full( + (batch_size, seq_len, kv_seq_len), + float("-inf"), + device=query_states.device, + dtype=query_states.dtype, + ) + + # Scatter 0s at the selected positions + sparse_mask.scatter_(-1, topk_indices, 0.0) + + # Combine with causal mask if provided + if attention_mask is not None: + if attention_mask.dim() == 4: + # [B, 1, S, S] -> [B, S, S] + attention_mask = attention_mask.squeeze(1) + sparse_mask = sparse_mask + attention_mask + + # Expand for heads: [B, H, S, S] + attention_mask = sparse_mask.unsqueeze(1).expand(-1, num_heads, -1, -1) + + # Use eager attention for now (can be extended to flash attention) + attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.scaling + + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + +class DeepseekV32DecoderLayer(DeepseekV3DecoderLayer): + """DeepSeek V3.2 decoder layer with sparse attention.""" + + def __init__(self, config: DeepseekV32Config, layer_idx: int): + # Call grandparent init to avoid V3 attention + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + + # Use V3.2 attention with indexer + self.self_attn = DeepseekV32Attention(config=config, layer_idx=layer_idx) + + # MLP: dense for first k layers, MoE for rest + if layer_idx >= config.first_k_dense_replace: + self.mlp = DeepseekV32MoE(config) + else: + self.mlp = DeepseekV32MLP(config) + + self.input_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + +class DeepseekV32PreTrainedModel(DeepseekV3PreTrainedModel): + """Base class for DeepSeek V3.2 models.""" + + config_class = DeepseekV32Config + _no_split_modules = ["DeepseekV32DecoderLayer"] + + +class DeepseekV32Model(DeepseekV3Model): + """DeepSeek V3.2 Model with sparse attention.""" + + config_class = DeepseekV32Config + + def __init__(self, config: DeepseekV32Config): + DeepseekV32PreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [DeepseekV32DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = DeepseekV32RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = DeepseekV32RotaryEmbedding(config=config) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + +class DeepseekV32ForCausalLM(DeepseekV3ForCausalLM): + """DeepSeek V3.2 for causal language modeling.""" + + config_class = DeepseekV32Config + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config: DeepseekV32Config): + DeepseekV32PreTrainedModel.__init__(self, config) + self.model = DeepseekV32Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + +class DeepseekV32ForSequenceClassification(DeepseekV3ForSequenceClassification): + """DeepSeek V3.2 for sequence classification.""" + + config_class = DeepseekV32Config + + +class DeepseekV32ForTokenClassification(DeepseekV3ForTokenClassification): + """DeepSeek V3.2 for token classification.""" + + config_class = DeepseekV32Config + + +__all__ = [ + "DeepseekV32Config", + "DeepseekV32PreTrainedModel", + "DeepseekV32Model", + "DeepseekV32ForCausalLM", + "DeepseekV32ForSequenceClassification", + "DeepseekV32ForTokenClassification", +]