Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,7 @@ public PlanFragment visitPhysicalHashAggregate(

aggregationNode.setNereidsId(aggregate.getId());
context.getNereidsIdToPlanNodeIdMap().put(aggregate.getId(), aggregationNode.getId());
if (isPartial) {
if (isPartial || aggregate.getAggregateParam().aggPhase.isLocal()) {
aggregationNode.unsetNeedsFinalize();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,26 @@
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.planner.AggregationNode;
import org.apache.doris.planner.OlapScanNode;
import org.apache.doris.planner.PlanFragment;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.planner.Planner;
import org.apache.doris.utframe.TestWithFeService;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import mockit.Injectable;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

import java.lang.reflect.Field;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

public class PhysicalPlanTranslatorTest {
public class PhysicalPlanTranslatorTest extends TestWithFeService {

@Test
public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exception {
Expand Down Expand Up @@ -86,4 +90,39 @@ public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exce
planNode.collect(OlapScanNode.class::isInstance, scanNodeList);
Assertions.assertEquals(2, scanNodeList.get(0).getTupleDesc().getSlots().size());
}

@Test
public void testAggNeedsFinalize() throws Exception {
createDatabase("test_db");
createTable("create table test_db.t(a int, b int) distributed by hash(a) buckets 3 "
+ "properties('replication_num' = '1');");
connectContext.getSessionVariable().setDisableNereidsRules("prune_empty_partition");
String querySql = "select b from test_db.t group by b";
Planner planner = getSQLPlanner(querySql);
Assertions.assertNotNull(planner);

List<PlanFragment> fragments = planner.getFragments();
Assertions.assertNotNull(fragments);
Assertions.assertFalse(fragments.isEmpty());

List<AggregationNode> aggNodes = new ArrayList<>();
for (PlanFragment fragment : fragments) {
PlanNode root = fragment.getPlanRoot();
if (root != null) {
root.collect(AggregationNode.class::isInstance, aggNodes);
}
}
Assertions.assertEquals(2, aggNodes.size());
Field needsFinalizeField = AggregationNode.class.getDeclaredField("needsFinalize");
needsFinalizeField.setAccessible(true);
AggregationNode upperAggNode = aggNodes.get(0);
AggregationNode lowerAggNode = aggNodes.get(1);

boolean lowerNeedsFinalize = needsFinalizeField.getBoolean(lowerAggNode);
Assertions.assertFalse(lowerNeedsFinalize,
"lower AggregationNode needsFinalize should be false");
boolean upperNeedsFinalize = needsFinalizeField.getBoolean(upperAggNode);
Assertions.assertTrue(upperNeedsFinalize,
"upper AggregationNode needsFinalize should be true");
}
}
Loading