From f03a2d1d4c31f9aa8aeab200d76b82ab9d8d1aec Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Fri, 9 Jan 2026 11:59:30 +0100 Subject: [PATCH] Refactor InListExpr to use modular StaticFilter architecture Introduces the StaticFilter trait to decouple membership testing from InListExpr. Migrates existing HashSet optimizations into primitive_filter.rs to maintain performance parity while enabling future specialized implementations. Triggers for all constant IN lists. (cherry picked from commit 797b7fc155d7c830717b1305815b75ddece057cb) --- .../physical-expr/src/expressions/in_list.rs | 501 +----------------- .../src/expressions/in_list/nested_filter.rs | 175 ++++++ .../expressions/in_list/primitive_filter.rs | 283 ++++++++++ .../src/expressions/in_list/result.rs | 116 ++++ .../src/expressions/in_list/static_filter.rs | 49 ++ .../src/expressions/in_list/strategy.rs | 50 ++ 6 files changed, 683 insertions(+), 491 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/in_list/nested_filter.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/result.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/static_filter.rs create mode 100644 datafusion/physical-expr/src/expressions/in_list/strategy.rs diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 2377edf9375cf..04c1459ba4eb7 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -26,29 +26,24 @@ use crate::physical_expr::physical_exprs_bag_equal; use arrow::array::*; use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::SortOptions; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::kernels::cmp::eq as arrow_eq; -use arrow::compute::{SortOptions, take}; use arrow::datatypes::*; -use arrow::util::bit_iterator::BitIndexIterator; -use datafusion_common::hash_utils::with_hashes; + use datafusion_common::{ - DFSchema, HashSet, Result, ScalarValue, assert_or_internal_err, exec_datafusion_err, - exec_err, + DFSchema, Result, ScalarValue, assert_or_internal_err, exec_err, }; use datafusion_expr::{ColumnarValue, expr_vec_fmt}; -use datafusion_common::HashMap; -use datafusion_common::hash_utils::RandomState; -use hashbrown::hash_map::RawEntryMut; - -/// Trait for InList static filters -trait StaticFilter { - fn null_count(&self) -> usize; +mod nested_filter; +mod primitive_filter; +mod result; +mod static_filter; +mod strategy; - /// Checks if values in `v` are contained in the filter - fn contains(&self, v: &dyn Array, negated: bool) -> Result; -} +use static_filter::StaticFilter; +use strategy::instantiate_static_filter; /// InList pub struct InListExpr { @@ -68,89 +63,7 @@ impl Debug for InListExpr { } } -/// Static filter for InList that stores the array and hash set for O(1) lookups -#[derive(Debug, Clone)] -struct ArrayStaticFilter { - in_array: ArrayRef, - state: RandomState, - /// Used to provide a lookup from value to in list index - /// - /// Note: usize::hash is not used, instead the raw entry - /// API is used to store entries w.r.t their value - map: HashMap, -} - -impl StaticFilter for ArrayStaticFilter { - fn null_count(&self) -> usize { - self.in_array.null_count() - } - - /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Null type comparisons always return null (SQL three-valued logic) - if v.data_type() == &DataType::Null - || self.in_array.data_type() == &DataType::Null - { - let nulls = NullBuffer::new_null(v.len()); - return Ok(BooleanArray::new( - BooleanBuffer::new_unset(v.len()), - Some(nulls), - )); - } - - // Unwrap dictionary-encoded needles when the value type matches - // in_array, evaluating against the dictionary values and mapping - // back via keys. - downcast_dictionary_array! { - v => { - // Only unwrap when the haystack (in_array) type matches - // the dictionary value type - if v.values().data_type() == self.in_array.data_type() { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())); - } - } - _ => {} - } - - let needle_nulls = v.logical_nulls(); - let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; - - with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; - Ok((0..v.len()) - .map(|i| { - // SQL three-valued logic: null IN (...) is always null - if needle_nulls.is_some_and(|nulls| nulls.is_null(i)) { - return None; - } - - let hash = hashes[i]; - let contains = self - .map - .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) - .is_some(); - - match contains { - true => Some(!negated), - false if haystack_has_nulls => None, - false => Some(negated), - } - }) - .collect()) - }) - } -} - /// Returns true if Arrow's vectorized `eq` kernel supports this data type. -/// -/// Supported: primitives, boolean, strings (Utf8/LargeUtf8/Utf8View), -/// binary (Binary/LargeBinary/BinaryView/FixedSizeBinary), Null, and -/// Dictionary-encoded variants of the above. -/// Unsupported: nested types (Struct, List, Map, Union) and RunEndEncoded. fn supports_arrow_eq(dt: &DataType) -> bool { use DataType::*; match dt { @@ -160,400 +73,6 @@ fn supports_arrow_eq(dt: &DataType) -> bool { } } -fn instantiate_static_filter( - in_array: ArrayRef, -) -> Result> { - match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - // Float primitive types (use ordered wrappers for Hash/Eq) - DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), - DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } - } -} - -impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave - fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set - if in_array.data_type() == &DataType::Null { - return Ok(ArrayStaticFilter { - in_array, - state: RandomState::default(), - map: HashMap::with_hasher(()), - }); - } - - let state = RandomState::default(); - let mut map: HashMap = HashMap::with_hasher(()); - - with_hashes([&in_array], &state, |hashes| -> Result<()> { - let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; - - match in_array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) - } - None => (0..in_array.len()).for_each(insert_value), - } - - Ok(()) - })?; - - Ok(Self { - in_array, - state, - map, - }) - } -} - -/// Wrapper for f32 that implements Hash and Eq using bit comparison. -/// This treats NaN values as equal to each other when they have the same bit pattern. -#[derive(Clone, Copy)] -struct OrderedFloat32(f32); - -impl Hash for OrderedFloat32 { - fn hash(&self, state: &mut H) { - self.0.to_ne_bytes().hash(state); - } -} - -impl PartialEq for OrderedFloat32 { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } -} - -impl Eq for OrderedFloat32 {} - -impl From for OrderedFloat32 { - fn from(v: f32) -> Self { - Self(v) - } -} - -/// Wrapper for f64 that implements Hash and Eq using bit comparison. -/// This treats NaN values as equal to each other when they have the same bit pattern. -#[derive(Clone, Copy)] -struct OrderedFloat64(f64); - -impl Hash for OrderedFloat64 { - fn hash(&self, state: &mut H) { - self.0.to_ne_bytes().hash(state); - } -} - -impl PartialEq for OrderedFloat64 { - fn eq(&self, other: &Self) -> bool { - self.0.to_bits() == other.0.to_bits() - } -} - -impl Eq for OrderedFloat64 {} - -impl From for OrderedFloat64 { - fn from(v: f64) -> Self { - Self(v) - } -} - -// Macro to generate specialized StaticFilter implementations for primitive types -macro_rules! primitive_static_filter { - ($Name:ident, $ArrowType:ty) => { - struct $Name { - null_count: usize, - values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; - - // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: - // ("-" means the value doesn't affect the result) - // - // | needle_null | haystack_null | negated | in set? | result | - // |-------------|---------------|---------|---------|--------| - // | true | - | false | - | null | - // | true | - | true | - | null | - // | false | true | false | yes | true | - // | false | true | false | no | null | - // | false | true | true | yes | false | - // | false | true | true | no | null | - // | false | false | false | yes | true | - // | false | false | false | no | false | - // | false | false | true | yes | false | - // | false | false | true | no | true | - - // Compute the "contains" result using collect_bool (fast batched approach) - // This ignores nulls - we handle them separately - let contains_buffer = if negated { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - !self.values.contains(&needle_values[i]) - }) - } else { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - self.values.contains(&needle_values[i]) - }) - }; - - // Compute the null mask - // Output is null when: - // 1. needle value is null, OR - // 2. needle value is not in set AND haystack has nulls - let result_nulls = match (needle_has_nulls, haystack_has_nulls) { - (false, false) => { - // No nulls anywhere - None - } - (true, false) => { - // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() - } - (false, true) => { - // Only haystack has nulls - result is null when value not in set - // Valid (not null) when original "in set" is true - // For NOT IN: contains_buffer = !original, so validity = !contains_buffer - let validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - Some(NullBuffer::new(validity)) - } - (true, true) => { - // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); - - // Valid when original "in set" is true (see above) - let haystack_validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - - // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; - Some(NullBuffer::new(combined_validity)) - } - }; - - Ok(BooleanArray::new(contains_buffer, result_nulls)) - } - } - }; -} - -// Generate specialized filters for all integer primitive types -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); -primitive_static_filter!(Int32StaticFilter, Int32Type); -primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); -primitive_static_filter!(UInt32StaticFilter, UInt32Type); -primitive_static_filter!(UInt64StaticFilter, UInt64Type); - -// Macro to generate specialized StaticFilter implementations for float types -// Floats require a wrapper type (OrderedFloat*) to implement Hash/Eq due to NaN semantics -macro_rules! float_static_filter { - ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { - struct $Name { - null_count: usize, - values: HashSet<$OrderedType>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(<$OrderedType>::from(v)); - } - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let needle_values = v.values(); - let needle_nulls = v.nulls(); - let needle_has_nulls = v.null_count() > 0; - - // Truth table for `value [NOT] IN (set)` with SQL three-valued logic: - // ("-" means the value doesn't affect the result) - // - // | needle_null | haystack_null | negated | in set? | result | - // |-------------|---------------|---------|---------|--------| - // | true | - | false | - | null | - // | true | - | true | - | null | - // | false | true | false | yes | true | - // | false | true | false | no | null | - // | false | true | true | yes | false | - // | false | true | true | no | null | - // | false | false | false | yes | true | - // | false | false | false | no | false | - // | false | false | true | yes | false | - // | false | false | true | no | true | - - // Compute the "contains" result using collect_bool (fast batched approach) - // This ignores nulls - we handle them separately - let contains_buffer = if negated { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - !self.values.contains(&<$OrderedType>::from(needle_values[i])) - }) - } else { - BooleanBuffer::collect_bool(needle_values.len(), |i| { - self.values.contains(&<$OrderedType>::from(needle_values[i])) - }) - }; - - // Compute the null mask - // Output is null when: - // 1. needle value is null, OR - // 2. needle value is not in set AND haystack has nulls - let result_nulls = match (needle_has_nulls, haystack_has_nulls) { - (false, false) => { - // No nulls anywhere - None - } - (true, false) => { - // Only needle has nulls - just use needle's null mask - needle_nulls.cloned() - } - (false, true) => { - // Only haystack has nulls - result is null when value not in set - // Valid (not null) when original "in set" is true - // For NOT IN: contains_buffer = !original, so validity = !contains_buffer - let validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - Some(NullBuffer::new(validity)) - } - (true, true) => { - // Both have nulls - combine needle nulls with haystack-induced nulls - let needle_validity = needle_nulls.map(|n| n.inner().clone()) - .unwrap_or_else(|| BooleanBuffer::new_set(needle_values.len())); - - // Valid when original "in set" is true (see above) - let haystack_validity = if negated { - !&contains_buffer - } else { - contains_buffer.clone() - }; - - // Combined validity: valid only where both are valid - let combined_validity = &needle_validity & &haystack_validity; - Some(NullBuffer::new(combined_validity)) - } - }; - - Ok(BooleanArray::new(contains_buffer, result_nulls)) - } - } - }; -} - -// Generate specialized filters for float types using ordered wrappers -float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); -float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); - /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], diff --git a/datafusion/physical-expr/src/expressions/in_list/nested_filter.rs b/datafusion/physical-expr/src/expressions/in_list/nested_filter.rs new file mode 100644 index 0000000000000..903bb6abe43b2 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/nested_filter.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Fallback filter for nested/complex types (List, Struct, Map, Union, etc.) + +use arrow::array::{ + Array, ArrayRef, BooleanArray, downcast_array, downcast_dictionary_array, + make_comparator, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::{SortOptions, take}; +use arrow::datatypes::DataType; +use arrow::util::bit_iterator::BitIndexIterator; +use datafusion_common::Result; +use datafusion_common::hash_utils::with_hashes; + +use datafusion_common::hash_utils::RandomState; +use hashbrown::HashTable; + +use super::result::build_in_list_result; +use super::static_filter::StaticFilter; + +/// Fallback filter for nested/complex types (List, Struct, Map, Union, etc.) +/// +/// Uses dynamic comparator via `make_comparator` since these types don't have +/// a simple typed comparison. For primitive and byte array types, use the +/// specialized filters instead (PrimitiveFilter, ByteArrayFilter, etc.) +#[derive(Debug, Clone)] +pub(crate) struct NestedTypeFilter { + in_array: ArrayRef, + state: RandomState, + /// Stores indices into `in_array` for O(1) lookups. + table: HashTable, +} + +impl NestedTypeFilter { + /// Creates a filter for nested/complex array types. + /// + /// This filter uses dynamic comparison and should only be used for types + /// that don't have specialized filters (List, Struct, Map, Union). + pub(crate) fn try_new(in_array: ArrayRef) -> Result { + // Null type has no natural order - return empty hash set + if in_array.data_type() == &DataType::Null { + return Ok(Self { + in_array, + state: RandomState::default(), + table: HashTable::new(), + }); + } + + let state = RandomState::default(); + let table = Self::build_haystack_table(&in_array, &state)?; + + Ok(Self { + in_array, + state, + table, + }) + } + + /// Build a hash table from haystack values for O(1) lookups. + /// + /// Each unique non-null value's index is stored, keyed by its hash. + /// Uses dynamic comparison via `make_comparator` for complex types. + fn build_haystack_table( + haystack: &ArrayRef, + state: &RandomState, + ) -> Result> { + let mut table = HashTable::new(); + + with_hashes([haystack.as_ref()], state, |hashes| -> Result<()> { + let cmp = make_comparator(haystack, haystack, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + // Only insert if not already present (deduplication) + if table.find(hash, |&x| cmp(x, idx).is_eq()).is_none() { + table.insert_unique(hash, idx, |&x| hashes[x]); + } + }; + + match haystack.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..haystack.len()).for_each(insert_value), + } + + Ok(()) + })?; + + Ok(table) + } + + /// Check which needle values exist in the haystack. + /// + /// Hashes each needle value and looks it up in the pre-built haystack table. + /// Uses dynamic comparison via `make_comparator` for complex types. + fn find_needles_in_haystack( + &self, + needles: &dyn Array, + negated: bool, + ) -> Result { + let needle_nulls = needles.logical_nulls(); + let haystack_has_nulls = self.in_array.null_count() != 0; + + with_hashes([needles], &self.state, |needle_hashes| { + let cmp = make_comparator(needles, &self.in_array, SortOptions::default())?; + + Ok(build_in_list_result( + needles.len(), + needle_nulls.as_ref(), + haystack_has_nulls, + negated, + #[inline(always)] + |i| { + let hash = needle_hashes[i]; + self.table.find(hash, |&idx| cmp(i, idx).is_eq()).is_some() + }, + )) + }) + } +} + +impl StaticFilter for NestedTypeFilter { + fn null_count(&self) -> usize { + self.in_array.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Null type comparisons always return null (SQL three-valued logic) + if v.data_type() == &DataType::Null + || self.in_array.data_type() == &DataType::Null + { + let nulls = NullBuffer::new_null(v.len()); + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(nulls), + )); + } + + // Unwrap dictionary-encoded needles when the value type matches + // in_array, evaluating against the dictionary values and mapping + // back via keys. + downcast_dictionary_array! { + v => { + // Only unwrap when the haystack (in_array) type matches + // the dictionary value type + if v.values().data_type() == self.in_array.data_type() { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())); + } + } + _ => {} + } + + self.find_needles_in_haystack(v, negated) + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs new file mode 100644 index 0000000000000..b5367a9f686a8 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/primitive_filter.rs @@ -0,0 +1,283 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Specialized primitive type filters for InList expressions + +use arrow::array::{ + Array, ArrayRef, AsArray, BooleanArray, downcast_array, downcast_dictionary_array, +}; +use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::compute::take; +use arrow::datatypes::*; +use datafusion_common::{HashSet, Result, exec_datafusion_err}; +use std::hash::{Hash, Hasher}; + +use super::static_filter::StaticFilter; + +/// Wrapper for f32 that implements Hash and Eq using bit comparison. +#[derive(Clone, Copy)] +pub(crate) struct OrderedFloat32(pub(crate) f32); + +impl Hash for OrderedFloat32 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat32 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for OrderedFloat32 {} + +impl From for OrderedFloat32 { + fn from(v: f32) -> Self { + Self(v) + } +} + +/// Wrapper for f64 that implements Hash and Eq using bit comparison. +#[derive(Clone, Copy)] +pub(crate) struct OrderedFloat64(pub(crate) f64); + +impl Hash for OrderedFloat64 { + fn hash(&self, state: &mut H) { + self.0.to_ne_bytes().hash(state); + } +} + +impl PartialEq for OrderedFloat64 { + fn eq(&self, other: &Self) -> bool { + self.0.to_bits() == other.0.to_bits() + } +} + +impl Eq for OrderedFloat64 {} + +impl From for OrderedFloat64 { + fn from(v: f64) -> Self { + Self(v) + } +} + +macro_rules! primitive_static_filter { + ($Name:ident, $ArrowType:ty) => { + pub(crate) struct $Name { + null_count: usize, + values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, + } + + impl $Name { + pub(crate) fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(v); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let haystack_has_nulls = self.null_count > 0; + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&needle_values[i]) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&needle_values[i]) + }) + }; + + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => None, + (true, false) => needle_nulls.cloned(), + (false, true) => { + let validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) + } + (true, true) => { + let needle_validity = needle_nulls + .map(|n| n.inner().clone()) + .unwrap_or_else(|| { + BooleanBuffer::new_set(needle_values.len()) + }); + let haystack_validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) + } + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) + } + } + }; +} + +primitive_static_filter!(Int8StaticFilter, Int8Type); +primitive_static_filter!(Int16StaticFilter, Int16Type); +primitive_static_filter!(Int32StaticFilter, Int32Type); +primitive_static_filter!(Int64StaticFilter, Int64Type); +primitive_static_filter!(UInt8StaticFilter, UInt8Type); +primitive_static_filter!(UInt16StaticFilter, UInt16Type); +primitive_static_filter!(UInt32StaticFilter, UInt32Type); +primitive_static_filter!(UInt64StaticFilter, UInt64Type); + +macro_rules! float_static_filter { + ($Name:ident, $ArrowType:ty, $OrderedType:ty) => { + pub(crate) struct $Name { + null_count: usize, + values: HashSet<$OrderedType>, + } + + impl $Name { + pub(crate) fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(<$OrderedType>::from(v)); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let haystack_has_nulls = self.null_count > 0; + let needle_values = v.values(); + let needle_nulls = v.nulls(); + let needle_has_nulls = v.null_count() > 0; + + let contains_buffer = if negated { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + !self.values.contains(&<$OrderedType>::from(needle_values[i])) + }) + } else { + BooleanBuffer::collect_bool(needle_values.len(), |i| { + self.values.contains(&<$OrderedType>::from(needle_values[i])) + }) + }; + + let result_nulls = match (needle_has_nulls, haystack_has_nulls) { + (false, false) => None, + (true, false) => needle_nulls.cloned(), + (false, true) => { + let validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + Some(NullBuffer::new(validity)) + } + (true, true) => { + let needle_validity = needle_nulls + .map(|n| n.inner().clone()) + .unwrap_or_else(|| { + BooleanBuffer::new_set(needle_values.len()) + }); + let haystack_validity = if negated { + !&contains_buffer + } else { + contains_buffer.clone() + }; + let combined_validity = &needle_validity & &haystack_validity; + Some(NullBuffer::new(combined_validity)) + } + }; + + Ok(BooleanArray::new(contains_buffer, result_nulls)) + } + } + }; +} + +float_static_filter!(Float32StaticFilter, Float32Type, OrderedFloat32); +float_static_filter!(Float64StaticFilter, Float64Type, OrderedFloat64); diff --git a/datafusion/physical-expr/src/expressions/in_list/result.rs b/datafusion/physical-expr/src/expressions/in_list/result.rs new file mode 100644 index 0000000000000..9ee20a7cc9707 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/result.rs @@ -0,0 +1,116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Result building helpers for InList operations +//! +//! This module provides unified logic for building BooleanArray results +//! from IN list membership tests, handling null propagation correctly +//! according to SQL three-valued logic. + +use arrow::array::BooleanArray; +use arrow::buffer::{BooleanBuffer, NullBuffer}; + +// ============================================================================= +// RESULT BUILDER FOR IN LIST OPERATIONS +// ============================================================================= +// +// Truth table for (needle_nulls, haystack_has_nulls, negated): +// (Some, true, false) → values: valid & contains, nulls: valid & contains +// (None, true, false) → values: contains, nulls: contains +// (Some, true, true) → values: valid ^ (valid & contains), nulls: valid & contains +// (None, true, true) → values: !contains, nulls: contains +// (Some, false, false) → values: valid & contains, nulls: valid +// (Some, false, true) → values: valid & !contains, nulls: valid +// (None, false, false) → values: contains, nulls: none +// (None, false, true) → values: !contains, nulls: none + +/// Builds a BooleanArray result for IN list operations (optimized for cheap contains). +/// +/// This function handles the complex null propagation logic for SQL IN lists: +/// - If the needle value is null, the result is null +/// - If the needle is not in the set AND the haystack has nulls, the result is null +/// - Otherwise, the result is true/false based on membership and negation +/// +/// This version computes contains for ALL positions (including nulls), then applies +/// null masking via bitmap operations. This is optimal for cheap contains checks +/// (like DirectProbeFilter) where the branch overhead exceeds the check cost. +#[inline] +pub(crate) fn build_in_list_result( + len: usize, + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains: C, +) -> BooleanArray +where + C: FnMut(usize) -> bool, +{ + // Always compute the contains buffer without checking nulls in the loop. + // The null check inside the loop hurts vectorization and branch prediction. + // Nulls are handled by build_result_from_contains using bitmap operations. + let contains_buf = BooleanBuffer::collect_bool(len, contains); + build_result_from_contains(needle_nulls, haystack_has_nulls, negated, contains_buf) +} + +/// Builds a BooleanArray result from a pre-computed contains buffer. +/// +/// This version does NOT assume contains_buf is pre-masked at null positions. +/// It handles nulls using bitmap operations which are more vectorization-friendly. +#[inline] +pub(crate) fn build_result_from_contains( + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains_buf: BooleanBuffer, +) -> BooleanArray { + match (needle_nulls, haystack_has_nulls, negated) { + // Haystack has nulls: result is null unless value is found + (Some(v), true, false) => { + // values: valid & contains, nulls: valid & contains + // Result is valid (not null) only when needle is valid AND found in haystack + let values = v.inner() & &contains_buf; + BooleanArray::new(values.clone(), Some(NullBuffer::new(values))) + } + (None, true, false) => { + BooleanArray::new(contains_buf.clone(), Some(NullBuffer::new(contains_buf))) + } + (Some(v), true, true) => { + // NOT IN with nulls: true if valid and not found, null if found or needle null + // values: valid & !contains, nulls: valid & contains + // Result is valid only when needle is valid AND found (because NOT IN with + // haystack nulls returns NULL when value isn't definitively excluded) + let valid = v.inner(); + let values = valid & &(!&contains_buf); + let nulls = valid & &contains_buf; + BooleanArray::new(values, Some(NullBuffer::new(nulls))) + } + (None, true, true) => { + BooleanArray::new(!&contains_buf, Some(NullBuffer::new(contains_buf))) + } + // Haystack has no nulls: result validity follows needle validity + (Some(v), false, false) => { + // values: valid & contains (mask out nulls), nulls: valid + BooleanArray::new(v.inner() & &contains_buf, Some(v.clone())) + } + (Some(v), false, true) => { + // values: valid & !contains, nulls: valid + BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) + } + (None, false, false) => BooleanArray::new(contains_buf, None), + (None, false, true) => BooleanArray::new(!&contains_buf, None), + } +} diff --git a/datafusion/physical-expr/src/expressions/in_list/static_filter.rs b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs new file mode 100644 index 0000000000000..9dbc00d35125c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/static_filter.rs @@ -0,0 +1,49 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Core trait for InList static filters + +use arrow::array::{Array, BooleanArray}; +use datafusion_common::Result; + +/// Trait for InList static filters. +/// +/// Static filters are pre-computed lookup structures that enable efficient +/// membership testing for IN list expressions. Different implementations +/// optimize for different data types: +/// +/// - [`super::primitive_filter::BitmapFilter`]: O(1) bit test for u8/u16 +/// - [`super::primitive_filter::BranchlessFilter`]: Unrolled OR-chain for small lists +/// - [`super::primitive_filter::DirectProbeFilter`]: O(1) hash lookups for larger primitive types +/// - [`super::transform::Utf8TwoStageFilter`]: Two-stage filter for Utf8/LargeUtf8 +/// - [`super::nested_filter::NestedTypeFilter`]: Dynamic comparator for complex types +pub(crate) trait StaticFilter { + /// Returns the number of null values in the filter's haystack. + fn null_count(&self) -> usize; + + /// Checks if values in `v` are contained in the filter. + /// + /// Returns a `BooleanArray` with the same length as `v`, where each element + /// indicates whether the corresponding value is in the filter (or NOT in, + /// if `negated` is true). + /// + /// Follows SQL three-valued logic: + /// - If the needle value is null, the result is null + /// - If the needle is not found AND the haystack contains nulls, the result is null + /// - Otherwise, the result is true/false based on membership + fn contains(&self, v: &dyn Array, negated: bool) -> Result; +} diff --git a/datafusion/physical-expr/src/expressions/in_list/strategy.rs b/datafusion/physical-expr/src/expressions/in_list/strategy.rs new file mode 100644 index 0000000000000..e59798df62158 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/in_list/strategy.rs @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Filter selection strategy for InList expressions + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use datafusion_common::Result; + +use super::nested_filter::NestedTypeFilter; +use super::primitive_filter::*; +use super::static_filter::StaticFilter; + +/// Creates the optimal static filter for the given array. +/// +/// This is the main entry point for filter creation. It analyzes the array's +/// data type and size to select the best lookup strategy. +pub(crate) fn instantiate_static_filter( + in_array: ArrayRef, +) -> Result> { + match in_array.data_type() { + DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + DataType::Float32 => Ok(Arc::new(Float32StaticFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64StaticFilter::try_new(&in_array)?)), + _ => Ok(Arc::new(NestedTypeFilter::try_new(in_array)?)), + } +}