@@ -341,3 +341,58 @@ def test_add_dtype_error(
341341 assert_raises_regex (
342342 TypeError , "Output array of type.*is needed" , dpt .add , ar1 , ar2 , y
343343 )
344+
345+
346+ @pytest .mark .parametrize ("dtype" , _all_dtypes )
347+ def test_add_inplace_python_scalar (dtype ):
348+ q = get_queue_or_skip ()
349+ skip_if_dtype_not_supported (dtype , q )
350+ X = dpt .zeros ((10 , 10 ), dtype = dtype , sycl_queue = q )
351+ dt_kind = X .dtype .kind
352+ if dt_kind in "ui" :
353+ X += int (0 )
354+ elif dt_kind == "f" :
355+ X += float (0 )
356+ elif dt_kind == "c" :
357+ X += complex (0 )
358+ elif dt_kind == "b" :
359+ X += bool (0 )
360+
361+
362+ @pytest .mark .parametrize ("op1_dtype" , _all_dtypes )
363+ @pytest .mark .parametrize ("op2_dtype" , _all_dtypes )
364+ def test_add_inplace_dtype_matrix (op1_dtype , op2_dtype ):
365+ q = get_queue_or_skip ()
366+ skip_if_dtype_not_supported (op1_dtype , q )
367+ skip_if_dtype_not_supported (op2_dtype , q )
368+
369+ if dpt .can_cast (op2_dtype , op1_dtype , casting = "safe" ):
370+ sz = 127
371+ ar1 = dpt .ones (sz , dtype = op1_dtype )
372+ ar2 = dpt .ones_like (ar1 , dtype = op2_dtype )
373+
374+ ar1 += ar2
375+ assert (
376+ dpt .asnumpy (ar1 ) == np .full (ar1 .shape , 2 , dtype = ar1 .dtype )
377+ ).all ()
378+
379+ ar3 = dpt .ones (sz , dtype = op1_dtype )
380+ ar4 = dpt .ones (2 * sz , dtype = op2_dtype )
381+
382+ ar3 [::- 1 ] += ar4 [::2 ]
383+ assert (
384+ dpt .asnumpy (ar3 ) == np .full (ar3 .shape , 2 , dtype = ar3 .dtype )
385+ ).all ()
386+
387+ else :
388+ assert pytest .raises (TypeError )
389+
390+
391+ def test_add_inplace_broadcasting ():
392+ get_queue_or_skip ()
393+
394+ m = dpt .ones ((100 , 5 ), dtype = "i4" )
395+ v = dpt .arange (5 , dtype = "i4" )
396+
397+ m += v
398+ assert (dpt .asnumpy (m ) == np .arange (1 , 6 , dtype = "i4" )[np .newaxis , :]).all ()
0 commit comments