Skip to content

Commit dca1f2c

Browse files
Hanxian97facebook-github-bot
authored andcommitted
support skip atten in export
Summary: Support export for llama model variants with attention layer skipping. We only need to specify the attention skip patterns in config.json in layer_type. E.g., "layer_types": [ "full_attention", "full_attention", "full_attention", "skip_attention", "skip_attention", "skip_attention" ] Differential Revision: D88399533
1 parent 4014597 commit dca1f2c

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def forward(self, x, freqs_cos, freqs_sin, attn_options: ForwardOptions): # x:
132132
)
133133
if not isinstance(self.attention, AttentionSkip):
134134
h = x + h
135+
else:
136+
h = x
137+
attn_options_update = None
135138

136139
if hasattr(self, "block_sparse_moe"):
137140
out = h + self.block_sparse_moe(self.ffn_norm(h))
@@ -272,6 +275,10 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
272275
norm_eps=model_args.norm_eps,
273276
)
274277
)
278+
elif model_args.layer_types and model_args.layer_types[layer_id] == "skip_attention":
279+
attention = AttentionSkip()
280+
transformer_block = TransformerBlock(model_args, attention)
281+
layers.append(transformer_block)
275282
else:
276283
attention = cls(
277284
model_args, layer_id, rope, **model_args.attention_kwargs

0 commit comments

Comments
 (0)