Skip to content

Commit cb2bfc3

Browse files
authored
[mypyc] Add primitives for bytes and str multiply (#20303)
These are a bit faster in microbenchmarks for short strings and bytes objects (~20% or so). I experimented with using LLMs to generate much of the code. I gave them detailed instructions and reviewed all the output.
1 parent 0c2bf7a commit cb2bfc3

File tree

9 files changed

+168
-0
lines changed

9 files changed

+168
-0
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -770,6 +770,7 @@ PyObject *CPy_Encode(PyObject *obj, PyObject *encoding, PyObject *errors);
770770
Py_ssize_t CPyStr_Count(PyObject *unicode, PyObject *substring, CPyTagged start);
771771
Py_ssize_t CPyStr_CountFull(PyObject *unicode, PyObject *substring, CPyTagged start, CPyTagged end);
772772
CPyTagged CPyStr_Ord(PyObject *obj);
773+
PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count);
773774

774775

775776
// Bytes operations
@@ -781,6 +782,7 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index);
781782
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
782783
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
783784
CPyTagged CPyBytes_Ord(PyObject *obj);
785+
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count);
784786

785787

786788
int CPyBytes_Compare(PyObject *left, PyObject *right);

mypyc/lib-rt/bytes_ops.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,12 @@ CPyTagged CPyBytes_Ord(PyObject *obj) {
162162
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
163163
return CPY_INT_TAG;
164164
}
165+
166+
PyObject *CPyBytes_Multiply(PyObject *bytes, CPyTagged count) {
167+
Py_ssize_t temp_count = CPyTagged_AsSsize_t(count);
168+
if (temp_count == -1 && PyErr_Occurred()) {
169+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
170+
return NULL;
171+
}
172+
return PySequence_Repeat(bytes, temp_count);
173+
}

mypyc/lib-rt/str_ops.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,12 @@ CPyTagged CPyStr_Ord(PyObject *obj) {
621621
PyExc_TypeError, "ord() expected a character, but a string of length %zd found", s);
622622
return CPY_INT_TAG;
623623
}
624+
625+
PyObject *CPyStr_Multiply(PyObject *str, CPyTagged count) {
626+
Py_ssize_t temp_count = CPyTagged_AsSsize_t(count);
627+
if (temp_count == -1 && PyErr_Occurred()) {
628+
PyErr_SetString(PyExc_OverflowError, CPYTHON_LARGE_INT_ERRMSG);
629+
return NULL;
630+
}
631+
return PySequence_Repeat(str, temp_count);
632+
}

mypyc/primitives/bytes_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,25 @@
8282
steals=[True, False],
8383
)
8484

85+
# bytes * int
86+
binary_op(
87+
name="*",
88+
arg_types=[bytes_rprimitive, int_rprimitive],
89+
return_type=bytes_rprimitive,
90+
c_function_name="CPyBytes_Multiply",
91+
error_kind=ERR_MAGIC,
92+
)
93+
94+
# int * bytes
95+
binary_op(
96+
name="*",
97+
arg_types=[int_rprimitive, bytes_rprimitive],
98+
return_type=bytes_rprimitive,
99+
c_function_name="CPyBytes_Multiply",
100+
error_kind=ERR_MAGIC,
101+
ordering=[1, 0],
102+
)
103+
85104
# bytes[begin:end]
86105
bytes_slice_op = custom_op(
87106
arg_types=[bytes_rprimitive, int_rprimitive, int_rprimitive],

mypyc/primitives/str_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,25 @@
7979
steals=[True, False],
8080
)
8181

82+
# str * int
83+
binary_op(
84+
name="*",
85+
arg_types=[str_rprimitive, int_rprimitive],
86+
return_type=str_rprimitive,
87+
c_function_name="CPyStr_Multiply",
88+
error_kind=ERR_MAGIC,
89+
)
90+
91+
# int * str
92+
binary_op(
93+
name="*",
94+
arg_types=[int_rprimitive, str_rprimitive],
95+
return_type=str_rprimitive,
96+
c_function_name="CPyStr_Multiply",
97+
error_kind=ERR_MAGIC,
98+
ordering=[1, 0],
99+
)
100+
82101
# str1 == str2 (very common operation, so we provide our own)
83102
str_eq = custom_primitive_op(
84103
name="str_eq",

mypyc/test-data/irbuild-bytes.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,24 @@ L2:
217217
L3:
218218
keep_alive y
219219
return r2
220+
221+
[case testBytesMultiply]
222+
def b_times_i(s: bytes, n: int) -> bytes:
223+
return s * n
224+
def i_times_b(s: bytes, n: int) -> bytes:
225+
return n * s
226+
[out]
227+
def b_times_i(s, n):
228+
s :: bytes
229+
n :: int
230+
r0 :: bytes
231+
L0:
232+
r0 = CPyBytes_Multiply(s, n)
233+
return r0
234+
def i_times_b(s, n):
235+
s :: bytes
236+
n :: int
237+
r0 :: bytes
238+
L0:
239+
r0 = CPyBytes_Multiply(s, n)
240+
return r0

mypyc/test-data/irbuild-str.test

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,24 @@ L0:
771771
r0 = 'literal'
772772
r1 = 'literal'
773773
return 1
774+
775+
[case testStrMultiply]
776+
def s_times_i(s: str, n: int) -> str:
777+
return s * n
778+
def i_times_s(s: str, n: int) -> str:
779+
return n * s
780+
[out]
781+
def s_times_i(s, n):
782+
s :: str
783+
n :: int
784+
r0 :: str
785+
L0:
786+
r0 = CPyStr_Multiply(s, n)
787+
return r0
788+
def i_times_s(s, n):
789+
s :: str
790+
n :: int
791+
r0 :: str
792+
L0:
793+
r0 = CPyStr_Multiply(s, n)
794+
return r0

mypyc/test-data/run-bytes.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,40 @@ def test_ord_bytesarray() -> None:
134134
with assertRaises(TypeError):
135135
ord(bytearray(b''))
136136

137+
def test_multiply() -> None:
138+
# Use bytes() and int() to avoid constant folding
139+
b = b'ab' + bytes()
140+
zero = int()
141+
one = 1 + zero
142+
three = 3 + zero
143+
neg_one = -1 + zero
144+
145+
assert b * zero == b''
146+
assert b * one == b'ab'
147+
assert b * three == b'ababab'
148+
assert b * neg_one == b''
149+
assert zero * b == b''
150+
assert one * b == b'ab'
151+
assert three * b == b'ababab'
152+
assert neg_one * b == b''
153+
154+
# Test with empty bytes
155+
empty = bytes()
156+
five = 5 + zero
157+
assert empty * five == b''
158+
assert five * empty == b''
159+
160+
# Test with single byte
161+
single = b'\xff' + bytes()
162+
four = 4 + zero
163+
assert single * four == b'\xff\xff\xff\xff'
164+
assert four * single == b'\xff\xff\xff\xff'
165+
166+
# Test type preservation
167+
two = 2 + zero
168+
result = b * two
169+
assert type(result) == bytes
170+
137171
[case testBytesSlicing]
138172
def test_bytes_slicing() -> None:
139173
b = b'abcdefg'

mypyc/test-data/run-strings.test

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,40 @@ def test_str_min_max() -> None:
362362
assert max(x, y) == 'bbb'
363363
assert max(x, z) == 'aaa'
364364

365+
def test_multiply() -> None:
366+
# Use str() and int() to avoid constant folding
367+
s = 'ab' + str()
368+
zero = int()
369+
one = 1 + zero
370+
three = 3 + zero
371+
neg_one = -1 + zero
372+
373+
assert s * zero == ''
374+
assert s * one == 'ab'
375+
assert s * three == 'ababab'
376+
assert s * neg_one == ''
377+
assert zero * s == ''
378+
assert one * s == 'ab'
379+
assert three * s == 'ababab'
380+
assert neg_one * s == ''
381+
382+
# Test with empty string
383+
empty = str()
384+
five = 5 + zero
385+
assert empty * five == ''
386+
assert five * empty == ''
387+
388+
# Test with single character
389+
single = 'x' + str()
390+
four = 4 + zero
391+
assert single * four == 'xxxx'
392+
assert four * single == 'xxxx'
393+
394+
# Test type preservation
395+
two = 2 + zero
396+
result = s * two
397+
assert type(result) == str
398+
365399
[case testStringFormattingCStyle]
366400
from typing import Tuple
367401

0 commit comments

Comments
 (0)