@@ -266,10 +266,13 @@ bool CompanionFunctionsRegistrar::registerPartialFunction(
266266 const core::QueryConfig& config)
267267 -> std::unique_ptr<Aggregate> {
268268 if (auto func = getAggregateFunctionEntry (name)) {
269+ core::AggregationNode::Step usedStep{
270+ core::AggregationNode::Step::kPartial };
269271 if (!exec::isRawInput (step)) {
270- step = core::AggregationNode::Step::kIntermediate ;
272+ usedStep = core::AggregationNode::Step::kIntermediate ;
271273 }
272- auto fn = func->factory (step, argTypes, resultType, config);
274+ auto fn =
275+ func->factory (usedStep, argTypes, resultType, config);
273276 VELOX_CHECK_NOT_NULL (fn);
274277 return std::make_unique<
275278 AggregateCompanionAdapter::PartialFunction>(
@@ -387,56 +390,50 @@ bool CompanionFunctionsRegistrar::registerMergeExtractFunction(
387390 const std::string& name,
388391 const std::vector<AggregateFunctionSignaturePtr>& signatures,
389392 bool overwrite) {
393+ bool registered = false ;
390394 if (CompanionSignatures::hasSameIntermediateTypesAcrossSignatures (
391395 signatures)) {
392- return registerMergeExtractFunctionWithSuffix (name, signatures, overwrite);
396+ registered |=
397+ registerMergeExtractFunctionWithSuffix (name, signatures, overwrite);
393398 }
394399
395400 auto mergeExtractSignatures =
396401 CompanionSignatures::mergeExtractFunctionSignatures (signatures);
397402 if (mergeExtractSignatures.empty ()) {
398- return false ;
403+ return registered ;
399404 }
400405
401406 auto mergeExtractFunctionName =
402407 CompanionSignatures::mergeExtractFunctionName (name);
403- return exec::registerAggregateFunction (
404- mergeExtractFunctionName,
405- std::move (mergeExtractSignatures),
406- [name, mergeExtractFunctionName](
407- core::AggregationNode::Step /* step*/ ,
408- const std::vector<TypePtr>& argTypes,
409- const TypePtr& resultType,
410- const core::QueryConfig& config)
411- -> std::unique_ptr<Aggregate> {
412- const auto & [originalResultType, _] =
413- resolveAggregateFunction (mergeExtractFunctionName, argTypes);
414- if (!originalResultType) {
415- // TODO: limitation -- result type must be resolveable given
416- // intermediate type of the original UDAF.
417- VELOX_UNREACHABLE (
418- " Signatures whose result types are not resolvable given intermediate types should have been excluded." );
419- }
420-
421- if (auto func = getAggregateFunctionEntry (name)) {
422- auto fn = func->factory (
423- core::AggregationNode::Step::kFinal ,
424- argTypes,
425- originalResultType,
426- config);
427- VELOX_CHECK_NOT_NULL (fn);
428- return std::make_unique<
429- AggregateCompanionAdapter::MergeExtractFunction>(
430- std::move (fn), resultType);
431- }
432- VELOX_FAIL (
433- " Original aggregation function {} not found: {}" ,
434- name,
435- mergeExtractFunctionName);
436- },
437- /* registerCompanionFunctions*/ false ,
438- overwrite)
439- .mainFunction ;
408+ registered |=
409+ exec::registerAggregateFunction (
410+ mergeExtractFunctionName,
411+ std::move (mergeExtractSignatures),
412+ [name, mergeExtractFunctionName](
413+ core::AggregationNode::Step /* step*/ ,
414+ const std::vector<TypePtr>& argTypes,
415+ const TypePtr& resultType,
416+ const core::QueryConfig& config) -> std::unique_ptr<Aggregate> {
417+ if (auto func = getAggregateFunctionEntry (name)) {
418+ auto fn = func->factory (
419+ core::AggregationNode::Step::kFinal ,
420+ argTypes,
421+ resultType,
422+ config);
423+ VELOX_CHECK_NOT_NULL (fn);
424+ return std::make_unique<
425+ AggregateCompanionAdapter::MergeExtractFunction>(
426+ std::move (fn), resultType);
427+ }
428+ VELOX_FAIL (
429+ " Original aggregation function {} not found: {}" ,
430+ name,
431+ mergeExtractFunctionName);
432+ },
433+ /* registerCompanionFunctions*/ false ,
434+ overwrite)
435+ .mainFunction ;
436+ return registered;
440437}
441438
442439bool CompanionFunctionsRegistrar::registerExtractFunctionWithSuffix (
0 commit comments