Skip to content

Commit 5ceb442

Browse files
zhztheplayerglutenperfbot
authored andcommitted
Register merge extract companion agg functions without suffix
1 parent 948f1dd commit 5ceb442

File tree

1 file changed

+38
-41
lines changed

1 file changed

+38
-41
lines changed

velox/exec/AggregateCompanionAdapter.cpp

Lines changed: 38 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
266266
const core::QueryConfig& config)
267267
-> std::unique_ptr<Aggregate> {
268268
if (auto func = getAggregateFunctionEntry(name)) {
269+
core::AggregationNode::Step usedStep{
270+
core::AggregationNode::Step::kPartial};
269271
if (!exec::isRawInput(step)) {
270-
step = core::AggregationNode::Step::kIntermediate;
272+
usedStep = core::AggregationNode::Step::kIntermediate;
271273
}
272-
auto fn = func->factory(step, argTypes, resultType, config);
274+
auto fn =
275+
func->factory(usedStep, argTypes, resultType, config);
273276
VELOX_CHECK_NOT_NULL(fn);
274277
return std::make_unique<
275278
AggregateCompanionAdapter::PartialFunction>(
@@ -387,56 +390,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
387390
const std::string& name,
388391
const std::vector<AggregateFunctionSignaturePtr>& signatures,
389392
bool overwrite) {
393+
bool registered = false;
390394
if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures(
391395
signatures)) {
392-
return registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
396+
registered |=
397+
registerMergeExtractFunctionWithSuffix(name, signatures, overwrite);
393398
}
394399

395400
auto mergeExtractSignatures =
396401
CompanionSignatures::mergeExtractFunctionSignatures(signatures);
397402
if (mergeExtractSignatures.empty()) {
398-
return false;
403+
return registered;
399404
}
400405

401406
auto mergeExtractFunctionName =
402407
CompanionSignatures::mergeExtractFunctionName(name);
403-
return exec::registerAggregateFunction(
404-
mergeExtractFunctionName,
405-
std::move(mergeExtractSignatures),
406-
[name, mergeExtractFunctionName](
407-
core::AggregationNode::Step /*step*/,
408-
const std::vector<TypePtr>& argTypes,
409-
const TypePtr& resultType,
410-
const core::QueryConfig& config)
411-
-> std::unique_ptr<Aggregate> {
412-
const auto& [originalResultType, _] =
413-
resolveAggregateFunction(mergeExtractFunctionName, argTypes);
414-
if (!originalResultType) {
415-
// TODO: limitation -- result type must be resolveable given
416-
// intermediate type of the original UDAF.
417-
VELOX_UNREACHABLE(
418-
"Signatures whose result types are not resolvable given intermediate types should have been excluded.");
419-
}
420-
421-
if (auto func = getAggregateFunctionEntry(name)) {
422-
auto fn = func->factory(
423-
core::AggregationNode::Step::kFinal,
424-
argTypes,
425-
originalResultType,
426-
config);
427-
VELOX_CHECK_NOT_NULL(fn);
428-
return std::make_unique<
429-
AggregateCompanionAdapter::MergeExtractFunction>(
430-
std::move(fn), resultType);
431-
}
432-
VELOX_FAIL(
433-
"Original aggregation function {} not found: {}",
434-
name,
435-
mergeExtractFunctionName);
436-
},
437-
/*registerCompanionFunctions*/ false,
438-
overwrite)
439-
.mainFunction;
408+
registered |=
409+
exec::registerAggregateFunction(
410+
mergeExtractFunctionName,
411+
std::move(mergeExtractSignatures),
412+
[name, mergeExtractFunctionName](
413+
core::AggregationNode::Step /*step*/,
414+
const std::vector<TypePtr>& argTypes,
415+
const TypePtr& resultType,
416+
const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
417+
if (auto func = getAggregateFunctionEntry(name)) {
418+
auto fn = func->factory(
419+
core::AggregationNode::Step::kFinal,
420+
argTypes,
421+
resultType,
422+
config);
423+
VELOX_CHECK_NOT_NULL(fn);
424+
return std::make_unique<
425+
AggregateCompanionAdapter::MergeExtractFunction>(
426+
std::move(fn), resultType);
427+
}
428+
VELOX_FAIL(
429+
"Original aggregation function {} not found: {}",
430+
name,
431+
mergeExtractFunctionName);
432+
},
433+
/*registerCompanionFunctions*/ false,
434+
overwrite)
435+
.mainFunction;
436+
return registered;
440437
}
441438

442439
bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix(

0 commit comments

Comments
 (0)