diff --git a/docs/source/user-guide/common-operations/udf-and-udfa.rst b/docs/source/user-guide/common-operations/udf-and-udfa.rst index d554e1e25..feed436b2 100644 --- a/docs/source/user-guide/common-operations/udf-and-udfa.rst +++ b/docs/source/user-guide/common-operations/udf-and-udfa.rst @@ -160,6 +160,17 @@ also see how the inputs to ``update`` and ``merge`` differ. df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")]) +FAQ +^^^ + +**How do I return a list from a UDAF?** +Use a list-valued scalar and declare list types for both the return and state +definitions. Returning a ``pyarrow.Array`` from ``evaluate`` is not supported +unless you convert it to a list scalar. For example, in ``evaluate`` you can +return ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and register the +UDAF with ``return_type=pa.list_(pa.timestamp("ms"))`` and +``state_type=[pa.list_(pa.timestamp("ms"))]``. + Window Functions ---------------- diff --git a/python/datafusion/user_defined.py b/python/datafusion/user_defined.py index 5dd626568..d4ebfe049 100644 --- a/python/datafusion/user_defined.py +++ b/python/datafusion/user_defined.py @@ -310,7 +310,21 @@ def merge(self, states: list[pa.Array]) -> None: @abstractmethod def evaluate(self) -> pa.Scalar: - """Return the resultant value.""" + """Return the resultant value. + + If you need to return a list, wrap it in a scalar with the correct + list type, for example:: + + import pyarrow as pa + + return pa.scalar( + [pa.scalar("2024-01-01T00:00:00Z")], + type=pa.list_(pa.timestamp("ms")), + ) + + Returning a ``pyarrow.Array`` from ``evaluate`` is not supported unless + you explicitly convert it to a list-valued scalar. + """ class AggregateUDFExportable(Protocol): diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py index 453ff6f4f..cfbbbca1c 100644 --- a/python/tests/test_udaf.py +++ b/python/tests/test_udaf.py @@ -17,6 +17,8 @@ from __future__ import annotations +from datetime import datetime, timezone + import pyarrow as pa import pyarrow.compute as pc import pytest @@ -58,6 +60,25 @@ def state(self) -> list[pa.Scalar]: return [self._sum] +class CollectTimestamps(Accumulator): + def __init__(self): + self._values: list[datetime] = [] + + def state(self) -> list[pa.Scalar]: + return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))] + + def update(self, values: pa.Array) -> None: + self._values.extend(values.to_pylist()) + + def merge(self, states: list[pa.Array]) -> None: + for state in states[0].to_pylist(): + if state is not None: + self._values.extend(state) + + def evaluate(self) -> pa.Scalar: + return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns"))) + + @pytest.fixture def df(ctx): # create a RecordBatch and a new DataFrame from it @@ -217,3 +238,30 @@ def test_register_udaf(ctx, df) -> None: df_result = ctx.sql("select summarize(b) from test_table") assert df_result.collect()[0][0][0].as_py() == 14.0 + + +def test_udaf_list_timestamp_return(ctx) -> None: + timestamps = [ + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 1, 2, tzinfo=timezone.utc), + ] + batch = pa.RecordBatch.from_arrays( + [pa.array(timestamps, type=pa.timestamp("ns"))], + names=["ts"], + ) + df = ctx.create_dataframe([[batch]], name="timestamp_table") + + collect = udaf( + CollectTimestamps, + pa.timestamp("ns"), + pa.list_(pa.timestamp("ns")), + [pa.list_(pa.timestamp("ns"))], + volatility="immutable", + ) + + result = df.aggregate([], [collect(column("ts"))]).collect()[0] + + assert result.column(0) == pa.array( + [timestamps], + type=pa.list_(pa.timestamp("ns")), + ) diff --git a/src/udaf.rs b/src/udaf.rs index 262366a8a..883170adf 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -17,7 +17,7 @@ use std::sync::Arc; -use datafusion::arrow::array::{Array, ArrayRef}; +use datafusion::arrow::array::ArrayRef; use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::common::ScalarValue; @@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple}; use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, validate_pycapsule}; +use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule}; #[derive(Debug)] struct RustAccumulator { @@ -47,24 +47,30 @@ impl RustAccumulator { impl Accumulator for RustAccumulator { fn state(&mut self) -> Result> { - Python::attach(|py| { - self.accum - .bind(py) - .call_method0("state")? - .extract::>() + Python::attach(|py| -> PyResult> { + let values = self.accum.bind(py).call_method0("state")?; + let mut scalars = Vec::new(); + for item in values.try_iter()? { + let item: Bound<'_, PyAny> = item?; + let scalar = match item.extract::() { + Ok(py_scalar) => py_scalar.0, + Err(_) => py_obj_to_scalar_value(py, item.unbind())?, + }; + scalars.push(scalar); + } + Ok(scalars) }) - .map(|v| v.into_iter().map(|x| x.0).collect()) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } fn evaluate(&mut self) -> Result { - Python::attach(|py| { - self.accum - .bind(py) - .call_method0("evaluate")? - .extract::() + Python::attach(|py| -> PyResult { + let value = self.accum.bind(py).call_method0("evaluate")?; + match value.extract::() { + Ok(py_scalar) => Ok(py_scalar.0), + Err(_) => py_obj_to_scalar_value(py, value.unbind()), + } }) - .map(|v| v.0) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } @@ -73,7 +79,7 @@ impl Accumulator for RustAccumulator { // 1. cast args to Pyarrow array let py_args = values .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .map(|arg| arg.to_data().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; @@ -94,7 +100,7 @@ impl Accumulator for RustAccumulator { .iter() .map(|state| { state - .into_data() + .to_data() .to_pyarrow(py) .map_err(|e| DataFusionError::Execution(format!("{e}"))) }) @@ -119,7 +125,7 @@ impl Accumulator for RustAccumulator { // 1. cast args to Pyarrow array let py_args = values .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .map(|arg| arg.to_data().to_pyarrow(py).unwrap()) .collect::>(); let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?; @@ -144,7 +150,7 @@ impl Accumulator for RustAccumulator { } pub fn to_rust_accumulator(accum: Py) -> AccumulatorFactoryFunction { - Arc::new(move |_| -> Result> { + Arc::new(move |_args| -> Result> { let accum = Python::attach(|py| { accum .call0(py) diff --git a/src/utils.rs b/src/utils.rs index eede34907..3b97ffb88 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,6 +19,10 @@ use std::future::Future; use std::sync::{Arc, OnceLock}; use std::time::Duration; +use datafusion::arrow::array::{make_array, ArrayData, ListArray}; +use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer}; +use datafusion::arrow::datatypes::Field; +use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::common::ScalarValue; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionContext; @@ -203,6 +207,41 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult()?; + let array_attr = pa.getattr("Array")?; + let array_type = array_attr.downcast::()?; + let chunked_array_attr = pa.getattr("ChunkedArray")?; + let chunked_array_type = chunked_array_attr.downcast::()?; + + let obj_ref = obj.bind(py); + + if obj_ref.is_instance(scalar_type)? { + let py_scalar = PyScalarValue::extract_bound(obj_ref) + .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; + return Ok(py_scalar.into()); + } + + if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? { + let array_obj = if obj_ref.is_instance(chunked_array_type)? { + obj_ref.call_method0("combine_chunks")?.unbind() + } else { + obj_ref.clone().unbind() + }; + let array_bound = array_obj.bind(py); + let array_data = ArrayData::from_pyarrow_bound(array_bound) + .map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?; + let array = make_array(array_data); + let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32])); + let list_array = Arc::new(ListArray::new( + Arc::new(Field::new_list_field(array.data_type().clone(), true)), + offsets, + array, + None, + )); + + return Ok(ScalarValue::List(list_array)); + } // Convert Python object to PyArrow scalar let scalar = pa.call_method1("scalar", (obj,))?;