Skip to content

Commit 41baf10

Browse files
fix and add expectations for cuda and rocm platforms
1 parent 2e93004 commit 41baf10

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

tests/models/mistral3/test_modeling_mistral3.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def test_mistral3_integration_batched_generate(self):
355355
expected_outputs = Expectations(
356356
{
357357
("xpu", 3): "Calm lake's mirror gleams,\nWhispering pines stand in silence,\nPath to peace begins.",
358-
("cuda", 8): "Wooden path to calm,\nReflections whisper secrets,\nNature's peace unfolds.",
358+
("cuda", 8): "Calm waters reflect\nWooden path to distant shore\nSilence in the woods",
359359
("rocm", (9, 5)): "Calm waters reflect\nWooden path to distant shore\nSilence in the scene"
360360
}
361361
) # fmt: skip
@@ -432,7 +432,8 @@ def test_mistral3_integration_batched_generate_multi_image(self):
432432
decoded_output = processor.decode(gen_tokens[0], skip_special_tokens=True)
433433
expected_outputs = Expectations(
434434
{
435-
("cuda", 8): 'Calm waters reflect\nWooden path to distant shore\nSilence in the scene',
435+
("cuda", 8): "Calm waters reflect\nWooden path to distant shore\nPeace in nature's hold",
436+
("rocm", (9, 4)): "Calm waters reflect\nWooden path to distant shore\nSilence in the pines",
436437
}
437438
) # fmt: skip
438439
expected_output = expected_outputs.get_expectation()
@@ -448,6 +449,7 @@ def test_mistral3_integration_batched_generate_multi_image(self):
448449
{
449450
("xpu", 3): "Certainly! The images depict two iconic landmarks:\n\n1. The first image shows the Statue of Liberty in New York City.",
450451
("cuda", 8): 'Certainly! The images depict two famous landmarks in the United States:\n\n1. The first image shows the Statue of Liberty,',
452+
("rocm", (9, 4)): 'Certainly! The images depict two famous landmarks in the United States:\n\n1. The first image shows the Statue of Liberty,',
451453
}
452454
) # fmt: skip
453455
expected_output = expected_outputs.get_expectation()

0 commit comments

Comments
 (0)