11use crate :: flash_attn:: flash_attn_varlen;
22use crate :: layers:: { get_cos_sin, get_inv_freqs, LayerNorm , Linear } ;
3- use crate :: models:: nomic:: { NomicBertEmbeddings , NomicBertGatedMLP } ;
3+ use crate :: models:: nomic:: { NomicBertEmbeddings , NomicMLP } ;
44use crate :: models:: { Model , NomicConfig } ;
55use candle:: { DType , Device , IndexOp , Result , Tensor , D } ;
66use candle_nn:: VarBuilder ;
@@ -25,16 +25,25 @@ impl NomicAttention {
2525 let attention_head_size = config. n_embd / config. n_head ;
2626 let hidden_size = config. n_embd ;
2727
28- let qkv_weight = vb. pp ( "Wqkv" ) . get (
29- ( 3 * num_attention_heads * attention_head_size, hidden_size) ,
30- "weight" ,
31- ) ?;
32- let qkv_linear = Linear :: new ( qkv_weight, None , None ) ;
28+ let qkv_dim = 3 * num_attention_heads * attention_head_size;
29+
30+ let qkv_weight = vb. pp ( "Wqkv" ) . get ( ( qkv_dim, hidden_size) , "weight" ) ?;
31+ let qkv_bias = if config. qkv_proj_bias {
32+ Some ( vb. pp ( "Wqkv" ) . get ( ( qkv_dim, ) , "bias" ) ?)
33+ } else {
34+ None
35+ } ;
36+ let qkv_linear = Linear :: new ( qkv_weight, qkv_bias, None ) ;
3337
3438 let out_proj_weight = vb
3539 . pp ( "out_proj" )
3640 . get ( ( hidden_size, hidden_size) , "weight" ) ?;
37- let out_proj = Linear :: new ( out_proj_weight, None , None ) ;
41+ let out_proj_bias = if config. qkv_proj_bias {
42+ Some ( vb. pp ( "out_proj" ) . get ( ( hidden_size, ) , "bias" ) ?)
43+ } else {
44+ None
45+ } ;
46+ let out_proj = Linear :: new ( out_proj_weight, out_proj_bias, None ) ;
3847
3948 let softmax_scale = ( 1. / ( attention_head_size as f64 ) . sqrt ( ) ) as f32 ;
4049
@@ -93,17 +102,18 @@ impl NomicAttention {
93102
94103struct NomicBertBlock {
95104 attention : NomicAttention ,
96- mlp : NomicBertGatedMLP ,
105+ mlp : NomicMLP ,
97106 post_attention_layer_norm : LayerNorm ,
98107 output_layer_norm : LayerNorm ,
99108
100109 span : tracing:: Span ,
101110}
102111
103112impl NomicBertBlock {
104- pub fn load ( vb : VarBuilder , config : & NomicConfig ) -> Result < Self > {
113+ pub fn load ( vb : VarBuilder , index : usize , config : & NomicConfig ) -> Result < Self > {
105114 let attention = NomicAttention :: load ( vb. pp ( "attn" ) , config) ?;
106- let mlp = NomicBertGatedMLP :: load ( vb. pp ( "mlp" ) , config) ?;
115+
116+ let mlp = NomicMLP :: load ( vb. pp ( "mlp" ) , index, config) ?;
107117
108118 let post_attention_layer_norm =
109119 LayerNorm :: load ( vb. pp ( "norm1" ) , config. n_embd , config. layer_norm_epsilon ) ?;
@@ -132,6 +142,7 @@ impl NomicBertBlock {
132142 let attn_output = self
133143 . attention
134144 . forward ( & hidden_states, cu_seqlens, cos, sin, max_s) ?;
145+
135146 let hidden_states = self
136147 . post_attention_layer_norm
137148 . forward ( & hidden_states, Some ( & attn_output) ) ?;
@@ -145,13 +156,14 @@ impl NomicBertBlock {
145156
146157struct NomicBertEncoder {
147158 layers : Vec < NomicBertBlock > ,
159+
148160 span : tracing:: Span ,
149161}
150162
151163impl NomicBertEncoder {
152164 pub fn load ( vb : VarBuilder , config : & NomicConfig ) -> Result < Self > {
153165 let layers = ( 0 ..config. n_layer )
154- . map ( |index| NomicBertBlock :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , config) )
166+ . map ( |index| NomicBertBlock :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , index , config) )
155167 . collect :: < Result < Vec < _ > > > ( ) ?;
156168
157169 let span = tracing:: span!( tracing:: Level :: TRACE , "encoder" ) ;
@@ -170,7 +182,6 @@ impl NomicBertEncoder {
170182
171183 let mut hidden_states = hidden_states. clone ( ) ;
172184
173- // Use a loop rather than a fold as it's easier to modify when adding debug/...
174185 for layer in self . layers . iter ( ) {
175186 hidden_states = layer. forward ( & hidden_states, cu_seqlens, cos, sin, max_s) ?
176187 }
@@ -419,6 +430,7 @@ impl Model for FlashNomicBertModel {
419430 fn is_padded ( & self ) -> bool {
420431 false
421432 }
433+
422434 fn embed ( & self , batch : Batch ) -> Result < ( Option < Tensor > , Option < Tensor > ) > {
423435 self . forward ( batch)
424436 }
0 commit comments