Skip to content

Commit 7c7b2a6

Browse files
authored
[AINode] More strict concurrent inference IT (#16898)
1 parent fad9bde commit 7c7b2a6

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public class AINodeConcurrentForecastIT {
4949
private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
5050

5151
private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
52-
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)";
52+
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, output_length=>%d)";
5353

5454
@BeforeClass
5555
public static void setUp() throws Exception {

integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
package org.apache.iotdb.ainode.utils;
2121

2222
import com.google.common.collect.ImmutableSet;
23+
import org.junit.Assert;
2324
import org.slf4j.Logger;
2425
import org.slf4j.LoggerFactory;
2526

@@ -34,6 +35,7 @@
3435
import java.util.Objects;
3536
import java.util.Set;
3637
import java.util.concurrent.TimeUnit;
38+
import java.util.concurrent.atomic.AtomicBoolean;
3739
import java.util.stream.Collectors;
3840
import java.util.stream.Stream;
3941

@@ -101,6 +103,7 @@ public static void errorTest(Statement statement, String sql, String errorMessag
101103
public static void concurrentInference(
102104
Statement statement, String sql, int threadCnt, int loop, int expectedOutputLength)
103105
throws InterruptedException {
106+
AtomicBoolean allPass = new AtomicBoolean(true);
104107
Thread[] threads = new Thread[threadCnt];
105108
for (int i = 0; i < threadCnt; i++) {
106109
threads[i] =
@@ -113,12 +116,23 @@ public static void concurrentInference(
113116
while (resultSet.next()) {
114117
outputCnt++;
115118
}
116-
assertEquals(expectedOutputLength, outputCnt);
119+
if (expectedOutputLength != outputCnt) {
120+
allPass.set(false);
121+
fail(
122+
"Output count mismatch for SQL: "
123+
+ sql
124+
+ ". Expected: "
125+
+ expectedOutputLength
126+
+ ", but got: "
127+
+ outputCnt);
128+
}
117129
} catch (SQLException e) {
130+
allPass.set(false);
118131
fail(e.getMessage());
119132
}
120133
}
121134
} catch (Exception e) {
135+
allPass.set(false);
122136
fail(e.getMessage());
123137
}
124138
});
@@ -130,6 +144,7 @@ public static void concurrentInference(
130144
fail("Thread timeout after 10 minutes");
131145
}
132146
}
147+
Assert.assertTrue(allPass.get());
133148
}
134149

135150
public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device)

iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,9 @@ def _step(self):
120120
grouped_requests = list(grouped_requests.values())
121121

122122
for requests in grouped_requests:
123-
batch_inputs = self._batcher.batch_request(requests).to(self.device)
123+
batch_inputs = self._batcher.batch_request(requests).to(
124+
"cpu"
125+
) # The input data should first load to CPU in current version
124126
if isinstance(self._inference_pipeline, ForecastPipeline):
125127
batch_output = self._inference_pipeline.forecast(
126128
batch_inputs,

0 commit comments

Comments
 (0)