diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java index 64029c1e34b8..844ec1d82230 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT { private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)"; + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)"; @BeforeClass public static void setUp() throws Exception { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index 0de90c42925f..1d21a4d90f01 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -20,6 +20,7 @@ package org.apache.iotdb.ainode.utils; import com.google.common.collect.ImmutableSet; +import org.junit.Assert; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -34,6 +35,7 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -101,6 +103,7 @@ public static void errorTest(Statement statement, String sql, String errorMessag public static void concurrentInference( Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength) throws InterruptedException { + AtomicBoolean allPass = new AtomicBoolean(true); Thread[] threads = new Thread[threadCnt]; for (int i = 0; i < threadCnt; i++) { threads[i] = @@ -113,12 +116,23 @@ public static void concurrentInference( while (resultSet.next()) { outputCnt++; } - assertEquals(expectedOutputLength, outputCnt); + if (expectedOutputLength != outputCnt) { + allPass.set(false); + fail( + "Output count mismatch for SQL: " + + sql + + ". Expected: " + + expectedOutputLength + + ", but got: " + + outputCnt); + } } catch (SQLException e) { + allPass.set(false); fail(e.getMessage()); } } } catch (Exception e) { + allPass.set(false); fail(e.getMessage()); } }); @@ -130,6 +144,7 @@ public static void concurrentInference( fail("Thread timeout after 10 minutes"); } } + Assert.assertTrue(allPass.get()); } public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index a6c415a6c848..c31bcd3d762e 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -120,7 +120,9 @@ def _step(self): grouped_requests = list(grouped_requests.values()) for requests in grouped_requests: - batch_inputs = self._batcher.batch_request(requests).to(self.device) + batch_inputs = self._batcher.batch_request(requests).to( + "cpu" + ) # The input data should first load to CPU in current version if isinstance(self._inference_pipeline, ForecastPipeline): batch_output = self._inference_pipeline.forecast( batch_inputs,