Skip to content

Commit 0aa4aa1

Browse files
Correct initial max value for the SoftMax kernel (#1426) (#1433)
Co-authored-by: Wei Lin <wei2.lin@intel.com>
1 parent 53417f8 commit 0aa4aa1

File tree

2 files changed

+49
-14
lines changed

2 files changed

+49
-14
lines changed

csrc/cpu/vec/vec512/perf_kernel/add_softmax.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ inline void _dil_div_add_reduce_max_fusion_kernel(
8585
const int& size,
8686
float* out,
8787
float& max) {
88-
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::min());
88+
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::lowest());
8989
auto vec_a = vec_ps_min;
9090
auto vec_b = vec_ps_min;
9191
auto vec_out = vec_ps_min;
@@ -219,7 +219,7 @@ inline void _dil_add_reduce_max_fusion_kernel(
219219
const int& size,
220220
float* out,
221221
float& max) {
222-
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::min());
222+
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::lowest());
223223
auto vec_a = vec_ps_min;
224224
auto vec_b = vec_ps_min;
225225
auto vec_out = vec_ps_min;
@@ -252,7 +252,7 @@ inline void _dil_mul_reduce_max_fusion_kernel(
252252
const int& size,
253253
float* out,
254254
float& max) {
255-
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::min());
255+
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::lowest());
256256
auto vec_a = vec_ps_min;
257257
auto vec_out = vec_ps_min;
258258

@@ -278,7 +278,7 @@ inline void _dil_mul_reduce_max_fusion_kernel(
278278
}
279279

280280
inline void _init_mha_buffer_kernel(float* max, float* sum, const int& size) {
281-
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::min());
281+
auto vec_ps_min = _mm512_set1_ps(std::numeric_limits<float>::lowest());
282282
auto vec_zeros = _mm512_setzero_ps();
283283

284284
int i = 0;

tests/cpu/test_mha.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch.nn as nn
55
import torch.nn.functional as F
66
import intel_extension_for_pytorch as ipex
7+
import math
78
from common_utils import TestCase
89

910
#(from Diffusers 0.12.1)
@@ -162,6 +163,20 @@ def forward(self, x, y):
162163
output = hidden_states.to(query.dtype)
163164
return output
164165

166+
#(Fake Diffusers Model - Fall back to ipex::mha_scores_calc)
167+
class Fake_SD_MHA_Model(nn.Module):
168+
def __init__(self, dim_per_head, softmax_dim=-1):
169+
super(Fake_SD_MHA_Model, self).__init__()
170+
self.softmax = nn.Softmax(dim=softmax_dim)
171+
self.dim_per_head = dim_per_head
172+
173+
def forward(self, mat1, mat2, mat3, bias):
174+
mat1 = mat1 / math.sqrt(self.dim_per_head)
175+
qk = torch.matmul(mat1, mat2.transpose(2, 3))
176+
scores = self.softmax(qk + bias)
177+
output = torch.matmul(scores, mat3)
178+
return output
179+
165180
class MHA_Model_BERT(nn.Module):
166181
def __init__(self, scale, num_heads, head_dims, permute_idx, trans_a, trans_b):
167182
super(MHA_Model_BERT, self).__init__()
@@ -267,7 +282,7 @@ def forward(self, x):
267282
class TransFreeMHATester(TestCase):
268283

269284
def test_sd_mha_bf16_v1(self):
270-
mat = torch.randn(2, 4096, 320).to(torch.bfloat16)
285+
mat = (torch.randn(2, 4096, 320) + 15).to(torch.bfloat16)
271286
sd_mha_model = SD_MHA_Model_v1(0.3, 8, 320, 320).eval()
272287
mha_ipex = ipex.optimize(sd_mha_model, dtype=torch.bfloat16, level="O1")
273288

@@ -278,14 +293,14 @@ def test_sd_mha_bf16_v1(self):
278293
for _ in range(2):
279294
mha_jit = mha_ipex(mat)
280295
mha_ref = sd_mha_model(mat)
281-
self.assertEqual(mha_ref, mha_jit, prec=1e-2)
296+
self.assertEqual(mha_ref, mha_jit, prec=1e-0)
282297

283298
mha_graph = mha_ipex.graph_for(mat)
284299
self.assertTrue(any(n.kind() == "ipex::sd_flash_mha" for n in mha_graph.nodes()))
285300

286301
def test_sd_mha_bf16_v2(self):
287-
mat1 = torch.randn(2, 4096, 320).to(torch.bfloat16)
288-
mat2 = torch.randn(2, 77, 320).to(torch.bfloat16)
302+
mat1 = (torch.randn(2, 4096, 320) + 15).to(torch.bfloat16)
303+
mat2 = (torch.randn(2, 77, 320) + 15).to(torch.bfloat16)
289304
sd_mha_model = SD_MHA_Model_v2(0.3, 8, 320, 320).eval()
290305
mha_ipex = ipex.optimize(sd_mha_model, dtype=torch.bfloat16, level="O1")
291306

@@ -296,13 +311,13 @@ def test_sd_mha_bf16_v2(self):
296311
for _ in range(2):
297312
mha_jit = mha_ipex(mat1, mat2)
298313
mha_ref = sd_mha_model(mat1, mat2)
299-
self.assertEqual(mha_ref, mha_jit, prec=1e-2)
314+
self.assertEqual(mha_ref, mha_jit, prec=1e-0)
300315

301316
mha_graph = mha_ipex.graph_for(mat1, mat2)
302317
self.assertTrue(any(n.kind() == "ipex::sd_flash_mha" for n in mha_graph.nodes()))
303318

304319
def test_sd_mha_bf16_v3(self):
305-
mat = torch.randn(2, 4096, 320).to(torch.bfloat16)
320+
mat = (torch.randn(2, 4096, 320) + 15).to(torch.bfloat16)
306321
sd_mha_model = SD_MHA_Model_v3(8, 320, 320).eval()
307322
mha_ipex = ipex.optimize(sd_mha_model, dtype=torch.bfloat16, level="O1")
308323

@@ -313,14 +328,14 @@ def test_sd_mha_bf16_v3(self):
313328
for _ in range(2):
314329
mha_jit = mha_ipex(mat)
315330
mha_ref = sd_mha_model(mat)
316-
self.assertEqual(mha_ref, mha_jit, prec=1e-2)
331+
self.assertEqual(mha_ref, mha_jit, prec=1e-0)
317332

318333
mha_graph = mha_ipex.graph_for(mat)
319334
self.assertTrue(any(n.kind() == "ipex::sd_flash_mha" for n in mha_graph.nodes()))
320335

321336
def test_sd_mha_bf16_v4(self):
322-
mat1 = torch.randn(2, 4096, 320).to(torch.bfloat16)
323-
mat2 = torch.randn(2, 77, 320).to(torch.bfloat16)
337+
mat1 = (torch.randn(2, 4096, 320) + 15).to(torch.bfloat16)
338+
mat2 = (torch.randn(2, 77, 320) + 15).to(torch.bfloat16)
324339
sd_mha_model = SD_MHA_Model_v4(8, 320, 320).eval()
325340
mha_ipex = ipex.optimize(sd_mha_model, dtype=torch.bfloat16, level="O1")
326341

@@ -331,11 +346,31 @@ def test_sd_mha_bf16_v4(self):
331346
for _ in range(2):
332347
mha_jit = mha_ipex(mat1, mat2)
333348
mha_ref = sd_mha_model(mat1, mat2)
334-
self.assertEqual(mha_ref, mha_jit, prec=1e-2)
349+
self.assertEqual(mha_ref, mha_jit, prec=1e-0)
335350

336351
mha_graph = mha_ipex.graph_for(mat1, mat2)
337352
self.assertTrue(any(n.kind() == "ipex::sd_flash_mha" for n in mha_graph.nodes()))
338353

354+
def test_fake_sd_mha_bf16(self):
355+
mat1 = (torch.randn(1, 2, 64, 64) + 20).to(torch.bfloat16)
356+
mat2 = (torch.randn(1, 2, 64, 64) - 20).to(torch.bfloat16)
357+
mat3 = torch.randn(1, 2, 64, 64).to(torch.bfloat16)
358+
mask = (torch.ones(1, 1, 1, 64)).to(torch.bfloat16)
359+
fake_sd_mha_model = Fake_SD_MHA_Model(64, -1).eval()
360+
fake_mha_ipex = ipex.optimize(fake_sd_mha_model, dtype=torch.bfloat16, level="O1")
361+
362+
with torch.cpu.amp.autocast(), torch.no_grad():
363+
fake_mha_ipex = torch.jit.trace(fake_mha_ipex, (mat1, mat2, mat3, mask, ))
364+
fake_mha_ipex = torch.jit.freeze(fake_mha_ipex)
365+
366+
for _ in range(2):
367+
fake_mha_jit = fake_mha_ipex(mat1, mat2, mat3, mask)
368+
fake_mha_ref = fake_sd_mha_model(mat1, mat2, mat3, mask)
369+
self.assertEqual(fake_mha_ref, fake_mha_jit, prec=1e-1)
370+
371+
fake_mha_graph = fake_mha_ipex.graph_for(mat1, mat2, mat3, mask)
372+
self.assertTrue(any(n.kind() == "ipex::mha_scores_calc" for n in fake_mha_graph.nodes()))
373+
339374
def test_transfree_mha_bf16(self):
340375
for i in range(len(bs)):
341376
mat = torch.randn(bs[i], seq[i], num_heads[i] * head_dims[i]).to(torch.bfloat16)

0 commit comments

Comments
 (0)