diff --git a/examples/models/llama/llama_transformer.py b/examples/models/llama/llama_transformer.py index 6c1a5c05d66..00c3bfa1f47 100644 --- a/examples/models/llama/llama_transformer.py +++ b/examples/models/llama/llama_transformer.py @@ -272,6 +272,13 @@ def construct_transformer(model_args: ModelArgs) -> Transformer: norm_eps=model_args.norm_eps, ) ) + elif ( + model_args.layer_types + and model_args.layer_types[layer_id] == "skip_attention" + ): + attention = AttentionSkip() + transformer_block = TransformerBlock(model_args, attention) + layers.append(transformer_block) else: attention = cls( model_args, layer_id, rope, **model_args.attention_kwargs