44import torch .nn as nn
55import torch .nn .functional as F
66import intel_extension_for_pytorch as ipex
7+ import math
78from common_utils import TestCase
89
910#(from Diffusers 0.12.1)
@@ -162,6 +163,20 @@ def forward(self, x, y):
162163 output = hidden_states .to (query .dtype )
163164 return output
164165
166+ #(Fake Diffusers Model - Fall back to ipex::mha_scores_calc)
167+ class Fake_SD_MHA_Model (nn .Module ):
168+ def __init__ (self , dim_per_head , softmax_dim = - 1 ):
169+ super (Fake_SD_MHA_Model , self ).__init__ ()
170+ self .softmax = nn .Softmax (dim = softmax_dim )
171+ self .dim_per_head = dim_per_head
172+
173+ def forward (self , mat1 , mat2 , mat3 , bias ):
174+ mat1 = mat1 / math .sqrt (self .dim_per_head )
175+ qk = torch .matmul (mat1 , mat2 .transpose (2 , 3 ))
176+ scores = self .softmax (qk + bias )
177+ output = torch .matmul (scores , mat3 )
178+ return output
179+
165180class MHA_Model_BERT (nn .Module ):
166181 def __init__ (self , scale , num_heads , head_dims , permute_idx , trans_a , trans_b ):
167182 super (MHA_Model_BERT , self ).__init__ ()
@@ -267,7 +282,7 @@ def forward(self, x):
267282class TransFreeMHATester (TestCase ):
268283
269284 def test_sd_mha_bf16_v1 (self ):
270- mat = torch .randn (2 , 4096 , 320 ).to (torch .bfloat16 )
285+ mat = ( torch .randn (2 , 4096 , 320 ) + 15 ).to (torch .bfloat16 )
271286 sd_mha_model = SD_MHA_Model_v1 (0.3 , 8 , 320 , 320 ).eval ()
272287 mha_ipex = ipex .optimize (sd_mha_model , dtype = torch .bfloat16 , level = "O1" )
273288
@@ -278,14 +293,14 @@ def test_sd_mha_bf16_v1(self):
278293 for _ in range (2 ):
279294 mha_jit = mha_ipex (mat )
280295 mha_ref = sd_mha_model (mat )
281- self .assertEqual (mha_ref , mha_jit , prec = 1e-2 )
296+ self .assertEqual (mha_ref , mha_jit , prec = 1e-0 )
282297
283298 mha_graph = mha_ipex .graph_for (mat )
284299 self .assertTrue (any (n .kind () == "ipex::sd_flash_mha" for n in mha_graph .nodes ()))
285300
286301 def test_sd_mha_bf16_v2 (self ):
287- mat1 = torch .randn (2 , 4096 , 320 ).to (torch .bfloat16 )
288- mat2 = torch .randn (2 , 77 , 320 ).to (torch .bfloat16 )
302+ mat1 = ( torch .randn (2 , 4096 , 320 ) + 15 ).to (torch .bfloat16 )
303+ mat2 = ( torch .randn (2 , 77 , 320 ) + 15 ).to (torch .bfloat16 )
289304 sd_mha_model = SD_MHA_Model_v2 (0.3 , 8 , 320 , 320 ).eval ()
290305 mha_ipex = ipex .optimize (sd_mha_model , dtype = torch .bfloat16 , level = "O1" )
291306
@@ -296,13 +311,13 @@ def test_sd_mha_bf16_v2(self):
296311 for _ in range (2 ):
297312 mha_jit = mha_ipex (mat1 , mat2 )
298313 mha_ref = sd_mha_model (mat1 , mat2 )
299- self .assertEqual (mha_ref , mha_jit , prec = 1e-2 )
314+ self .assertEqual (mha_ref , mha_jit , prec = 1e-0 )
300315
301316 mha_graph = mha_ipex .graph_for (mat1 , mat2 )
302317 self .assertTrue (any (n .kind () == "ipex::sd_flash_mha" for n in mha_graph .nodes ()))
303318
304319 def test_sd_mha_bf16_v3 (self ):
305- mat = torch .randn (2 , 4096 , 320 ).to (torch .bfloat16 )
320+ mat = ( torch .randn (2 , 4096 , 320 ) + 15 ).to (torch .bfloat16 )
306321 sd_mha_model = SD_MHA_Model_v3 (8 , 320 , 320 ).eval ()
307322 mha_ipex = ipex .optimize (sd_mha_model , dtype = torch .bfloat16 , level = "O1" )
308323
@@ -313,14 +328,14 @@ def test_sd_mha_bf16_v3(self):
313328 for _ in range (2 ):
314329 mha_jit = mha_ipex (mat )
315330 mha_ref = sd_mha_model (mat )
316- self .assertEqual (mha_ref , mha_jit , prec = 1e-2 )
331+ self .assertEqual (mha_ref , mha_jit , prec = 1e-0 )
317332
318333 mha_graph = mha_ipex .graph_for (mat )
319334 self .assertTrue (any (n .kind () == "ipex::sd_flash_mha" for n in mha_graph .nodes ()))
320335
321336 def test_sd_mha_bf16_v4 (self ):
322- mat1 = torch .randn (2 , 4096 , 320 ).to (torch .bfloat16 )
323- mat2 = torch .randn (2 , 77 , 320 ).to (torch .bfloat16 )
337+ mat1 = ( torch .randn (2 , 4096 , 320 ) + 15 ).to (torch .bfloat16 )
338+ mat2 = ( torch .randn (2 , 77 , 320 ) + 15 ).to (torch .bfloat16 )
324339 sd_mha_model = SD_MHA_Model_v4 (8 , 320 , 320 ).eval ()
325340 mha_ipex = ipex .optimize (sd_mha_model , dtype = torch .bfloat16 , level = "O1" )
326341
@@ -331,11 +346,31 @@ def test_sd_mha_bf16_v4(self):
331346 for _ in range (2 ):
332347 mha_jit = mha_ipex (mat1 , mat2 )
333348 mha_ref = sd_mha_model (mat1 , mat2 )
334- self .assertEqual (mha_ref , mha_jit , prec = 1e-2 )
349+ self .assertEqual (mha_ref , mha_jit , prec = 1e-0 )
335350
336351 mha_graph = mha_ipex .graph_for (mat1 , mat2 )
337352 self .assertTrue (any (n .kind () == "ipex::sd_flash_mha" for n in mha_graph .nodes ()))
338353
354+ def test_fake_sd_mha_bf16 (self ):
355+ mat1 = (torch .randn (1 , 2 , 64 , 64 ) + 20 ).to (torch .bfloat16 )
356+ mat2 = (torch .randn (1 , 2 , 64 , 64 ) - 20 ).to (torch .bfloat16 )
357+ mat3 = torch .randn (1 , 2 , 64 , 64 ).to (torch .bfloat16 )
358+ mask = (torch .ones (1 , 1 , 1 , 64 )).to (torch .bfloat16 )
359+ fake_sd_mha_model = Fake_SD_MHA_Model (64 , - 1 ).eval ()
360+ fake_mha_ipex = ipex .optimize (fake_sd_mha_model , dtype = torch .bfloat16 , level = "O1" )
361+
362+ with torch .cpu .amp .autocast (), torch .no_grad ():
363+ fake_mha_ipex = torch .jit .trace (fake_mha_ipex , (mat1 , mat2 , mat3 , mask , ))
364+ fake_mha_ipex = torch .jit .freeze (fake_mha_ipex )
365+
366+ for _ in range (2 ):
367+ fake_mha_jit = fake_mha_ipex (mat1 , mat2 , mat3 , mask )
368+ fake_mha_ref = fake_sd_mha_model (mat1 , mat2 , mat3 , mask )
369+ self .assertEqual (fake_mha_ref , fake_mha_jit , prec = 1e-1 )
370+
371+ fake_mha_graph = fake_mha_ipex .graph_for (mat1 , mat2 , mat3 , mask )
372+ self .assertTrue (any (n .kind () == "ipex::mha_scores_calc" for n in fake_mha_graph .nodes ()))
373+
339374 def test_transfree_mha_bf16 (self ):
340375 for i in range (len (bs )):
341376 mat = torch .randn (bs [i ], seq [i ], num_heads [i ] * head_dims [i ]).to (torch .bfloat16 )
0 commit comments