Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 51 additions & 4 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2496,6 +2496,19 @@ impl Filter {
Self::try_new_internal(predicate, input)
}

/// Create a new filter operator without re-validating the predicate type.
///
/// This is intended for internal optimizer use when rearranging predicates
/// that are already known to be valid filter expressions. Like
/// [`Self::try_new`], this removes nested aliases from the predicate.
#[doc(hidden)]
pub fn new_unchecked(predicate: Expr, input: Arc<LogicalPlan>) -> Self {
Self {
predicate: predicate.unalias_nested().data,
input,
}
}

fn is_allowed_filter_type(data_type: &DataType) -> bool {
match data_type {
// Interpret NULL as a missing boolean value.
Expand All @@ -2520,10 +2533,7 @@ impl Filter {
);
}

Ok(Self {
predicate: predicate.unalias_nested().data,
input,
})
Ok(Self::new_unchecked(predicate, input))
}

/// Is this filter guaranteed to return 0 or 1 row in a given instantiation?
Expand Down Expand Up @@ -5163,6 +5173,43 @@ mod tests {
assert!(filter.is_scalar());
}

#[test]
fn test_filter_new_unchecked_strips_aliases() {
let scan = Arc::new(
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0]))
.unwrap()
.build()
.unwrap(),
);

let predicate = col("id").alias("employee_id").eq(lit(1i32)).alias("pred");
let unchecked = Filter::new_unchecked(predicate, scan);

assert_eq!(unchecked.predicate, col("id").eq(lit(1i32)));
}

#[test]
fn test_filter_new_unchecked_skips_type_validation() {
let scan = Arc::new(
table_scan(Some("employee_csv"), &employee_schema(), Some(vec![0]))
.unwrap()
.build()
.unwrap(),
);

let predicate = col("id") + lit(1i32);

let err = Filter::try_new(predicate.clone(), Arc::clone(&scan)).unwrap_err();
assert!(
err.to_string()
.contains("Cannot create filter with non-boolean predicate"),
"{err}"
);

let unchecked = Filter::new_unchecked(predicate.clone(), scan);
assert_eq!(unchecked.predicate, predicate);
}

#[test]
fn test_transform_explain() {
let schema = Schema::new(vec![
Expand Down
70 changes: 27 additions & 43 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -496,11 +496,10 @@ fn push_down_all_join(
}

if let Some(predicate) = conjunction(left_push) {
join.left = Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.left)?));
join.left = Arc::new(make_filter(predicate, join.left)?);
}
if let Some(predicate) = conjunction(right_push) {
join.right =
Arc::new(LogicalPlan::Filter(Filter::try_new(predicate, join.right)?));
join.right = Arc::new(make_filter(predicate, join.right)?);
}

// Add any new join conditions as the non join predicates
Expand All @@ -510,7 +509,7 @@ fn push_down_all_join(
// wrap the join on the filter whose predicates must be kept, if any
let plan = LogicalPlan::Join(join);
let plan = if let Some(predicate) = conjunction(keep_predicates) {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(plan))?)
make_filter(predicate, Arc::new(plan))?
} else {
plan
};
Expand Down Expand Up @@ -825,28 +824,21 @@ impl OptimizerRule for PushDownFilter {
let Some(new_predicate) = conjunction(new_predicates) else {
return plan_err!("at least one expression exists");
};
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input,
)?);
let new_filter = make_filter(new_predicate, child_filter.input)?;
self.rewrite(new_filter, config)
}
LogicalPlan::Repartition(repartition) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(&repartition.input))
.map(LogicalPlan::Filter)?;
make_filter(filter.predicate, Arc::clone(&repartition.input))?;
insert_below(LogicalPlan::Repartition(repartition), new_filter)
}
LogicalPlan::Distinct(distinct) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(distinct.input()))
.map(LogicalPlan::Filter)?;
make_filter(filter.predicate, Arc::clone(distinct.input()))?;
insert_below(LogicalPlan::Distinct(distinct), new_filter)
}
LogicalPlan::Sort(sort) => {
let new_filter =
Filter::try_new(filter.predicate, Arc::clone(&sort.input))
.map(LogicalPlan::Filter)?;
let new_filter = make_filter(filter.predicate, Arc::clone(&sort.input))?;
insert_below(LogicalPlan::Sort(sort), new_filter)
}
LogicalPlan::SubqueryAlias(subquery_alias) => {
Expand All @@ -863,10 +855,8 @@ impl OptimizerRule for PushDownFilter {
}
let new_predicate = replace_cols_by_name(filter.predicate, &replace_map)?;

let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
Arc::clone(&subquery_alias.input),
)?);
let new_filter =
make_filter(new_predicate, Arc::clone(&subquery_alias.input))?;
insert_below(LogicalPlan::SubqueryAlias(subquery_alias), new_filter)
}
LogicalPlan::Projection(projection) => {
Expand All @@ -877,8 +867,7 @@ impl OptimizerRule for PushDownFilter {
match keep_predicate {
None => Ok(new_projection),
Some(keep_predicate) => new_projection.map_data(|child_plan| {
Filter::try_new(keep_predicate, Arc::new(child_plan))
.map(LogicalPlan::Filter)
make_filter(keep_predicate, Arc::new(child_plan))
}),
}
} else {
Expand Down Expand Up @@ -944,10 +933,10 @@ impl OptimizerRule for PushDownFilter {

let unnest_input = std::mem::take(&mut unnest.input);

let filter_with_unnest_input = LogicalPlan::Filter(Filter::try_new(
let filter_with_unnest_input = make_filter(
conjunction(non_unnest_predicates).unwrap(), // Safe to unwrap since non_unnest_predicates is not empty.
unnest_input,
)?);
)?;

// Directly assign new filter plan as the new unnest's input.
// The new filter plan will go through another rewrite pass since the rule itself
Expand All @@ -957,9 +946,10 @@ impl OptimizerRule for PushDownFilter {

match conjunction(unnest_predicates) {
None => Ok(unnest_plan),
Some(predicate) => Ok(Transformed::yes(LogicalPlan::Filter(
Filter::try_new(predicate, Arc::new(unnest_plan.data))?,
))),
Some(predicate) => Ok(Transformed::yes(make_filter(
predicate,
Arc::new(unnest_plan.data),
)?)),
}
}
LogicalPlan::Union(ref union) => {
Expand All @@ -977,10 +967,7 @@ impl OptimizerRule for PushDownFilter {

let push_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
push_predicate,
Arc::clone(input),
)?)))
inputs.push(Arc::new(make_filter(push_predicate, Arc::clone(input))?))
}
Ok(Transformed::yes(LogicalPlan::Union(Union {
inputs,
Expand Down Expand Up @@ -1250,10 +1237,7 @@ impl OptimizerRule for PushDownFilter {
.inputs()
.into_iter()
.map(|child| {
Ok(LogicalPlan::Filter(Filter::try_new(
predicate.clone(),
Arc::new(child.clone()),
)?))
make_filter(predicate.clone(), Arc::new(child.clone()))
})
.collect::<Result<Vec<_>>>()?,
None => extension_plan.node.inputs().into_iter().cloned().collect(),
Expand All @@ -1264,10 +1248,7 @@ impl OptimizerRule for PushDownFilter {
child_plan.with_new_exprs(child_plan.expressions(), new_children)?;

let new_plan = match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_extension),
)?),
Some(predicate) => make_filter(predicate, Arc::new(new_extension))?,
None => new_extension,
};
Ok(Transformed::yes(new_plan))
Expand Down Expand Up @@ -1346,10 +1327,10 @@ fn rewrite_projection(
Some(expr) => {
// re-write all filters based on this projection
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
let new_filter = LogicalPlan::Filter(Filter::try_new(
let new_filter = make_filter(
replace_cols_by_name(expr, &pushable_map)?,
std::mem::take(&mut projection.input),
)?);
)?;

projection.input = Arc::new(new_filter);

Expand All @@ -1365,9 +1346,12 @@ fn rewrite_projection(
}
}

/// Creates a new LogicalPlan::Filter node.
pub fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
Filter::try_new(predicate, input).map(LogicalPlan::Filter)
/// Creates a filter node without re-validating predicate type.
fn make_filter(predicate: Expr, input: Arc<LogicalPlan>) -> Result<LogicalPlan> {
// PushDownFilter only rebuilds predicates that already came from validated
// filter/join expressions, so re-running full boolean type validation here
// only recomputes expensive expression schemas.
Ok(LogicalPlan::Filter(Filter::new_unchecked(predicate, input)))
}

/// Replace the existing child of the single input node with `new_child`.
Expand Down
Loading