1- from __future__ import annotations
2-
31import functools
42import inspect
3+ import typing
54from typing import Any
65from typing import Callable
76from typing import List
87from typing import Optional
98from typing import Type
109from typing import Union
1110
11+ from . import utils
12+ from .dtypes import SQLString
13+
1214
1315ParameterType = Union [
1416 str ,
15- Callable [..., str ],
16- List [Union [str , Callable [..., str ]]],
17+ Callable [..., SQLString ],
18+ List [Union [str , Callable [..., SQLString ]]],
1719 Type [Any ],
1820]
1921
2022ReturnType = ParameterType
2123
2224
25+ def is_valid_type (obj : Any ) -> bool :
26+ """Check if the object is a valid type for a schema definition."""
27+ if not inspect .isclass (obj ):
28+ return False
29+
30+ if utils .is_typeddict (obj ):
31+ return True
32+
33+ if utils .is_namedtuple (obj ):
34+ return True
35+
36+ if utils .is_dataclass (obj ):
37+ return True
38+
39+ # We don't want to import pydantic here, so we check if
40+ # the class is a subclass
41+ if utils .is_pydantic (obj ):
42+ return True
43+
44+ return False
45+
46+
47+ def is_valid_callable (obj : Any ) -> bool :
48+ """Check if the object is a valid callable for a parameter type."""
49+ if not callable (obj ):
50+ return False
51+
52+ returns = inspect .get_annotations (obj ).get ('return' , None )
53+
54+ if inspect .isclass (returns ) and issubclass (returns , str ):
55+ return True
56+
57+ raise TypeError (
58+ f'callable { obj } must return a str, '
59+ f'but got { returns } ' ,
60+ )
61+
62+
63+ def verify_mask (obj : Any ) -> bool :
64+ """Verify that the object is a tuple of two vector types."""
65+ if typing .get_origin (obj ) is not tuple or len (typing .get_args (obj )) != 2 :
66+ raise TypeError (
67+ f'Expected a tuple of two vector types, but got { type (obj )} ' ,
68+ )
69+
70+ args = typing .get_args (obj )
71+
72+ if not utils .is_vector (args [0 ]):
73+ raise TypeError (
74+ f'Expected a vector type for the first element, but got { args [0 ]} ' ,
75+ )
76+
77+ if not utils .is_vector (args [1 ]):
78+ raise TypeError (
79+ f'Expected a vector type for the second element, but got { args [1 ]} ' ,
80+ )
81+
82+ return True
83+
84+
85+ def verify_masks (obj : Callable [..., Any ]) -> bool :
86+ """Verify that the function parameters and return value are all masks."""
87+ ann = utils .get_annotations (obj )
88+ for name , value in ann .items ():
89+ if not verify_mask (value ):
90+ raise TypeError (
91+ f'Expected a vector type for the parameter { name } '
92+ f'in function { obj .__name__ } , but got { value } ' ,
93+ )
94+ return True
95+
96+
2397def expand_types (args : Any ) -> Optional [Union [List [str ], Type [Any ]]]:
2498 """Expand the types for the function arguments / return values."""
2599 if args is None :
@@ -30,18 +104,11 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
30104 return [args ]
31105
32106 # General way of accepting pydantic.BaseModel, NamedTuple, TypedDict
33- elif inspect . isclass (args ):
107+ elif is_valid_type (args ):
34108 return args
35109
36- # Callable that returns a SQL string
37- elif callable (args ):
38- out = args ()
39- if not isinstance (out , str ):
40- raise TypeError (f'unrecognized type for parameter: { args } ' )
41- return [out ]
42-
43110 # List of SQL strings or callables
44- else :
111+ elif isinstance ( args , list ) :
45112 new_args = []
46113 for arg in args :
47114 if isinstance (arg , str ):
@@ -52,14 +119,23 @@ def expand_types(args: Any) -> Optional[Union[List[str], Type[Any]]]:
52119 raise TypeError (f'unrecognized type for parameter: { arg } ' )
53120 return new_args
54121
122+ # Callable that returns a SQL string
123+ elif is_valid_callable (args ):
124+ out = args ()
125+ if not isinstance (out , str ):
126+ raise TypeError (f'unrecognized type for parameter: { args } ' )
127+ return [out ]
128+
129+ raise TypeError (f'unrecognized type for parameter: { args } ' )
130+
55131
56132def _func (
57133 func : Optional [Callable [..., Any ]] = None ,
58134 * ,
59135 name : Optional [str ] = None ,
60136 args : Optional [ParameterType ] = None ,
61137 returns : Optional [ReturnType ] = None ,
62- include_masks : bool = False ,
138+ with_null_masks : bool = False ,
63139 function_type : str = 'udf' ,
64140) -> Callable [..., Any ]:
65141 """Generic wrapper for UDF and TVF decorators."""
@@ -69,7 +145,7 @@ def _func(
69145 name = name ,
70146 args = expand_types (args ),
71147 returns = expand_types (returns ),
72- include_masks = include_masks ,
148+ with_null_masks = with_null_masks ,
73149 function_type = function_type ,
74150 ).items () if v is not None
75151 }
@@ -79,12 +155,21 @@ def _func(
79155 # in at that time.
80156 if func is None :
81157 def decorate (func : Callable [..., Any ]) -> Callable [..., Any ]:
158+ if with_null_masks :
159+ verify_masks (func )
160+
82161 def wrapper (* args : Any , ** kwargs : Any ) -> Callable [..., Any ]:
83162 return func (* args , ** kwargs ) # type: ignore
163+
84164 wrapper ._singlestoredb_attrs = _singlestoredb_attrs # type: ignore
165+
85166 return functools .wraps (func )(wrapper )
167+
86168 return decorate
87169
170+ if with_null_masks :
171+ verify_masks (func )
172+
88173 def wrapper (* args : Any , ** kwargs : Any ) -> Callable [..., Any ]:
89174 return func (* args , ** kwargs ) # type: ignore
90175
@@ -99,10 +184,9 @@ def udf(
99184 name : Optional [str ] = None ,
100185 args : Optional [ParameterType ] = None ,
101186 returns : Optional [ReturnType ] = None ,
102- include_masks : bool = False ,
103187) -> Callable [..., Any ]:
104188 """
105- Apply attributes to a UDF.
189+ Define a user-defined function ( UDF) .
106190
107191 Parameters
108192 ----------
@@ -126,10 +210,6 @@ def udf(
126210 returns : str, optional
127211 Specifies the return data type of the function. If not specified,
128212 the type annotation from the function is used.
129- include_masks : bool, optional
130- Should boolean masks be included with each input parameter to indicate
131- which elements are NULL? This is only used when a input parameters are
132- configured to a vector type (numpy, pandas, polars, arrow).
133213
134214 Returns
135215 -------
@@ -141,7 +221,55 @@ def udf(
141221 name = name ,
142222 args = args ,
143223 returns = returns ,
144- include_masks = include_masks ,
224+ with_null_masks = False ,
225+ function_type = 'udf' ,
226+ )
227+
228+
229+ def udf_with_null_masks (
230+ func : Optional [Callable [..., Any ]] = None ,
231+ * ,
232+ name : Optional [str ] = None ,
233+ args : Optional [ParameterType ] = None ,
234+ returns : Optional [ReturnType ] = None ,
235+ ) -> Callable [..., Any ]:
236+ """
237+ Define a user-defined function (UDF) with null masks.
238+
239+ Parameters
240+ ----------
241+ func : callable, optional
242+ The UDF to apply parameters to
243+ name : str, optional
244+ The name to use for the UDF in the database
245+ args : str | Callable | List[str | Callable], optional
246+ Specifies the data types of the function arguments. Typically,
247+ the function data types are derived from the function parameter
248+ annotations. These annotations can be overridden. If the function
249+ takes a single type for all parameters, `args` can be set to a
250+ SQL string describing all parameters. If the function takes more
251+ than one parameter and all of the parameters are being manually
252+ defined, a list of SQL strings may be used (one for each parameter).
253+ A dictionary of SQL strings may be used to specify a parameter type
254+ for a subset of parameters; the keys are the names of the
255+ function parameters. Callables may also be used for datatypes. This
256+ is primarily for using the functions in the ``dtypes`` module that
257+ are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
258+ returns : str, optional
259+ Specifies the return data type of the function. If not specified,
260+ the type annotation from the function is used.
261+
262+ Returns
263+ -------
264+ Callable
265+
266+ """
267+ return _func (
268+ func = func ,
269+ name = name ,
270+ args = args ,
271+ returns = returns ,
272+ with_null_masks = True ,
145273 function_type = 'udf' ,
146274 )
147275
@@ -152,10 +280,57 @@ def tvf(
152280 name : Optional [str ] = None ,
153281 args : Optional [ParameterType ] = None ,
154282 returns : Optional [ReturnType ] = None ,
155- include_masks : bool = False ,
156283) -> Callable [..., Any ]:
157284 """
158- Apply attributes to a TVF.
285+ Define a table-valued function (TVF).
286+
287+ Parameters
288+ ----------
289+ func : callable, optional
290+ The TVF to apply parameters to
291+ name : str, optional
292+ The name to use for the TVF in the database
293+ args : str | Callable | List[str | Callable], optional
294+ Specifies the data types of the function arguments. Typically,
295+ the function data types are derived from the function parameter
296+ annotations. These annotations can be overridden. If the function
297+ takes a single type for all parameters, `args` can be set to a
298+ SQL string describing all parameters. If the function takes more
299+ than one parameter and all of the parameters are being manually
300+ defined, a list of SQL strings may be used (one for each parameter).
301+ A dictionary of SQL strings may be used to specify a parameter type
302+ for a subset of parameters; the keys are the names of the
303+ function parameters. Callables may also be used for datatypes. This
304+ is primarily for using the functions in the ``dtypes`` module that
305+ are associated with SQL types with all default options (e.g., ``dt.FLOAT``).
306+ returns : str, optional
307+ Specifies the return data type of the function. If not specified,
308+ the type annotation from the function is used.
309+
310+ Returns
311+ -------
312+ Callable
313+
314+ """
315+ return _func (
316+ func = func ,
317+ name = name ,
318+ args = args ,
319+ returns = returns ,
320+ with_null_masks = False ,
321+ function_type = 'tvf' ,
322+ )
323+
324+
325+ def tvf_with_null_masks (
326+ func : Optional [Callable [..., Any ]] = None ,
327+ * ,
328+ name : Optional [str ] = None ,
329+ args : Optional [ParameterType ] = None ,
330+ returns : Optional [ReturnType ] = None ,
331+ ) -> Callable [..., Any ]:
332+ """
333+ Define a table-valued function (TVF) using null masks.
159334
160335 Parameters
161336 ----------
@@ -179,10 +354,6 @@ def tvf(
179354 returns : str, optional
180355 Specifies the return data type of the function. If not specified,
181356 the type annotation from the function is used.
182- include_masks : bool, optional
183- Should boolean masks be included with each input parameter to indicate
184- which elements are NULL? This is only used when a input parameters are
185- configured to a vector type (numpy, pandas, polars, arrow).
186357
187358 Returns
188359 -------
@@ -194,6 +365,6 @@ def tvf(
194365 name = name ,
195366 args = args ,
196367 returns = returns ,
197- include_masks = include_masks ,
368+ with_null_masks = True ,
198369 function_type = 'tvf' ,
199370 )
0 commit comments