Skip to content

Commit 0655a50

Browse files
committed
Relax Inputs/Ouputs validation
* add RunOptions for qnn htp batch multiplier * avoid validating inputs/outpus if this option is used
1 parent e9da0b1 commit 0655a50

File tree

2 files changed

+14
-22
lines changed

2 files changed

+14
-22
lines changed

include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_con
4646
// Set QNN Lora Config File for apply Lora in QNN context binary
4747
static const char* const kOrtRunOptionsConfigQnnLoraConfig = "qnn.lora_config";
4848

49+
// Set QNN enable batch mutiplier for htp backend
50+
static const char* const kOrtRunOptionsConfigQnnBatchMultiplier = "qnn.enable_htp_batch_multiplier";
51+
4952
// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
5053
// The value should be an integer. If the value is not set, the default value is 0 and
5154
// ORT session only captures one cuda graph before another capture is requested.

onnxruntime/core/session/inference_session.cc

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2606,30 +2606,12 @@ common::Status InferenceSession::CheckShapes(const std::string& input_output_nam
26062606
" Please fix either the inputs/outputs or the model.");
26072607
}
26082608

2609-
#ifdef USE_QNN
2610-
// Helper function to check whether QNN EP is used & all nodes are assigned to QNN EP,
2611-
// and relax the constraint to support batch multiplier on the first dimension.
2612-
// We will check whether only the Htp backend is used inside QnnModel::ExecuteGraph.
2613-
auto is_valid_qnn_batch_multiplier = [this](int64_t input_dim, int64_t expected_dim, const Graph& graph) -> bool {
2614-
if (!AreAllNodesInMainGraphAssignedToOneEp(graph, kQnnExecutionProvider)) {
2615-
LOGS_IF(input_dim != expected_dim, *session_logger_, WARNING) << "input batch size and expected batch size are different. "
2616-
<< "Batch multiplier is only supported on the QNN EP, "
2617-
<< "but some nodes in the graph are assigned to other EPs";
2618-
return false;
2619-
}
2620-
return expected_dim > 0 && input_dim % expected_dim == 0;
2621-
};
2622-
#endif
2623-
26242609
InlinedVector<size_t> invalid_dim_indices;
26252610
for (size_t i = 0; i < shape_size; ++i) {
26262611
if (expected_shape[i] < 0) {
26272612
continue; // this represents a symbolic shape dimension
2628-
#ifdef USE_QNN
2629-
} else if (i == 0 && is_valid_qnn_batch_multiplier(input_output_shape[i], expected_shape[i], model_->MainGraph())) {
2630-
continue; // Qnn API supports batch multiplier, but the running batch size must be divisible by the original batch size.
2631-
#endif
2632-
} else if (input_output_shape[i] != expected_shape[i]) {
2613+
}
2614+
if (input_output_shape[i] != expected_shape[i]) {
26332615
invalid_dim_indices.push_back(i);
26342616
}
26352617
}
@@ -2999,10 +2981,17 @@ Status InferenceSession::Run(const RunOptions& run_options,
29992981

30002982
// log evaluation start to trace logging provider
30012983
env.GetTelemetryProvider().LogEvaluationStart(session_id_);
3002-
2984+
#ifdef USE_QNN
2985+
const std::string& batch_multiplier = run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigQnnBatchMultiplier, "");
2986+
if (batch_multiplier.empty()) {
2987+
LOGS(*session_logger_, INFO) << "Enable QNN HTP batch mutliplier. Don't validate the inputs/outptus";
2988+
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds));
2989+
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches));
2990+
}
2991+
#else
30032992
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds));
30042993
ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches));
3005-
2994+
#endif
30062995
// shrink certain default memory arenas if the user has requested for it
30072996
const std::string& shrink_memory_arenas =
30082997
run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigEnableMemoryArenaShrinkage, "");

0 commit comments

Comments
 (0)