Skip to content

Commit 72fd468

Browse files
Hanxian97facebook-github-bot
authored andcommitted
support skip atten in export (#16104)
Summary: Support export for llama model variants with attention layer skipping. We only need to specify the attention skip patterns in config.json in layer_type. E.g., "layer_types": [ "full_attention", "full_attention", "full_attention", "skip_attention", "skip_attention", "skip_attention" ] Differential Revision: D88399533
1 parent 4014597 commit 72fd468

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

examples/models/llama/llama_transformer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer:
272272
norm_eps=model_args.norm_eps,
273273
)
274274
)
275+
elif (
276+
model_args.layer_types
277+
and model_args.layer_types[layer_id] == "skip_attention"
278+
):
279+
attention = AttentionSkip()
280+
transformer_block = TransformerBlock(model_args, attention)
281+
layers.append(transformer_block)
275282
else:
276283
attention = cls(
277284
model_args, layer_id, rope, **model_args.attention_kwargs

0 commit comments

Comments
 (0)