Skip to content

Commit 1b02a35

Browse files
committed
Fix masked types; add asgi endpoint for function info
1 parent b36f7bf commit 1b02a35

13 files changed

Lines changed: 1073 additions & 638 deletions

File tree

singlestoredb/config.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,18 @@
407407
environ=['SINGLESTOREDB_EXT_FUNC_LOG_LEVEL'],
408408
)
409409

410+
register_option(
411+
'external_function.name_prefix', 'string', check_str, '',
412+
'Prefix to add to external function names.',
413+
environ=['SINGLESTOREDB_EXT_FUNC_NAME_PREFIX'],
414+
)
415+
416+
register_option(
417+
'external_function.name_suffix', 'string', check_str, '',
418+
'Suffix to add to external function names.',
419+
environ=['SINGLESTOREDB_EXT_FUNC_NAME_SUFFIX'],
420+
)
421+
410422
register_option(
411423
'external_function.connection', 'string', check_str,
412424
os.environ.get('SINGLESTOREDB_URL') or None,
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
from .decorator import tvf # noqa: F401
2+
from .decorator import tvf_with_null_masks # noqa: F401
23
from .decorator import udf # noqa: F401
4+
from .decorator import udf_with_null_masks # noqa: F401
5+
from .typing import Masked # noqa: F401
6+
from .typing import MaskedNDArray # noqa: F401

singlestoredb/functions/decorator.py

Lines changed: 200 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,99 @@
1-
from __future__ import annotations
2-
31
import functools
42
import inspect
3+
import typing
54
from typing import Any
65
from typing import Callable
76
from typing import List
87
from typing import Optional
98
from typing import Type
109
from typing import Union
1110

11+
from . import utils
12+
from .dtypes import SQLString
13+
1214

1315
ParameterType = Union[
1416
str,
15-
Callable[..., str],
16-
List[Union[str, Callable[..., str]]],
17+
Callable[..., SQLString],
18+
List[Union[str, Callable[..., SQLString]]],
1719
Type[Any],
1820
]
1921

2022
ReturnType = ParameterType
2123

2224

25+
def is_valid_type(obj: Any) -> bool:
26+
"""Check if the object is a valid type for a schema definition."""
27+
if not inspect.isclass(obj):
28+
return False
29+
30+
if utils.is_typeddict(obj):
31+
return True
32+
33+
if utils.is_namedtuple(obj):
34+
return True
35+
36+
if utils.is_dataclass(obj):
37+
return True
38+
39+
# We don't want to import pydantic here, so we check if
40+
# the class is a subclass
41+
if utils.is_pydantic(obj):
42+
return True
43+
44+
return False
45+
46+
47+
def is_valid_callable(obj: Any) -> bool:
48+
"""Check if the object is a valid callable for a parameter type."""
49+
if not callable(obj):
50+
return False
51+
52+
returns = inspect.get_annotations(obj).get('return', None)
53+
54+
if inspect.isclass(returns) and issubclass(returns, str):
55+
return True
56+
57+
raise TypeError(
58+
f'callable {obj} must return a str, '
59+
f'but got {returns}',
60+
)
61+
62+
63+
def verify_mask(obj: Any) -> bool:
64+
"""Verify that the object is a tuple of two vector types."""
65+
if typing.get_origin(obj) is not tuple or len(typing.get_args(obj)) != 2:
66+
raise TypeError(
67+
f'Expected a tuple of two vector types, but got {type(obj)}',
68+
)
69+
70+
args = typing.get_args(obj)
71+
72+
if not utils.is_vector(args[0]):
73+
raise TypeError(
74+
f'Expected a vector type for the first element, but got {args[0]}',
75+
)
76+
77+
if not utils.is_vector(args[1]):
78+
raise TypeError(
79+
f'Expected a vector type for the second element, but got {args[1]}',
80+
)
81+
82+
return True
83+
84+
85+
def verify_masks(obj: Callable[..., Any]) -> bool:
86+
"""Verify that the function parameters and return value are all masks."""
87+
ann = utils.get_annotations(obj)
88+
for name, value in ann.items():
89+
if not verify_mask(value):
90+
raise TypeError(
91+
f'Expected a vector type for the parameter {name} '
92+
f'in function {obj.__name__}, but got {value}',
93+
)
94+
return True
95+
96+
2397
def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
2498
"""Expand the types for the function arguments / return values."""
2599
if args is None:
@@ -30,18 +104,11 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
30104
return [args]
31105

32106
# General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
33-
elif inspect.isclass(args):
107+
elif is_valid_type(args):
34108
return args
35109

36-
# Callable that returns a SQL string
37-
elif callable(args):
38-
out = args()
39-
if not isinstance(out, str):
40-
raise TypeError(f'unrecognized type for parameter: {args}')
41-
return [out]
42-
43110
# List of SQL strings or callables
44-
else:
111+
elif isinstance(args, list):
45112
new_args = []
46113
for arg in args:
47114
if isinstance(arg, str):
@@ -52,14 +119,23 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
52119
raise TypeError(f'unrecognized type for parameter: {arg}')
53120
return new_args
54121

122+
# Callable that returns a SQL string
123+
elif is_valid_callable(args):
124+
out = args()
125+
if not isinstance(out, str):
126+
raise TypeError(f'unrecognized type for parameter: {args}')
127+
return [out]
128+
129+
raise TypeError(f'unrecognized type for parameter: {args}')
130+
55131

56132
def _func(
57133
func: Optional[Callable[..., Any]] = None,
58134
*,
59135
name: Optional[str] = None,
60136
args: Optional[ParameterType] = None,
61137
returns: Optional[ReturnType] = None,
62-
include_masks: bool = False,
138+
with_null_masks: bool = False,
63139
function_type: str = 'udf',
64140
) -> Callable[..., Any]:
65141
"""Generic wrapper for UDF and TVF decorators."""
@@ -69,7 +145,7 @@ def _func(
69145
name=name,
70146
args=expand_types(args),
71147
returns=expand_types(returns),
72-
include_masks=include_masks,
148+
with_null_masks=with_null_masks,
73149
function_type=function_type,
74150
).items() if v is not None
75151
}
@@ -79,12 +155,21 @@ def _func(
79155
# in at that time.
80156
if func is None:
81157
def decorate(func: Callable[..., Any]) -> Callable[..., Any]:
158+
if with_null_masks:
159+
verify_masks(func)
160+
82161
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
83162
return func(*args, **kwargs) # type: ignore
163+
84164
wrapper._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
165+
85166
return functools.wraps(func)(wrapper)
167+
86168
return decorate
87169

170+
if with_null_masks:
171+
verify_masks(func)
172+
88173
def wrapper(*args: Any, **kwargs: Any) -> Callable[..., Any]:
89174
return func(*args, **kwargs) # type: ignore
90175

@@ -99,10 +184,9 @@ def udf(
99184
name: Optional[str] = None,
100185
args: Optional[ParameterType] = None,
101186
returns: Optional[ReturnType] = None,
102-
include_masks: bool = False,
103187
) -> Callable[..., Any]:
104188
"""
105-
Apply attributes to a UDF.
189+
Define a user-defined function (UDF).
106190
107191
Parameters
108192
----------
@@ -126,10 +210,6 @@ def udf(
126210
returns : str, optional
127211
Specifies the return data type of the function. If not specified,
128212
the type annotation from the function is used.
129-
include_masks : bool, optional
130-
Should boolean masks be included with each input parameter to indicate
131-
which elements are NULL? This is only used when a input parameters are
132-
configured to a vector type (numpy, pandas, polars, arrow).
133213
134214
Returns
135215
-------
@@ -141,7 +221,55 @@ def udf(
141221
name=name,
142222
args=args,
143223
returns=returns,
144-
include_masks=include_masks,
224+
with_null_masks=False,
225+
function_type='udf',
226+
)
227+
228+
229+
def udf_with_null_masks(
230+
func: Optional[Callable[..., Any]] = None,
231+
*,
232+
name: Optional[str] = None,
233+
args: Optional[ParameterType] = None,
234+
returns: Optional[ReturnType] = None,
235+
) -> Callable[..., Any]:
236+
"""
237+
Define a user-defined function (UDF) with null masks.
238+
239+
Parameters
240+
----------
241+
func : callable, optional
242+
The UDF to apply parameters to
243+
name : str, optional
244+
The name to use for the UDF in the database
245+
args : str | Callable | List[str | Callable], optional
246+
Specifies the data types of the function arguments. Typically,
247+
the function data types are derived from the function parameter
248+
annotations. These annotations can be overridden. If the function
249+
takes a single type for all parameters, `args` can be set to a
250+
SQL string describing all parameters. If the function takes more
251+
than one parameter and all of the parameters are being manually
252+
defined, a list of SQL strings may be used (one for each parameter).
253+
A dictionary of SQL strings may be used to specify a parameter type
254+
for a subset of parameters; the keys are the names of the
255+
function parameters. Callables may also be used for datatypes. This
256+
is primarily for using the functions in the ``dtypes`` module that
257+
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
258+
returns : str, optional
259+
Specifies the return data type of the function. If not specified,
260+
the type annotation from the function is used.
261+
262+
Returns
263+
-------
264+
Callable
265+
266+
"""
267+
return _func(
268+
func=func,
269+
name=name,
270+
args=args,
271+
returns=returns,
272+
with_null_masks=True,
145273
function_type='udf',
146274
)
147275

@@ -152,10 +280,57 @@ def tvf(
152280
name: Optional[str] = None,
153281
args: Optional[ParameterType] = None,
154282
returns: Optional[ReturnType] = None,
155-
include_masks: bool = False,
156283
) -> Callable[..., Any]:
157284
"""
158-
Apply attributes to a TVF.
285+
Define a table-valued function (TVF).
286+
287+
Parameters
288+
----------
289+
func : callable, optional
290+
The TVF to apply parameters to
291+
name : str, optional
292+
The name to use for the TVF in the database
293+
args : str | Callable | List[str | Callable], optional
294+
Specifies the data types of the function arguments. Typically,
295+
the function data types are derived from the function parameter
296+
annotations. These annotations can be overridden. If the function
297+
takes a single type for all parameters, `args` can be set to a
298+
SQL string describing all parameters. If the function takes more
299+
than one parameter and all of the parameters are being manually
300+
defined, a list of SQL strings may be used (one for each parameter).
301+
A dictionary of SQL strings may be used to specify a parameter type
302+
for a subset of parameters; the keys are the names of the
303+
function parameters. Callables may also be used for datatypes. This
304+
is primarily for using the functions in the ``dtypes`` module that
305+
are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
306+
returns : str, optional
307+
Specifies the return data type of the function. If not specified,
308+
the type annotation from the function is used.
309+
310+
Returns
311+
-------
312+
Callable
313+
314+
"""
315+
return _func(
316+
func=func,
317+
name=name,
318+
args=args,
319+
returns=returns,
320+
with_null_masks=False,
321+
function_type='tvf',
322+
)
323+
324+
325+
def tvf_with_null_masks(
326+
func: Optional[Callable[..., Any]] = None,
327+
*,
328+
name: Optional[str] = None,
329+
args: Optional[ParameterType] = None,
330+
returns: Optional[ReturnType] = None,
331+
) -> Callable[..., Any]:
332+
"""
333+
Define a table-valued function (TVF) using null masks.
159334
160335
Parameters
161336
----------
@@ -179,10 +354,6 @@ def tvf(
179354
returns : str, optional
180355
Specifies the return data type of the function. If not specified,
181356
the type annotation from the function is used.
182-
include_masks : bool, optional
183-
Should boolean masks be included with each input parameter to indicate
184-
which elements are NULL? This is only used when a input parameters are
185-
configured to a vector type (numpy, pandas, polars, arrow).
186357
187358
Returns
188359
-------
@@ -194,6 +365,6 @@ def tvf(
194365
name=name,
195366
args=args,
196367
returns=returns,
197-
include_masks=include_masks,
368+
with_null_masks=True,
198369
function_type='tvf',
199370
)

0 commit comments

Comments
 (0)