Skip to content

Commit

Permalink
[fix](nereids) unnest in-subquery with agg node in proper condition (a…
Browse files Browse the repository at this point in the history
…pache#25800)

consider sql having in-subquery

SELECT count(*)
        FROM sub_query_correlated_subquery6
        WHERE k1 IN 
            (SELECT k1
            FROM 
                (**SELECT k1,
                sum(k3) AS bbb,
                count(k2) AS aaa
                FROM sub_query_correlated_subquery7
                WHERE k1 > 0
                        AND k3 > 0
                GROUP BY  k1** ) y
                WHERE y.aaa>0
                        AND k1>1); 

The subquery part having agg is un-correlated, which can be unnested.

on the other side:
SELECT count(*)
                    FROM sub_query_correlated_subquery6
                    WHERE k1 IN 
                        (SELECT k1
                        FROM 
                            (**SELECT k1,
                            sum(k3) AS bbb,
                            count(k2) AS aaa
                            FROM sub_query_correlated_subquery7
                            WHERE k1 > 0
                                    AND k3 > 0 and sub_query_correlated_subquery6.k1 > 2
                            GROUP BY  k1** ) y
                            WHERE y.aaa>0
                                    AND k1>1);

The subquery part having agg is correlated, which can't be unnested.
  • Loading branch information
starocean999 authored Nov 7, 2023
1 parent 16644ef commit f138aaa
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
Expand Down Expand Up @@ -89,8 +91,7 @@ public Expression visitInSubquery(InSubquery expr, CascadesContext context) {
AnalyzedResult analyzedResult = analyzeSubquery(expr);

checkOutputColumn(analyzedResult.getLogicalPlan());
checkHasNotAgg(analyzedResult);
checkHasGroupBy(analyzedResult);
checkNoCorrelatedSlotsUnderAgg(analyzedResult);
checkRootIsLimit(analyzedResult);

return new InSubquery(
Expand All @@ -105,7 +106,7 @@ public Expression visitScalarSubquery(ScalarSubquery scalar, CascadesContext con

checkOutputColumn(analyzedResult.getLogicalPlan());
checkHasAgg(analyzedResult);
checkHasGroupBy(analyzedResult);
checkHasNoGroupBy(analyzedResult);

return new ScalarSubquery(analyzedResult.getLogicalPlan(), analyzedResult.getCorrelatedSlots());
}
Expand Down Expand Up @@ -135,7 +136,7 @@ private void checkHasAgg(AnalyzedResult analyzedResult) {
}
}

private void checkHasGroupBy(AnalyzedResult analyzedResult) {
private void checkHasNoGroupBy(AnalyzedResult analyzedResult) {
if (!analyzedResult.isCorrelated()) {
return;
}
Expand All @@ -145,13 +146,11 @@ private void checkHasGroupBy(AnalyzedResult analyzedResult) {
}
}

private void checkHasNotAgg(AnalyzedResult analyzedResult) {
if (!analyzedResult.isCorrelated()) {
return;
}
if (analyzedResult.hasAgg()) {
throw new AnalysisException("Unsupported correlated subquery with grouping and/or aggregation "
+ analyzedResult.getLogicalPlan());
private void checkNoCorrelatedSlotsUnderAgg(AnalyzedResult analyzedResult) {
if (analyzedResult.hasCorrelatedSlotsUnderAgg()) {
throw new AnalysisException(
"Unsupported correlated subquery with grouping and/or aggregation "
+ analyzedResult.getLogicalPlan());
}
}

Expand Down Expand Up @@ -223,6 +222,29 @@ public boolean hasGroupBy() {
return false;
}

public boolean hasCorrelatedSlotsUnderAgg() {
return correlatedSlots.isEmpty() ? false
: findAggContainsCorrelatedSlots(logicalPlan, ImmutableSet.copyOf(correlatedSlots));
}

private boolean findAggContainsCorrelatedSlots(Plan rootPlan, ImmutableSet<Slot> slots) {
ArrayDeque<Plan> planQueue = new ArrayDeque<>();
planQueue.add(rootPlan);
while (!planQueue.isEmpty()) {
Plan plan = planQueue.poll();
if (plan instanceof LogicalAggregate) {
if (plan.containsSlots(slots)) {
return true;
}
} else {
for (Plan child : plan.children()) {
planQueue.add(child);
}
}
}
return false;
}

public boolean rootIsLimit() {
return logicalPlan instanceof LogicalLimit;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
Expand Down Expand Up @@ -269,8 +270,10 @@ private LogicalPlan subqueryToApply(List<SubqueryExpr> subqueryExprs, LogicalPla
private boolean nonMarkJoinExistsWithAgg(SubqueryExpr exists,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot) {
return exists instanceof Exists
&& exists.getQueryPlan().anyMatch(Aggregate.class::isInstance)
&& !subqueryToMarkJoinSlot.get(exists).isPresent();
&& exists.getQueryPlan()
.anyMatch(planTreeNode -> planTreeNode instanceof LogicalAggregate
&& ((LogicalAggregate<?>) planTreeNode).getGroupByExpressions().isEmpty())
&& !subqueryToMarkJoinSlot.get(exists).isPresent();
}

private LogicalPlan addApply(SubqueryExpr subquery, LogicalPlan childPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ default boolean hasUnboundExpression() {
return getExpressions().stream().anyMatch(Expression::hasUnbound);
}

default boolean containsSlots(ImmutableSet<Slot> slots) {
return getExpressions().stream().anyMatch(
expression -> !Sets.intersection(slots, expression.getInputSlots()).isEmpty()
|| children().stream().anyMatch(plan -> plan.containsSlots(slots)));
}

default LogicalProperties computeLogicalProperties() {
throw new IllegalStateException("Not support compute logical properties for " + getClass().getName());
}
Expand Down
21 changes: 21 additions & 0 deletions regression-test/data/nereids_syntax_p0/sub_query_correlated.out
Original file line number Diff line number Diff line change
Expand Up @@ -465,3 +465,24 @@ true
\N
\N

-- !cir_5218_in_ok --
4

-- !cir_5218_exists_ok_1 --
13

-- !cir_5218_exists_ok_2 --
3

-- !cir_5218_exists_ok_3 --
5

-- !cir_5218_exists_ok_4 --
13

-- !cir_5218_exists_ok_5 --
13

-- !cir_5218_exists_ok_6 --
0

116 changes: 116 additions & 0 deletions regression-test/suites/nereids_syntax_p0/sub_query_correlated.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ suite ("sub_query_correlated") {
DROP TABLE IF EXISTS `sub_query_correlated_subquery9`
"""

sql """
DROP TABLE IF EXISTS `sub_query_correlated_subquery10`
"""

sql """
create table if not exists sub_query_correlated_subquery1
(k1 bigint, k2 bigint)
Expand Down Expand Up @@ -128,6 +132,13 @@ suite ("sub_query_correlated") {
properties('replication_num' = '1');
"""

sql """
create table if not exists sub_query_correlated_subquery10
(k1 int, k2 varchar(128), k3 bigint, v1 bigint, v2 bigint)
distributed by hash(k2) buckets 1
properties('replication_num' = '1');
"""

sql """
insert into sub_query_correlated_subquery1 values (1,2), (1,3), (2,4), (2,5), (3,3), (3,4), (20,2), (22,3), (24,4)
"""
Expand Down Expand Up @@ -532,6 +543,111 @@ suite ("sub_query_correlated") {
select sub_query_correlated_subquery8.k1 in (select sub_query_correlated_subquery9.k3 from sub_query_correlated_subquery9) from sub_query_correlated_subquery8 order by k1, k2;
"""

qt_cir_5218_in_ok """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE k1 IN
(SELECT k1
FROM
(SELECT k1,
sum(k3) AS bbb,
count(k2) AS aaa
FROM sub_query_correlated_subquery7
WHERE k1 > 0
AND k3 > 0
GROUP BY k1 ) y
WHERE y.aaa>0
AND k1>1);
"""

qt_cir_5218_exists_ok_1 """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE exists
(SELECT k1
FROM
(SELECT k1,
sum(k3) AS bbb,
count(k2) AS aaa
FROM sub_query_correlated_subquery7
WHERE k1 > 0
AND k3 > 0
GROUP BY k1 ) y
WHERE y.aaa>0
AND k1>1);
"""

qt_cir_5218_exists_ok_2 """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE exists
(SELECT k1
FROM
(SELECT k1
FROM sub_query_correlated_subquery7
WHERE sub_query_correlated_subquery6.k1 > 7
GROUP BY k1 ) y);
"""

qt_cir_5218_exists_ok_3 """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE exists
(SELECT k1
FROM
(SELECT k1
FROM sub_query_correlated_subquery7
WHERE sub_query_correlated_subquery6.k1 > sub_query_correlated_subquery7.k3
GROUP BY k1 ) y);
"""

qt_cir_5218_exists_ok_4 """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE exists
(SELECT sum(k3)
FROM
sub_query_correlated_subquery7
WHERE sub_query_correlated_subquery6.k1 > sub_query_correlated_subquery7.k3);
"""

qt_cir_5218_exists_ok_5 """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE exists
(SELECT sum(k3)
FROM
sub_query_correlated_subquery10);
"""

qt_cir_5218_exists_ok_6 """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE exists
(SELECT sum(k3)
FROM
sub_query_correlated_subquery10 group by k2);
"""

test {
sql """
SELECT count(*)
FROM sub_query_correlated_subquery6
WHERE k1 IN
(SELECT k1
FROM
(SELECT k1,
sum(k3) AS bbb,
count(k2) AS aaa
FROM sub_query_correlated_subquery7
WHERE k1 > 0
AND k3 > 0 and sub_query_correlated_subquery6.k1 > 2
GROUP BY k1 ) y
WHERE y.aaa>0
AND k1>1); """
exception "Unsupported correlated subquery with grouping and/or aggregation";
}

// order_qt_doris_6937_2 """
// select * from sub_query_correlated_subquery1 where sub_query_correlated_subquery1.k1 not in (select sub_query_correlated_subquery3.k3 from sub_query_correlated_subquery3 where sub_query_correlated_subquery3.v2 > sub_query_correlated_subquery1.k2) or k1 < 10 order by k1, k2;
// """
Expand Down

0 comments on commit f138aaa

Please sign in to comment.