Skip to content

Commit 2b9138d

Browse files
committed
[Fix](ai) Fix _exec_plan_fragment_impl meet unknown error when call AI_Functions
1 parent eb93602 commit 2b9138d

File tree

6 files changed

+83
-10
lines changed

6 files changed

+83
-10
lines changed

be/src/runtime/query_context.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,12 @@ class QueryContext : public std::enable_shared_from_this<QueryContext> {
269269
std::make_unique<std::map<std::string, TAIResource>>(std::move(ai_resources));
270270
}
271271

272-
const std::map<std::string, TAIResource>& get_ai_resources() const {
272+
Status get_ai_resources(const std::map<std::string, TAIResource>** ai_resources) const {
273273
if (_ai_resources == nullptr) {
274-
throw Status::InternalError("AI resources not found");
274+
return Status::InternalError("AI resources not found");
275275
}
276-
return *_ai_resources;
276+
*ai_resources = _ai_resources.get();
277+
return Status::OK();
277278
}
278279

279280
std::unordered_map<TNetworkAddress, std::shared_ptr<PBackendService_Stub>>

be/src/vec/aggregate_functions/aggregate_function_ai_agg.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ class AggregateFunctionAIAggData {
141141
_task = task_ref.to_string();
142142

143143
std::string resource_name = resource_name_ref.to_string();
144-
const std::map<std::string, TAIResource>& ai_resources = _ctx->get_ai_resources();
145-
auto it = ai_resources.find(resource_name);
146-
if (it == ai_resources.end()) {
144+
const std::map<std::string, TAIResource>* ai_resources = nullptr;
145+
THROW_IF_ERROR(_ctx->get_ai_resources(&ai_resources));
146+
auto it = ai_resources->find(resource_name);
147+
if (it == ai_resources->end()) {
147148
throw Exception(ErrorCode::NOT_FOUND, "AI resource not found: " + resource_name);
148149
}
149150
_ai_config = it->second;

be/src/vec/functions/ai/ai_functions.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -190,10 +190,10 @@ class AIFunction : public IFunction {
190190
StringRef resource_name_ref = resource_column.column->get_data_at(0);
191191
std::string resource_name = std::string(resource_name_ref.data, resource_name_ref.size);
192192

193-
const std::map<std::string, TAIResource>& ai_resources =
194-
context->state()->get_query_ctx()->get_ai_resources();
195-
auto it = ai_resources.find(resource_name);
196-
if (it == ai_resources.end()) {
193+
const std::map<std::string, TAIResource>* ai_resources = nullptr;
194+
RETURN_IF_ERROR(context->state()->get_query_ctx()->get_ai_resources(&ai_resources));
195+
auto it = ai_resources->find(resource_name);
196+
if (it == ai_resources->end()) {
197197
return Status::InvalidArgument("AI resource not found: " + resource_name);
198198
}
199199
config = it->second;

be/test/ai/aggregate_function_ai_agg_test.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,4 +413,32 @@ TEST_F(AggregateFunctionAIAggTest, mock_resource_send_request_test) {
413413
_agg_function->destroy(place);
414414
}
415415

416+
TEST_F(AggregateFunctionAIAggTest, missing_ai_resources_metadata_test) {
417+
auto empty_query_ctx = MockQueryContext::create();
418+
_agg_function->set_query_context(empty_query_ctx.get());
419+
420+
std::vector<std::string> resources = {"resource_name"};
421+
std::vector<std::string> texts = {"test input"};
422+
std::vector<std::string> task = {"summarize"};
423+
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
424+
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
425+
auto col_task = ColumnHelper::create_column<DataTypeString>(task);
426+
427+
std::unique_ptr<char[]> memory(new char[_agg_function->size_of_data()]);
428+
AggregateDataPtr place = memory.get();
429+
_agg_function->create(place);
430+
431+
const IColumn* columns[3] = {col_resource.get(), col_text.get(), col_task.get()};
432+
433+
try {
434+
_agg_function->add(place, columns, 0, _arena);
435+
FAIL() << "Expected exception for missing AI resources";
436+
} catch (const Exception& e) {
437+
EXPECT_EQ(e.code(), ErrorCode::INTERNAL_ERROR);
438+
EXPECT_NE(e.to_string().find("AI resources metadata missing"), std::string::npos);
439+
}
440+
441+
_agg_function->destroy(place);
442+
}
443+
416444
} // namespace doris::vectorized

be/test/ai/ai_function_test.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -551,6 +551,35 @@ TEST(AIFunctionTest, MockResourceSendRequest) {
551551
ASSERT_EQ(val, "this is a mock response. test input");
552552
}
553553

554+
TEST(AIFunctionTest, MissingAIResourcesMetadataTest) {
555+
auto query_ctx = MockQueryContext::create();
556+
TQueryOptions query_options;
557+
TQueryGlobals query_globals;
558+
RuntimeState runtime_state(TUniqueId(), 0, query_options, query_globals, nullptr,
559+
query_ctx.get());
560+
auto ctx = FunctionContext::create_context(&runtime_state, {}, {});
561+
562+
std::vector<std::string> resources = {"resource_name"};
563+
std::vector<std::string> texts = {"test"};
564+
auto col_resource = ColumnHelper::create_column<DataTypeString>(resources);
565+
auto col_text = ColumnHelper::create_column<DataTypeString>(texts);
566+
567+
Block block;
568+
block.insert({std::move(col_resource), std::make_shared<DataTypeString>(), "resource"});
569+
block.insert({std::move(col_text), std::make_shared<DataTypeString>(), "text"});
570+
block.insert({nullptr, std::make_shared<DataTypeString>(), "result"});
571+
572+
ColumnNumbers arguments = {0, 1};
573+
size_t result_idx = 2;
574+
575+
auto sentiment_func = FunctionAISentiment::create();
576+
Status exec_status =
577+
sentiment_func->execute_impl(ctx.get(), block, arguments, result_idx, texts.size());
578+
579+
ASSERT_FALSE(exec_status.ok());
580+
ASSERT_NE(exec_status.to_string().find("AI resources metadata missing"), std::string::npos);
581+
}
582+
554583
TEST(AIFunctionTest, ReturnTypeTest) {
555584
FunctionAIClassify func_classify;
556585
DataTypes args;

fe/fe-core/src/main/java/org/apache/doris/qe/Coordinator.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
import org.apache.doris.analysis.DescriptorTable;
2121
import org.apache.doris.analysis.StorageBackend;
22+
import org.apache.doris.catalog.AIResource;
2223
import org.apache.doris.catalog.Env;
2324
import org.apache.doris.catalog.FsBroker;
25+
import org.apache.doris.catalog.Resource;
2426
import org.apache.doris.common.Config;
2527
import org.apache.doris.common.MarkedCountDownLatch;
2628
import org.apache.doris.common.Pair;
@@ -89,6 +91,7 @@
8991
import org.apache.doris.system.SystemInfoService;
9092
import org.apache.doris.task.LoadEtlTask;
9193
import org.apache.doris.thrift.PaloInternalServiceVersion;
94+
import org.apache.doris.thrift.TAIResource;
9295
import org.apache.doris.thrift.TBrokerScanRange;
9396
import org.apache.doris.thrift.TDataSinkType;
9497
import org.apache.doris.thrift.TDescriptorTable;
@@ -3235,6 +3238,17 @@ Map<TNetworkAddress, TPipelineFragmentParams> toThrift(int backendNum) {
32353238
if (ignoreDataDistribution) {
32363239
params.setParallelInstances(parallelTasksNum);
32373240
}
3241+
3242+
// Used for AI Functions
3243+
Map<String, TAIResource> aiResourceMap = Maps.newLinkedHashMap();
3244+
for (Resource resource : Env.getCurrentEnv().getResourceMgr()
3245+
.getResource(Resource.ResourceType.AI)) {
3246+
if (resource instanceof AIResource) {
3247+
aiResourceMap.put(resource.getName(), ((AIResource) resource).toThrift());
3248+
}
3249+
}
3250+
3251+
params.setAiResources(aiResourceMap);
32383252
res.put(instanceExecParam.host, params);
32393253
res.get(instanceExecParam.host).setBucketSeqToInstanceIdx(new HashMap<Integer, Integer>());
32403254
res.get(instanceExecParam.host).setShuffleIdxToInstanceIdx(new HashMap<Integer, Integer>());

0 commit comments

Comments
 (0)