Skip to content

Commit b083169

Browse files
UserChen666mamba-chenydshieh
authored
Adapt some test case on npu (#42335)
* Adapt some test case on npu * Adapt some test case on npu --------- Co-authored-by: mamba-chen <chenhao388@huawei.com> Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
1 parent f6dcac6 commit b083169

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

src/transformers/testing_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,12 @@
222222
IS_ROCM_SYSTEM = torch.version.hip is not None
223223
IS_CUDA_SYSTEM = torch.version.cuda is not None
224224
IS_XPU_SYSTEM = getattr(torch.version, "xpu", None) is not None
225+
IS_NPU_SYSTEM = getattr(torch, "npu", None) is not None
225226
else:
226227
IS_ROCM_SYSTEM = False
227228
IS_CUDA_SYSTEM = False
228229
IS_XPU_SYSTEM = False
230+
IS_NPU_SYSTEM = False
229231

230232
logger = transformers_logging.get_logger(__name__)
231233

@@ -3174,6 +3176,8 @@ def get_device_properties() -> DeviceProperties:
31743176
gen_mask = 0x000000FF00000000
31753177
gen = (arch & gen_mask) >> 32
31763178
return ("xpu", gen, None)
3179+
elif IS_NPU_SYSTEM:
3180+
return ("npu", None, None)
31773181
else:
31783182
return (torch_device, None, None)
31793183

tests/models/qwen3/test_modeling_qwen3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def test_speculative_generation(self):
168168
("xpu", 3): "My favourite condiment is 100% beef and comes in a 12 oz. jar. It is sold in",
169169
("cuda", 7): "My favourite condiment is 100% natural. It's a little spicy and a little sweet, but it's the",
170170
("cuda", 8): "My favourite condiment is 100% beef, 100% beef, 100% beef.",
171+
("npu", None): "My favourite condiment is 100% chicken and beef. I love it because it's so good and I love it",
171172
}
172173
) # fmt: skip
173174
EXPECTED_TEXT_COMPLETION = EXPECTED_TEXT_COMPLETIONS.get_expectation()
@@ -214,6 +215,7 @@ def test_export_static_cache(self):
214215
("xpu", None): ["My favourite condiment is 100% plain, unflavoured, and unadulterated. It is"],
215216
("rocm", (9, 5)): ["My favourite condiment is 100% plain, unflavoured, and unadulterated."],
216217
("cuda", None): cuda_expectation,
218+
("npu", None): ["My favourite condiment is 100% plain, unsalted, unsweetened, and unflavored. It is"],
217219
}
218220
) # fmt: skip
219221
EXPECTED_TEXT_COMPLETION = expected_text_completions.get_expectation()

tests/test_modeling_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def _can_output_attn(model):
448448
if torch_device in ["cpu", "cuda"]:
449449
atol = atols[torch_device, enable_kernels, dtype]
450450
rtol = rtols[torch_device, enable_kernels, dtype]
451-
elif torch_device == "hpu":
451+
elif torch_device in ["hpu", "npu"]:
452452
atol = atols["cuda", enable_kernels, dtype]
453453
rtol = rtols["cuda", enable_kernels, dtype]
454454
elif torch_device == "xpu":

0 commit comments

Comments
 (0)