Skip to content

Commit d84677e

Browse files
authored
Fix test_sdp_bs_last UT tolerance (#4712)(#4727)
* Increase tol for test_sdp_bs_last * Replace ref with cpu full precision results. * Update input tensor type to be cpu float. * Fix format.
1 parent ba1e3f6 commit d84677e

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

tests/gpu/examples/test_sdp_bs_last.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@ def test_sdp_bs_last(self):
3737
beta = 1.0
3838
max_len = 2048
3939

40-
query_layer = torch.randn(beam_width, q_len, num_heads, head_dim).xpu().half()
41-
key_layer = torch.randn(beam_width, kv_len, num_heads, head_dim).xpu().half()
42-
value_layer = torch.randn(beam_width, kv_len, num_heads, head_dim).xpu().half()
40+
query_layer = torch.randn(beam_width, q_len, num_heads, head_dim)
41+
key_layer = torch.randn(beam_width, kv_len, num_heads, head_dim)
42+
value_layer = torch.randn(beam_width, kv_len, num_heads, head_dim)
4343

4444
# attention_mask = torch.zeros(beam_width, 1, q_len, kv_len).half()
4545
# attention_mask[0][0][0] = -65504.
46-
attention_mask = torch.zeros(beam_width, 1, q_len, kv_len).xpu().half()
46+
attention_mask = torch.zeros(beam_width, 1, q_len, kv_len)
4747
attention_mask[0, 0, 0:q_len, 0] = -65504
4848
attention_mask[0, 0, 0:q_len, kv_len - 1 : kv_len] = -float("inf")
4949
attention_mask[0, 0, 0, kv_len - 3 : kv_len] = -float("inf")
@@ -68,21 +68,21 @@ def test_sdp_bs_last(self):
6868
# None,
6969
# alpha)
7070
#
71-
# ref_out_cpu, _ = naive_sdp(
72-
# query_layer.cpu().float().permute(1, 2, 0, 3),
73-
# key_layer.cpu().float().permute(1, 2, 0, 3),
74-
# value_layer.cpu().float().permute(1, 2, 0, 3),
75-
# attention_mask.cpu().float(),
71+
# ref_out_xpu, _ = naive_sdp(
72+
# query_layer.permute(0, 2, 1, 3),
73+
# key_layer.permute(0, 2, 1, 3),
74+
# value_layer.permute(0, 2, 1, 3),
75+
# attention_mask,
7676
# None,
7777
# None,
7878
# alpha)
7979
attention_mask_padded = torch.zeros(beam_width, 1, q_len, max_len).half()
8080
attention_mask_padded[:, :, :, 0:kv_len] = attention_mask
8181

8282
res_out = torch.xpu.IpexSDP(
83-
query_layer.to("xpu").permute(0, 2, 1, 3),
84-
key_layer.to("xpu").permute(0, 2, 1, 3),
85-
value_layer.to("xpu").permute(0, 2, 1, 3),
83+
query_layer.half().to("xpu").permute(0, 2, 1, 3),
84+
key_layer.half().to("xpu").permute(0, 2, 1, 3),
85+
value_layer.half().to("xpu").permute(0, 2, 1, 3),
8686
None,
8787
attention_mask_padded.to("xpu"),
8888
None,
@@ -99,11 +99,7 @@ def test_sdp_bs_last(self):
9999
"sdp half vs naive xpu half: ",
100100
torch.max(torch.abs(ref_out.cpu() - res_out.cpu())).item(),
101101
)
102-
self.assertEqual(
103-
ref_out.cpu().float(), res_out.cpu().float(), atol=1e-3, rtol=1e-4
104-
)
102+
self.assertEqual(ref_out, res_out.cpu().float(), atol=1e-3, rtol=1e-4)
105103
# print("sdp half vs sdp half non padded: ", torch.max(torch.abs(res_non_pad_out.cpu() - res_out.cpu())).item())
106-
# print("sdp half vs naive cpu float: ", torch.max(torch.abs(res_out.cpu() - ref_out_cpu)).item())
107104
# print("sdp half vs naive xpu float: ", torch.max(torch.abs(res_out.cpu() - ref_out_float.cpu())).item())
108-
# print("naive xpu half vs naive cpu float: ", torch.max(torch.abs(ref_out.cpu() - ref_out_cpu)).item())
109105
# print("naive xpu half vs naive xpu float: ", torch.max(torch.abs(ref_out.cpu() - ref_out_float.cpu())).item())

0 commit comments

Comments
 (0)