@@ -2606,29 +2606,10 @@ common::Status InferenceSession::CheckShapes(const std::string& input_output_nam
26062606 " Please fix either the inputs/outputs or the model." );
26072607 }
26082608
2609- #ifdef USE_QNN
2610- // Helper function to check whether QNN EP is used & all nodes are assigned to QNN EP,
2611- // and relax the constraint to support batch multiplier on the first dimension.
2612- // We will check whether only the Htp backend is used inside QnnModel::ExecuteGraph.
2613- auto is_valid_qnn_batch_multiplier = [this ](int64_t input_dim, int64_t expected_dim, const Graph& graph) -> bool {
2614- if (!AreAllNodesInMainGraphAssignedToOneEp (graph, kQnnExecutionProvider )) {
2615- LOGS_IF (input_dim != expected_dim, *session_logger_, WARNING) << " input batch size and expected batch size are different. "
2616- << " Batch multiplier is only supported on the QNN EP, "
2617- << " but some nodes in the graph are assigned to other EPs" ;
2618- return false ;
2619- }
2620- return expected_dim > 0 && input_dim % expected_dim == 0 ;
2621- };
2622- #endif
2623-
26242609 InlinedVector<size_t > invalid_dim_indices;
26252610 for (size_t i = 0 ; i < shape_size; ++i) {
26262611 if (expected_shape[i] < 0 ) {
26272612 continue ; // this represents a symbolic shape dimension
2628- #ifdef USE_QNN
2629- } else if (i == 0 && is_valid_qnn_batch_multiplier (input_output_shape[i], expected_shape[i], model_->MainGraph ())) {
2630- continue ; // Qnn API supports batch multiplier, but the running batch size must be divisible by the original batch size.
2631- #endif
26322613 } else if (input_output_shape[i] != expected_shape[i]) {
26332614 invalid_dim_indices.push_back (i);
26342615 }
@@ -2999,10 +2980,17 @@ Status InferenceSession::Run(const RunOptions& run_options,
29992980
30002981 // log evaluation start to trace logging provider
30012982 env.GetTelemetryProvider ().LogEvaluationStart (session_id_);
3002-
2983+ #ifdef USE_QNN
2984+ const std::string& batch_multiplier = run_options.config_options .GetConfigOrDefault (kOrtRunOptionsConfigQnnBatchMultiplier , " " );
2985+ if (batch_multiplier.empty ()) {
2986+ LOGS (*session_logger_, INFO) << " Enable QNN HTP batch mutliplier. Don't validate the inputs/outptus" ;
2987+ ORT_RETURN_IF_ERROR_SESSIONID_ (ValidateInputs (feed_names, feeds));
2988+ ORT_RETURN_IF_ERROR_SESSIONID_ (ValidateOutputs (output_names, p_fetches));
2989+ }
2990+ #else
30032991 ORT_RETURN_IF_ERROR_SESSIONID_ (ValidateInputs (feed_names, feeds));
30042992 ORT_RETURN_IF_ERROR_SESSIONID_ (ValidateOutputs (output_names, p_fetches));
3005-
2993+ # endif
30062994 // shrink certain default memory arenas if the user has requested for it
30072995 const std::string& shrink_memory_arenas =
30082996 run_options.config_options .GetConfigOrDefault (kOrtRunOptionsConfigEnableMemoryArenaShrinkage , " " );
0 commit comments