diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 4f73169ad2827..9bb180cdd93a5 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2496,6 +2496,19 @@ impl Filter { Self::try_new_internal(predicate, input) } + /// Create a new filter operator without re-validating the predicate type. + /// + /// This is intended for internal optimizer use when rearranging predicates + /// that are already known to be valid filter expressions. Like + /// [`Self::try_new`], this removes nested aliases from the predicate. + #[doc(hidden)] + pub fn new_unchecked(predicate: Expr, input: Arc) -> Self { + Self { + predicate: predicate.unalias_nested().data, + input, + } + } + fn is_allowed_filter_type(data_type: &DataType) -> bool { match data_type { // Interpret NULL as a missing boolean value. @@ -2520,10 +2533,7 @@ impl Filter { ); } - Ok(Self { - predicate: predicate.unalias_nested().data, - input, - }) + Ok(Self::new_unchecked(predicate, input)) } /// Is this filter guaranteed to return 0 or 1 row in a given instantiation? @@ -5163,6 +5173,43 @@ mod tests { assert!(filter.is_scalar()); } + #[test] + fn test_filter_new_unchecked_strips_aliases() { + let scan = Arc::new( + table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0])) + .unwrap() + .build() + .unwrap(), + ); + + let predicate = col("id").alias("employee_id").eq(lit(1i32)).alias("pred"); + let unchecked = Filter::new_unchecked(predicate, scan); + + assert_eq!(unchecked.predicate, col("id").eq(lit(1i32))); + } + + #[test] + fn test_filter_new_unchecked_skips_type_validation() { + let scan = Arc::new( + table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0])) + .unwrap() + .build() + .unwrap(), + ); + + let predicate = col("id") + lit(1i32); + + let err = Filter::try_new(predicate.clone(), Arc::clone(&scan)).unwrap_err(); + assert!( + err.to_string() + .contains("Cannot create filter with non-boolean predicate"), + "{err}" + ); + + let unchecked = Filter::new_unchecked(predicate.clone(), scan); + assert_eq!(unchecked.predicate, predicate); + } + #[test] fn test_transform_explain() { let schema = Schema::new(vec![ diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index a1a636cfef9af..6320737fbcac6 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -496,11 +496,10 @@ fn push_down_all_join( } if let Some(predicate) = conjunction(left_push) { - join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?)); + join.left = Arc::new(make_filter(predicate, join.left)?); } if let Some(predicate) = conjunction(right_push) { - join.right = - Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?)); + join.right = Arc::new(make_filter(predicate, join.right)?); } // Add any new join conditions as the non join predicates @@ -510,7 +509,7 @@ fn push_down_all_join( // wrap the join on the filter whose predicates must be kept, if any let plan = LogicalPlan::Join(join); let plan = if let Some(predicate) = conjunction(keep_predicates) { - LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?) + make_filter(predicate, Arc::new(plan))? } else { plan }; @@ -825,28 +824,21 @@ impl OptimizerRule for PushDownFilter { let Some(new_predicate) = conjunction(new_predicates) else { return plan_err!("at least one expression exists"); }; - let new_filter = LogicalPlan::Filter(Filter::try_new( - new_predicate, - child_filter.input, - )?); + let new_filter = make_filter(new_predicate, child_filter.input)?; self.rewrite(new_filter, config) } LogicalPlan::Repartition(repartition) => { let new_filter = - Filter::try_new(filter.predicate, Arc::clone(&repartition.input)) - .map(LogicalPlan::Filter)?; + make_filter(filter.predicate, Arc::clone(&repartition.input))?; insert_below(LogicalPlan::Repartition(repartition), new_filter) } LogicalPlan::Distinct(distinct) => { let new_filter = - Filter::try_new(filter.predicate, Arc::clone(distinct.input())) - .map(LogicalPlan::Filter)?; + make_filter(filter.predicate, Arc::clone(distinct.input()))?; insert_below(LogicalPlan::Distinct(distinct), new_filter) } LogicalPlan::Sort(sort) => { - let new_filter = - Filter::try_new(filter.predicate, Arc::clone(&sort.input)) - .map(LogicalPlan::Filter)?; + let new_filter = make_filter(filter.predicate, Arc::clone(&sort.input))?; insert_below(LogicalPlan::Sort(sort), new_filter) } LogicalPlan::SubqueryAlias(subquery_alias) => { @@ -863,10 +855,8 @@ impl OptimizerRule for PushDownFilter { } let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?; - let new_filter = LogicalPlan::Filter(Filter::try_new( - new_predicate, - Arc::clone(&subquery_alias.input), - )?); + let new_filter = + make_filter(new_predicate, Arc::clone(&subquery_alias.input))?; insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter) } LogicalPlan::Projection(projection) => { @@ -877,8 +867,7 @@ impl OptimizerRule for PushDownFilter { match keep_predicate { None => Ok(new_projection), Some(keep_predicate) => new_projection.map_data(|child_plan| { - Filter::try_new(keep_predicate, Arc::new(child_plan)) - .map(LogicalPlan::Filter) + make_filter(keep_predicate, Arc::new(child_plan)) }), } } else { @@ -944,10 +933,10 @@ impl OptimizerRule for PushDownFilter { let unnest_input = std::mem::take(&mut unnest.input); - let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new( + let filter_with_unnest_input = make_filter( conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty. unnest_input, - )?); + )?; // Directly assign new filter plan as the new unnest's input. // The new filter plan will go through another rewrite pass since the rule itself @@ -957,9 +946,10 @@ impl OptimizerRule for PushDownFilter { match conjunction(unnest_predicates) { None => Ok(unnest_plan), - Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter( - Filter::try_new(predicate, Arc::new(unnest_plan.data))?, - ))), + Some(predicate) => Ok(Transformed::yes(make_filter( + predicate, + Arc::new(unnest_plan.data), + )?)), } } LogicalPlan::Union(ref union) => { @@ -977,10 +967,7 @@ impl OptimizerRule for PushDownFilter { let push_predicate = replace_cols_by_name(filter.predicate.clone(), &replace_map)?; - inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new( - push_predicate, - Arc::clone(input), - )?))) + inputs.push(Arc::new(make_filter(push_predicate, Arc::clone(input))?)) } Ok(Transformed::yes(LogicalPlan::Union(Union { inputs, @@ -1250,10 +1237,7 @@ impl OptimizerRule for PushDownFilter { .inputs() .into_iter() .map(|child| { - Ok(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(child.clone()), - )?)) + make_filter(predicate.clone(), Arc::new(child.clone())) }) .collect::>>()?, None => extension_plan.node.inputs().into_iter().cloned().collect(), @@ -1264,10 +1248,7 @@ impl OptimizerRule for PushDownFilter { child_plan.with_new_exprs(child_plan.expressions(), new_children)?; let new_plan = match conjunction(keep_predicates) { - Some(predicate) => LogicalPlan::Filter(Filter::try_new( - predicate, - Arc::new(new_extension), - )?), + Some(predicate) => make_filter(predicate, Arc::new(new_extension))?, None => new_extension, }; Ok(Transformed::yes(new_plan)) @@ -1346,10 +1327,10 @@ fn rewrite_projection( Some(expr) => { // re-write all filters based on this projection // E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1" - let new_filter = LogicalPlan::Filter(Filter::try_new( + let new_filter = make_filter( replace_cols_by_name(expr, &pushable_map)?, std::mem::take(&mut projection.input), - )?); + )?; projection.input = Arc::new(new_filter); @@ -1365,9 +1346,12 @@ fn rewrite_projection( } } -/// Creates a new LogicalPlan::Filter node. -pub fn make_filter(predicate: Expr, input: Arc) -> Result { - Filter::try_new(predicate, input).map(LogicalPlan::Filter) +/// Creates a filter node without re-validating predicate type. +fn make_filter(predicate: Expr, input: Arc) -> Result { + // PushDownFilter only rebuilds predicates that already came from validated + // filter/join expressions, so re-running full boolean type validation here + // only recomputes expensive expression schemas. + Ok(LogicalPlan::Filter(Filter::new_unchecked(predicate, input))) } /// Replace the existing child of the single input node with `new_child`.