Skip to content

Commit 0bae1ee

Browse files
committed
Implement transformers for UDF inputs and outputs
1 parent 9996e09 commit 0bae1ee

File tree

13 files changed

+1254
-65
lines changed

13 files changed

+1254
-65
lines changed

singlestoredb/functions/ext/asgi.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def cancel_on_event(
311311

312312
def build_udf_endpoint(
313313
func: Callable[..., Any],
314+
args_data_format: str,
314315
returns_data_format: str,
315316
) -> Callable[..., Any]:
316317
"""
@@ -352,11 +353,12 @@ async def do_func(
352353

353354
return do_func
354355

355-
return build_vector_udf_endpoint(func, returns_data_format)
356+
return build_vector_udf_endpoint(func, args_data_format, returns_data_format)
356357

357358

358359
def build_vector_udf_endpoint(
359360
func: Callable[..., Any],
361+
args_data_format: str,
360362
returns_data_format: str,
361363
) -> Callable[..., Any]:
362364
"""
@@ -422,6 +424,7 @@ async def do_func(
422424

423425
def build_tvf_endpoint(
424426
func: Callable[..., Any],
427+
args_data_format: str,
425428
returns_data_format: str,
426429
) -> Callable[..., Any]:
427430
"""
@@ -451,27 +454,27 @@ async def do_func(
451454
rows: Sequence[Sequence[Any]],
452455
) -> Tuple[Sequence[int], List[Tuple[Any, ...]]]:
453456
'''Call function on given rows of data.'''
454-
out_ids: List[int] = []
455-
out = []
457+
out: List[Tuple[Any, ...]] = []
456458
# Call function on each row of data
457459
async with timer('call_function'):
460+
out = []
458461
for i, row in zip(row_ids, rows):
459462
cancel_on_event(cancel_event)
460463
if is_async:
461464
res = await func(*row)
462465
else:
463466
res = func(*row)
464467
out.extend(as_list_of_tuples(res))
465-
out_ids.extend([row_ids[i]] * (len(out)-len(out_ids)))
466-
return out_ids, out
468+
return [row_ids[0]] * len(out), out
467469

468470
return do_func
469471

470-
return build_vector_tvf_endpoint(func, returns_data_format)
472+
return build_vector_tvf_endpoint(func, args_data_format, returns_data_format)
471473

472474

473475
def build_vector_tvf_endpoint(
474476
func: Callable[..., Any],
477+
args_data_format: str,
475478
returns_data_format: str,
476479
) -> Callable[..., Any]:
477480
"""
@@ -575,9 +578,9 @@ def make_func(
575578
)
576579

577580
if function_type == 'tvf':
578-
do_func = build_tvf_endpoint(func, returns_data_format)
581+
do_func = build_tvf_endpoint(func, args_data_format, returns_data_format)
579582
else:
580-
do_func = build_udf_endpoint(func, returns_data_format)
583+
do_func = build_udf_endpoint(func, args_data_format, returns_data_format)
581584

582585
do_func.__name__ = name
583586
do_func.__doc__ = func.__doc__

singlestoredb/functions/ext/rowdat_1.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def _dump_vectors(
462462
default = DEFAULT_VALUES[rtype]
463463
try:
464464
if rtype in numeric_formats:
465-
if value is None:
465+
if is_null or value is None:
466466
out.write(struct.pack(numeric_formats[rtype], default))
467467
else:
468468
if rtype in int_types:
@@ -486,14 +486,14 @@ def _dump_vectors(
486486
),
487487
)
488488
elif rtype in string_types:
489-
if value is None:
489+
if is_null or value is None:
490490
out.write(struct.pack('<q', 0))
491491
else:
492492
sval = value.encode('utf-8')
493493
out.write(struct.pack('<q', len(sval)))
494494
out.write(sval)
495495
elif rtype in binary_types:
496-
if value is None:
496+
if is_null or value is None:
497497
out.write(struct.pack('<q', 0))
498498
else:
499499
out.write(struct.pack('<q', len(value)))
@@ -571,8 +571,18 @@ def _load_numpy_accel(
571571

572572
for i, (_, dtype, transformer) in enumerate(colspec):
573573
if transformer is not None:
574-
t = np.vectorize(transformer)
575-
numpy_cols[i] = (t(numpy_cols[i][0]), numpy_cols[i][1])
574+
# Numpy will try to be "helpful" and create multidimensional arrays
575+
# from nested iterables. We don't usually want that. What we want is
576+
# numpy arrays of Python objects (e.g., lists, dicts, etc). To do that,
577+
# we have to create an empty array of the correct length and dtype=object,
578+
# then fill it in with the transformed values. The transformer may have
579+
# an output_type attribute that we can use to create a more specific type.
580+
if getattr(transformer, 'output_type', None):
581+
new_col = np.empty(len(numpy_cols[i][0]), dtype=transformer.output_type)
582+
new_col[:] = list(map(transformer, numpy_cols[i][0]))
583+
else:
584+
new_col = np.array(list(map(transformer, numpy_cols[i][0])))
585+
numpy_cols[i] = (new_col, numpy_cols[i][1])
576586

577587
return numpy_ids, numpy_cols
578588

@@ -589,8 +599,7 @@ def _dump_numpy_accel(
589599

590600
for i, (_, dtype, transformer) in enumerate(returns):
591601
if transformer is not None:
592-
t = np.vectorize(transformer)
593-
cols[i] = (t(cols[i][0]), cols[i][1])
602+
cols[i] = (np.array(list(map(transformer, cols[i][0]))), cols[i][1])
594603

595604
return _singlestoredb_accel.dump_rowdat_1_numpy(returns, row_ids, cols)
596605

@@ -678,10 +687,18 @@ def _dump_polars_accel(
678687
if not has_accel:
679688
raise RuntimeError('could not load SingleStoreDB extension')
680689

690+
import numpy as np
691+
import polars as pl
692+
681693
numpy_ids = row_ids.to_numpy()
682694
numpy_cols = [
683695
(
684-
data.to_numpy(),
696+
# Polars will try to be "helpful" and convert nested iterables into
697+
# multidimensional arrays. We don't usually want that. What we want is
698+
# numpy arrays of Python objects (e.g., lists, dicts, etc). To
699+
# do that, we have to convert the Series to a list first.
700+
np.array(data.to_list())
701+
if isinstance(data.dtype, (pl.Struct, pl.Object)) else data.to_numpy(),
685702
mask.to_numpy() if mask is not None else None,
686703
)
687704
for data, mask in cols
@@ -722,7 +739,7 @@ def _create_arrow_mask(
722739
if mask is None:
723740
return data.is_null().to_numpy(zero_copy_only=False)
724741

725-
return pc.or_(data.is_null(), mask.is_null()).to_numpy(zero_copy_only=False)
742+
return pc.or_(data.is_null(), mask).to_numpy(zero_copy_only=False)
726743

727744

728745
def _dump_arrow_accel(

singlestoredb/functions/ext/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import zipfile
88
from copy import copy
99
from typing import Any
10-
from typing import Callable
1110
from typing import Dict
1211
from typing import List
1312
from typing import Optional
@@ -32,8 +31,7 @@ def formatMessage(self, record: logging.LogRecord) -> str:
3231
recordcopy.__dict__['levelprefix'] = levelname + ':' + seperator
3332
return super().formatMessage(recordcopy)
3433

35-
36-
Transformer = Callable[..., Any]
34+
from ..typing import Transformer
3735

3836

3937
def apply_transformer(func: Optional[Transformer], v: Any) -> Any:

singlestoredb/functions/signature.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ def get_schema(
923923
spec: Any,
924924
overrides: Optional[List[ParamSpec]] = None,
925925
mode: str = 'parameter',
926+
masks: Optional[List[bool]] = None,
926927
) -> Tuple[List[ParamSpec], str, str]:
927928
"""
928929
Expand a return type annotation into a list of types and field names.
@@ -935,6 +936,8 @@ def get_schema(
935936
List of SQL type specifications for the return type
936937
mode : str
937938
The mode of the function, either 'parameter' or 'return'
939+
is_masked : bool
940+
Whether the type is wrapped in a Masked type
938941
939942
Returns
940943
-------
@@ -996,7 +999,13 @@ def get_schema(
996999
'dataclass, TypedDict, or pydantic model',
9971000
)
9981001
spec = typing.get_args(unpacked_spec[0])[0]
999-
data_format = 'list'
1002+
# Lists as output from TVFs are considered scalar outputs
1003+
# since they correspond to individual Python objects, not
1004+
# a true vector type.
1005+
if function_type == 'tvf':
1006+
data_format = 'scalar'
1007+
else:
1008+
data_format = 'list'
10001009

10011010
elif all([utils.is_vector(x, include_masks=True) for x in unpacked_spec]):
10021011
pass
@@ -1113,7 +1122,11 @@ def get_schema(
11131122
_, inner_apply_meta = unpack_annotated(typing.get_args(spec)[0])
11141123
if inner_apply_meta.sql_type:
11151124
udf_attrs = inner_apply_meta
1116-
colspec = get_schema(typing.get_args(spec)[0], mode=mode)[0]
1125+
colspec = get_schema(
1126+
typing.get_args(spec)[0],
1127+
mode=mode,
1128+
masks=[masks[0]] if masks else None,
1129+
)[0]
11171130
else:
11181131
colspec = [
11191132
ParamSpec(
@@ -1144,6 +1157,7 @@ def get_schema(
11441157
overrides=[overrides[i]] if overrides else [],
11451158
# Always pass UDF mode for individual items
11461159
mode=mode,
1160+
masks=[masks[i]] if masks else None,
11471161
)
11481162

11491163
# Use the name from the overrides if specified
@@ -1185,7 +1199,7 @@ def get_schema(
11851199
out = []
11861200

11871201
# Normalize colspec data types
1188-
for c in colspec:
1202+
for i, c in enumerate(colspec):
11891203

11901204
# if the dtype is a string, it is resolved already
11911205
if isinstance(c.dtype, str):
@@ -1203,13 +1217,27 @@ def get_schema(
12031217
include_null=c.is_optional,
12041218
)
12051219

1220+
sql_type = c.sql_type if isinstance(c.sql_type, str) else udf_attrs.sql_type
1221+
1222+
is_optional = (
1223+
c.is_optional
1224+
or bool(dtype and dtype.endswith('?'))
1225+
or bool(masks and masks[i])
1226+
)
1227+
1228+
if is_optional:
1229+
if dtype and not dtype.endswith('?'):
1230+
dtype += '?'
1231+
if sql_type and re.search(r' NOT NULL\b', sql_type):
1232+
sql_type = re.sub(r' NOT NULL\b', r' NULL', sql_type)
1233+
12061234
p = ParamSpec(
12071235
name=c.name,
12081236
dtype=dtype,
1209-
sql_type=c.sql_type if isinstance(c.sql_type, str) else udf_attrs.sql_type,
1210-
is_optional=c.is_optional or bool(dtype and dtype.endswith('?')),
1211-
transformer=udf_attrs.input_transformer
1212-
if mode == 'parameter' else udf_attrs.output_transformer,
1237+
sql_type=sql_type,
1238+
is_optional=is_optional,
1239+
transformer=udf_attrs.args_transformer
1240+
if mode == 'parameter' else udf_attrs.returns_transformer,
12131241
)
12141242

12151243
out.append(p)
@@ -1347,6 +1375,7 @@ def get_signature(
13471375
unpack_masked_type(param.annotation),
13481376
overrides=[args_colspec[i]] if args_colspec else [],
13491377
mode='parameter',
1378+
masks=[args_masks[i]] if args_masks else [],
13501379
)
13511380
args_data_formats.append(args_data_format)
13521381

@@ -1406,6 +1435,7 @@ def get_signature(
14061435
unpack_masked_type(signature.return_annotation),
14071436
overrides=returns_colspec if returns_colspec else None,
14081437
mode='return',
1438+
masks=ret_masks or [],
14091439
)
14101440

14111441
rdf = out['returns_data_format'] = out['returns_data_format'] or 'scalar'
@@ -1421,6 +1451,12 @@ def get_signature(
14211451
'scalar or vector types.',
14221452
)
14231453

1454+
# If we hava function parameters and the function is a TVF, then
1455+
# the return type should just match the parameter vector types. This ensures
1456+
# the output producers for scalars and vectors are consistent.
1457+
elif function_type == 'tvf' and rdf == 'scalar' and args_schema:
1458+
out['returns_data_format'] = out['args_data_format']
1459+
14241460
# All functions have to return a value, so if none was specified try to
14251461
# insert a reasonable default that includes NULLs.
14261462
if not ret_schema:

0 commit comments

Comments
 (0)