@@ -37,13 +37,13 @@ def test_sdp_bs_last(self):
3737 beta = 1.0
3838 max_len = 2048
3939
40- query_layer = torch .randn (beam_width , q_len , num_heads , head_dim ). xpu (). half ()
41- key_layer = torch .randn (beam_width , kv_len , num_heads , head_dim ). xpu (). half ()
42- value_layer = torch .randn (beam_width , kv_len , num_heads , head_dim ). xpu (). half ()
40+ query_layer = torch .randn (beam_width , q_len , num_heads , head_dim )
41+ key_layer = torch .randn (beam_width , kv_len , num_heads , head_dim )
42+ value_layer = torch .randn (beam_width , kv_len , num_heads , head_dim )
4343
4444 # attention_mask = torch.zeros(beam_width, 1, q_len, kv_len).half()
4545 # attention_mask[0][0][0] = -65504.
46- attention_mask = torch .zeros (beam_width , 1 , q_len , kv_len ). xpu (). half ()
46+ attention_mask = torch .zeros (beam_width , 1 , q_len , kv_len )
4747 attention_mask [0 , 0 , 0 :q_len , 0 ] = - 65504
4848 attention_mask [0 , 0 , 0 :q_len , kv_len - 1 : kv_len ] = - float ("inf" )
4949 attention_mask [0 , 0 , 0 , kv_len - 3 : kv_len ] = - float ("inf" )
@@ -68,21 +68,21 @@ def test_sdp_bs_last(self):
6868 # None,
6969 # alpha)
7070 #
71- # ref_out_cpu , _ = naive_sdp(
72- # query_layer.cpu().float(). permute(1 , 2, 0 , 3),
73- # key_layer.cpu().float(). permute(1 , 2, 0 , 3),
74- # value_layer.cpu().float(). permute(1 , 2, 0 , 3),
75- # attention_mask.cpu().float() ,
71+ # ref_out_xpu , _ = naive_sdp(
72+ # query_layer.permute(0 , 2, 1 , 3),
73+ # key_layer.permute(0 , 2, 1 , 3),
74+ # value_layer.permute(0 , 2, 1 , 3),
75+ # attention_mask,
7676 # None,
7777 # None,
7878 # alpha)
7979 attention_mask_padded = torch .zeros (beam_width , 1 , q_len , max_len ).half ()
8080 attention_mask_padded [:, :, :, 0 :kv_len ] = attention_mask
8181
8282 res_out = torch .xpu .IpexSDP (
83- query_layer .to ("xpu" ).permute (0 , 2 , 1 , 3 ),
84- key_layer .to ("xpu" ).permute (0 , 2 , 1 , 3 ),
85- value_layer .to ("xpu" ).permute (0 , 2 , 1 , 3 ),
83+ query_layer .half (). to ("xpu" ).permute (0 , 2 , 1 , 3 ),
84+ key_layer .half (). to ("xpu" ).permute (0 , 2 , 1 , 3 ),
85+ value_layer .half (). to ("xpu" ).permute (0 , 2 , 1 , 3 ),
8686 None ,
8787 attention_mask_padded .to ("xpu" ),
8888 None ,
@@ -99,11 +99,7 @@ def test_sdp_bs_last(self):
9999 "sdp half vs naive xpu half: " ,
100100 torch .max (torch .abs (ref_out .cpu () - res_out .cpu ())).item (),
101101 )
102- self .assertEqual (
103- ref_out .cpu ().float (), res_out .cpu ().float (), atol = 1e-3 , rtol = 1e-4
104- )
102+ self .assertEqual (ref_out , res_out .cpu ().float (), atol = 1e-3 , rtol = 1e-4 )
105103 # print("sdp half vs sdp half non padded: ", torch.max(torch.abs(res_non_pad_out.cpu() - res_out.cpu())).item())
106- # print("sdp half vs naive cpu float: ", torch.max(torch.abs(res_out.cpu() - ref_out_cpu)).item())
107104 # print("sdp half vs naive xpu float: ", torch.max(torch.abs(res_out.cpu() - ref_out_float.cpu())).item())
108- # print("naive xpu half vs naive cpu float: ", torch.max(torch.abs(ref_out.cpu() - ref_out_cpu)).item())
109105 # print("naive xpu half vs naive xpu float: ", torch.max(torch.abs(ref_out.cpu() - ref_out_float.cpu())).item())
0 commit comments