diff --git a/intermediate_source/scaled_dot_product_attention_tutorial.py b/intermediate_source/scaled_dot_product_attention_tutorial.py index 35b1ba7be4e..6b67169e7e1 100644 --- a/intermediate_source/scaled_dot_product_attention_tutorial.py +++ b/intermediate_source/scaled_dot_product_attention_tutorial.py @@ -223,7 +223,7 @@ def generate_rand_batch( torch.randn(seq_len, embed_dimension, dtype=dtype, device=device) for seq_len in seq_len_list - ] + ], layout=torch.jagged ), seq_len_list, )