Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-testing
62 changes: 51 additions & 11 deletions datafusion/optimizer/src/push_down_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,25 +810,22 @@ impl OptimizerRule for PushDownFilter {

match Arc::unwrap_or_clone(filter.input) {
LogicalPlan::Filter(child_filter) => {
let parents_predicates = split_conjunction_owned(filter.predicate);

// remove duplicated filters
let child_predicates = split_conjunction_owned(child_filter.predicate);
let new_predicates = parents_predicates
.into_iter()
.chain(child_predicates)
// use IndexSet to remove dupes while preserving predicate order
.collect::<IndexSet<_>>()
// child filters first to preserve execution order
let new_predicates = split_conjunction_owned(child_filter.predicate)
.into_iter()
.collect::<Vec<_>>();
.chain(split_conjunction_owned(filter.predicate))
// use IndexSet to remove duplicates while preserving predicate order
.collect::<IndexSet<_>>();

let Some(new_predicate) = conjunction(new_predicates) else {
return plan_err!("at least one expression exists");
};

let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input,
)?);

self.rewrite(new_filter, config)
}
LogicalPlan::Repartition(repartition) => {
Expand Down Expand Up @@ -2474,7 +2471,7 @@ mod tests {
plan,
@r"
Projection: test.a
Filter: test.a >= Int64(1) AND test.a <= Int64(1)
Filter: test.a <= Int64(1) AND test.a >= Int64(1)
Limit: skip=0, fetch=1
TableScan: test
"
Expand Down Expand Up @@ -3253,6 +3250,28 @@ mod tests {
)
}

#[test]
fn multi_combined_two_filters() -> Result<()> {
let plan = table_scan_with_pushdown_provider_builder(
TableProviderFilterPushDown::Inexact,
vec![col("a").eq(lit(10i64)), col("b").gt(lit(11i64))],
Some(vec![0]),
)?
.filter(col("a").eq(lit(10i64)))?
.filter(col("b").gt(lit(11i64)))?
.project(vec![col("a"), col("b")])?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Projection: a, b
Filter: a = Int64(10) AND b > Int64(11)
TableScan: test projection=[a], partial_filters=[a = Int64(10), b > Int64(11)]
"
)
}

#[test]
fn multi_combined_filter_exact() -> Result<()> {
let plan = table_scan_with_pushdown_provider_builder(
Expand All @@ -3273,6 +3292,27 @@ mod tests {
)
}

#[test]
fn multi_combined_two_filters_exact() -> Result<()> {
let plan = table_scan_with_pushdown_provider_builder(
TableProviderFilterPushDown::Exact,
vec![],
Some(vec![0]),
)?
.filter(col("a").eq(lit(10i64)))?
.filter(col("b").gt(lit(11i64)))?
.project(vec![col("a"), col("b")])?
.build()?;

assert_optimized_plan_equal!(
plan,
@r"
Projection: a, b
TableScan: test projection=[a], full_filters=[a = Int64(10), b > Int64(11)]
"
)
}

#[test]
fn test_filter_with_alias() -> Result<()> {
// in table scan the true col name is 'test.a',
Expand Down
4 changes: 2 additions & 2 deletions datafusion/sqllogictest/test_files/predicates.slt
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ logical_plan
02)--Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
03)----Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)
04)------TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
05)----Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
05)----Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15))
06)------TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_size <= Int32(15)]
physical_plan
01)HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_partkey@0]
Expand All @@ -674,7 +674,7 @@ physical_plan
04)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
05)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/lineitem.csv]]}, projection=[l_partkey, l_quantity], file_type=csv, has_header=true
06)--RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
07)----FilterExec: (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15) AND p_size@2 >= 1
07)----FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_size@2 <= 15)
08)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
09)--------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/tpch-csv/part.csv]]}, projection=[p_partkey, p_brand, p_size], file_type=csv, has_header=true

Expand Down
8 changes: 4 additions & 4 deletions datafusion/sqllogictest/test_files/tpch/plans/q19.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ logical_plan
03)----Projection: lineitem.l_extendedprice, lineitem.l_discount
04)------Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15)
05)--------Projection: lineitem.l_partkey, lineitem.l_quantity, lineitem.l_extendedprice, lineitem.l_discount
06)----------Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON")
06)----------Filter: (lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG")) AND lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON") AND (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2))
07)------------TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], partial_filters=[lineitem.l_shipmode = Utf8View("AIR") OR lineitem.l_shipmode = Utf8View("AIR REG"), lineitem.l_shipinstruct = Utf8View("DELIVER IN PERSON"), lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)]
08)--------Filter: (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1)
08)--------Filter: part.p_size >= Int32(1) AND (part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15))
09)----------TableScan: part projection=[p_partkey, p_brand, p_size, p_container], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8View("Brand#12") AND part.p_container IN ([Utf8View("SM CASE"), Utf8View("SM BOX"), Utf8View("SM PACK"), Utf8View("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8View("Brand#23") AND part.p_container IN ([Utf8View("MED BAG"), Utf8View("MED BOX"), Utf8View("MED PKG"), Utf8View("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8View("Brand#34") AND part.p_container IN ([Utf8View("LG CASE"), Utf8View("LG BOX"), Utf8View("LG PACK"), Utf8View("LG PKG")]) AND part.p_size <= Int32(15)]
physical_plan
01)ProjectionExec: expr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)@0 as revenue]
Expand All @@ -70,9 +70,9 @@ physical_plan
04)------AggregateExec: mode=Partial, gby=[], aggr=[sum(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]
05)--------HashJoinExec: mode=Partitioned, join_type=Inner, on=[(l_partkey@0, p_partkey@0)], filter=p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND l_quantity@0 >= Some(100),15,2 AND l_quantity@0 <= Some(1100),15,2 AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND l_quantity@0 >= Some(1000),15,2 AND l_quantity@0 <= Some(2000),15,2 AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND l_quantity@0 >= Some(2000),15,2 AND l_quantity@0 <= Some(3000),15,2 AND p_size@2 <= 15, projection=[l_extendedprice@2, l_discount@3]
06)----------RepartitionExec: partitioning=Hash([l_partkey@0], 4), input_partitions=4
07)------------FilterExec: (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2) AND (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON, projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3]
07)------------FilterExec: (l_shipmode@5 = AIR OR l_shipmode@5 = AIR REG) AND l_shipinstruct@4 = DELIVER IN PERSON AND (l_quantity@1 >= Some(100),15,2 AND l_quantity@1 <= Some(1100),15,2 OR l_quantity@1 >= Some(1000),15,2 AND l_quantity@1 <= Some(2000),15,2 OR l_quantity@1 >= Some(2000),15,2 AND l_quantity@1 <= Some(3000),15,2), projection=[l_partkey@0, l_quantity@1, l_extendedprice@2, l_discount@3]
08)--------------DataSourceExec: file_groups={4 groups: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:0..18561749], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:18561749..37123498], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:37123498..55685247], [WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/lineitem.tbl:55685247..74246996]]}, projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode], file_type=csv, has_header=false
09)----------RepartitionExec: partitioning=Hash([p_partkey@0], 4), input_partitions=4
10)------------FilterExec: (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15) AND p_size@2 >= 1
10)------------FilterExec: p_size@2 >= 1 AND (p_brand@1 = Brand#12 AND p_container@3 IN (SET) ([SM CASE, SM BOX, SM PACK, SM PKG]) AND p_size@2 <= 5 OR p_brand@1 = Brand#23 AND p_container@3 IN (SET) ([MED BAG, MED BOX, MED PKG, MED PACK]) AND p_size@2 <= 10 OR p_brand@1 = Brand#34 AND p_container@3 IN (SET) ([LG CASE, LG BOX, LG PACK, LG PKG]) AND p_size@2 <= 15)
11)--------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
12)----------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/part.tbl]]}, projection=[p_partkey, p_brand, p_size, p_container], file_type=csv, has_header=false
Loading