diff --git a/.github/workflows/cluster-it-1c1d1a.yml b/.github/workflows/cluster-it-1c1d1a.yml
index d4c40fa7ad889..67be8f1a5f6c8 100644
--- a/.github/workflows/cluster-it-1c1d1a.yml
+++ b/.github/workflows/cluster-it-1c1d1a.yml
@@ -41,9 +41,6 @@ jobs:
steps:
- uses: actions/checkout@v4
- - name: Build AINode
- shell: bash
- run: mvn clean package -DskipTests -P with-ainode
- name: IT Test
shell: bash
run: |
diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java
index 357c15c785697..5e418072a7d97 100644
--- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java
+++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java
@@ -137,4 +137,10 @@ public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportio
setProperty("datanode_memory_proportion", dataNodeMemoryProportion);
return this;
}
+
+ @Override
+ public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) {
+ setProperty("query_cost_stat_window", String.valueOf(queryCostStatWindow));
+ return this;
+ }
}
diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
index e118d6c3a98ff..34fd7e85240cf 100644
--- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
+++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java
@@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
private static final String PROPERTIES_FILE = "iotdb-ainode.properties";
public static final String CONFIG_PATH = "conf";
public static final String SCRIPT_PATH = "sbin";
- public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights";
+ public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin";
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";
private void replaceAttribute(String[] keys, String[] values, String filePath) {
diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java
index 1af7cb8f613a8..bba4c964f9578 100644
--- a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java
+++ b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java
@@ -93,4 +93,9 @@ public DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs)
public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion) {
return this;
}
+
+ @Override
+ public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) {
+ return this;
+ }
}
diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java
index 0ae46ffc70f27..d57015b13964e 100644
--- a/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java
+++ b/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java
@@ -51,4 +51,6 @@ DataNodeConfig setLoadTsFileAnalyzeSchemaMemorySizeInBytes(
DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs);
DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion);
+
+ DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow);
}
diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java b/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java
index 3f96fdf1372fc..0523a3848289b 100644
--- a/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java
+++ b/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java
@@ -78,6 +78,13 @@ private void updateConfig(Statement statement, int timeout) throws SQLException
statement.setQueryTimeout(timeout);
}
+ /**
+ * Executes a SQL query on all read statements in parallel.
+ *
+ *
Note: For PreparedStatement EXECUTE queries, use the write connection directly instead,
+ * because PreparedStatements are session-scoped and this method may route queries to different
+ * nodes where the PreparedStatement doesn't exist.
+ */
@Override
public ResultSet executeQuery(String sql) throws SQLException {
return new ClusterTestResultSet(readStatements, readEndpoints, sql, queryTimeout);
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
new file mode 100644
index 0000000000000..44e280eca169b
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.ainode.it;
+
+import org.apache.iotdb.ainode.utils.AINodeTestUtils;
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.AIClusterIT;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+import java.sql.Statement;
+
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
+import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({AIClusterIT.class})
+public class AINodeCallInferenceIT {
+
+ private static final String[] WRITE_SQL_IN_TREE =
+ new String[] {
+ "CREATE DATABASE root.AI",
+ "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
+ "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
+ "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
+ "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
+ };
+
+ private static final String CALL_INFERENCE_SQL_TEMPLATE =
+ "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
+ private static final int DEFAULT_INPUT_LENGTH = 256;
+ private static final int DEFAULT_OUTPUT_LENGTH = 48;
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ // Init 1C1D1A cluster environment
+ EnvFactory.getEnv().initClusterEnvironment(1, 1);
+ prepareData(WRITE_SQL_IN_TREE);
+ try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ for (int i = 0; i < 2880; i++) {
+ statement.execute(
+ String.format(
+ "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
+ i, (float) i, (double) i, i, i));
+ }
+ }
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ @Test
+ public void callInferenceTest() throws SQLException {
+ for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
+ try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ callInferenceTest(statement, modelInfo);
+ }
+ }
+ }
+
+ public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
+ throws SQLException {
+ // Invoke call inference for specified models, there should exist result.
+ for (int i = 0; i < 4; i++) {
+ String callInferenceSQL =
+ String.format(
+ CALL_INFERENCE_SQL_TEMPLATE,
+ modelInfo.getModelId(),
+ i,
+ DEFAULT_INPUT_LENGTH,
+ DEFAULT_OUTPUT_LENGTH);
+ try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
+ ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
+ checkHeader(resultSetMetaData, "Time,output");
+ int count = 0;
+ while (resultSet.next()) {
+ count++;
+ }
+ // Ensure the call inference return results
+ Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count);
+ }
+ }
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
new file mode 100644
index 0000000000000..64029c1e34b8e
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.ainode.it;
+
+import org.apache.iotdb.ainode.utils.AINodeTestUtils;
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.AIClusterIT;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.sql.Connection;
+import java.sql.SQLException;
+import java.sql.Statement;
+
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({AIClusterIT.class})
+public class AINodeConcurrentForecastIT {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);
+
+ private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)";
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ // Init 1C1D1A cluster environment
+ EnvFactory.getEnv().initClusterEnvironment(1, 1);
+ prepareDataForTableModel();
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ private static void prepareDataForTableModel() throws SQLException {
+ try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ statement.execute("CREATE DATABASE root");
+ statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
+ for (int i = 0; i < 2880; i++) {
+ statement.execute(
+ String.format(
+ "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
+ }
+ }
+ }
+
+ @Test
+ public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
+ for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
+ concurrentGPUForecastTest(modelInfo);
+ }
+ }
+
+ public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo)
+ throws SQLException, InterruptedException {
+ final int forecastLength = 512;
+ try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ // Single forecast request can be processed successfully
+ final String forecastSQL =
+ String.format(
+ FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength);
+ final int threadCnt = 10;
+ final int loop = 100;
+ final String devices = "0,1";
+ statement.execute(
+ String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices));
+ checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
+ long startTime = System.currentTimeMillis();
+ concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength);
+ long endTime = System.currentTimeMillis();
+ LOGGER.info(
+ String.format(
+ "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
+ modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime));
+ statement.execute(
+ String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices));
+ checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
+ }
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
deleted file mode 100644
index a08990d472fe6..0000000000000
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java
+++ /dev/null
@@ -1,187 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.ainode.it;
-
-import org.apache.iotdb.it.env.EnvFactory;
-import org.apache.iotdb.it.framework.IoTDBTestRunner;
-import org.apache.iotdb.itbase.category.AIClusterIT;
-import org.apache.iotdb.itbase.env.BaseEnv;
-
-import com.google.common.collect.ImmutableSet;
-import org.junit.AfterClass;
-import org.junit.Assert;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.junit.experimental.categories.Category;
-import org.junit.runner.RunWith;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.sql.Connection;
-import java.sql.ResultSet;
-import java.sql.SQLException;
-import java.sql.Statement;
-import java.util.HashSet;
-import java.util.Set;
-import java.util.concurrent.TimeUnit;
-
-import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;
-
-@RunWith(IoTDBTestRunner.class)
-@Category({AIClusterIT.class})
-public class AINodeConcurrentInferenceIT {
-
- private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class);
-
- @BeforeClass
- public static void setUp() throws Exception {
- // Init 1C1D1A cluster environment
- EnvFactory.getEnv().initClusterEnvironment(1, 1);
- prepareDataForTreeModel();
- prepareDataForTableModel();
- }
-
- @AfterClass
- public static void tearDown() throws Exception {
- EnvFactory.getEnv().cleanClusterEnvironment();
- }
-
- private static void prepareDataForTreeModel() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- statement.execute("CREATE DATABASE root.AI");
- statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE");
- for (int i = 0; i < 2880; i++) {
- statement.execute(
- String.format(
- "INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)",
- i, Math.sin(i * Math.PI / 1440)));
- }
- }
- }
-
- private static void prepareDataForTableModel() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- statement.execute("CREATE DATABASE root");
- statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
- for (int i = 0; i < 2880; i++) {
- statement.execute(
- String.format(
- "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
- }
- }
- }
-
- // @Test
- public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException {
- concurrentGPUCallInferenceTest("timer_xl");
- concurrentGPUCallInferenceTest("sundial");
- }
-
- private void concurrentGPUCallInferenceTest(String modelId)
- throws SQLException, InterruptedException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- final int threadCnt = 10;
- final int loop = 100;
- final int predictLength = 512;
- final String devices = "0,1";
- statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices));
- checkModelOnSpecifiedDevice(statement, modelId, devices);
- concurrentInference(
- statement,
- String.format(
- "CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)",
- modelId, predictLength),
- threadCnt,
- loop,
- predictLength);
- statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId));
- }
- }
-
- String forecastTableFunctionSql =
- "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d";
- String forecastUDTFSql =
- "SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM root.AI";
-
- @Test
- public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
- concurrentGPUForecastTest("timer_xl", forecastUDTFSql);
- concurrentGPUForecastTest("sundial", forecastUDTFSql);
- concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql);
- concurrentGPUForecastTest("sundial", forecastTableFunctionSql);
- }
-
- public void concurrentGPUForecastTest(String modelId, String selectSql)
- throws SQLException, InterruptedException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- final int threadCnt = 10;
- final int loop = 100;
- final int predictLength = 512;
- final String devices = "0,1";
- statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices));
- checkModelOnSpecifiedDevice(statement, modelId, devices);
- long startTime = System.currentTimeMillis();
- concurrentInference(
- statement,
- String.format(selectSql, modelId, predictLength),
- threadCnt,
- loop,
- predictLength);
- long endTime = System.currentTimeMillis();
- LOGGER.info(
- String.format(
- "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
- modelId, threadCnt * loop, threadCnt, loop, endTime - startTime));
- statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId));
- }
- }
-
- private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device)
- throws SQLException, InterruptedException {
- Set targetDevices = ImmutableSet.copyOf(device.split(","));
- LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
- for (int retry = 0; retry < 200; retry++) {
- Set foundDevices = new HashSet<>();
- try (final ResultSet resultSet =
- statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
- while (resultSet.next()) {
- String deviceId = resultSet.getString("DeviceId");
- String loadedModelId = resultSet.getString("ModelId");
- int count = resultSet.getInt("Count(instances)");
- LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
- if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
- foundDevices.add(deviceId);
- LOGGER.info("Model {} is loaded to device {}", modelId, device);
- }
- }
- if (foundDevices.containsAll(targetDevices)) {
- LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices);
- return;
- }
- }
- TimeUnit.SECONDS.sleep(3);
- }
- Assert.fail("Model " + modelId + " is not loaded on device " + device);
- }
-}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
new file mode 100644
index 0000000000000..a06656d4adace
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.ainode.it;
+
+import org.apache.iotdb.ainode.utils.AINodeTestUtils;
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.AIClusterIT;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({AIClusterIT.class})
+public class AINodeForecastIT {
+
+ private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
+ "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)";
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ // Init 1C1D1A cluster environment
+ EnvFactory.getEnv().initClusterEnvironment(1, 1);
+ try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ statement.execute("CREATE DATABASE db");
+ statement.execute(
+ "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)");
+ for (int i = 0; i < 2880; i++) {
+ statement.execute(
+ String.format(
+ "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
+ i, (float) i, (double) i, i, i));
+ }
+ }
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ @Test
+ public void forecastTableFunctionTest() throws SQLException {
+ for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
+ try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ forecastTableFunctionTest(statement, modelInfo);
+ }
+ }
+ }
+
+ public void forecastTableFunctionTest(
+ Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException {
+ // Invoke call inference for specified models, there should exist result.
+ for (int i = 0; i < 4; i++) {
+ String forecastTableFunctionSQL =
+ String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i);
+ try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) {
+ int count = 0;
+ while (resultSet.next()) {
+ count++;
+ }
+ // Ensure the call inference return results
+ Assert.assertTrue(count > 0);
+ }
+ }
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java
deleted file mode 100644
index 70f7a1d9f9eb7..0000000000000
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java
+++ /dev/null
@@ -1,344 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.ainode.it;
-
-import org.apache.iotdb.it.env.EnvFactory;
-import org.apache.iotdb.it.framework.IoTDBTestRunner;
-import org.apache.iotdb.itbase.category.AIClusterIT;
-import org.apache.iotdb.itbase.env.BaseEnv;
-
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.junit.experimental.categories.Category;
-import org.junit.runner.RunWith;
-
-import java.sql.Connection;
-import java.sql.ResultSet;
-import java.sql.ResultSetMetaData;
-import java.sql.SQLException;
-import java.sql.Statement;
-
-import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH;
-import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
-import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
-import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;
-import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData;
-import static org.junit.Assert.assertEquals;
-
-@RunWith(IoTDBTestRunner.class)
-@Category({AIClusterIT.class})
-public class AINodeInferenceSQLIT {
-
- static String[] WRITE_SQL_IN_TREE =
- new String[] {
- "set configuration \"trusted_uri_pattern\"='.*'",
- "create model identity using uri \"" + EXAMPLE_MODEL_PATH + "\"",
- "CREATE DATABASE root.AI",
- "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
- "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
- "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
- "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
- };
-
- static String[] WRITE_SQL_IN_TABLE =
- new String[] {
- "CREATE DATABASE root",
- "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)",
- };
-
- @BeforeClass
- public static void setUp() throws Exception {
- // Init 1C1D1A cluster environment
- EnvFactory.getEnv().initClusterEnvironment(1, 1);
- prepareData(WRITE_SQL_IN_TREE);
- prepareTableData(WRITE_SQL_IN_TABLE);
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- for (int i = 0; i < 2880; i++) {
- statement.execute(
- String.format(
- "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
- i, (float) i, (double) i, i, i));
- }
- }
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- for (int i = 0; i < 2880; i++) {
- statement.execute(
- String.format(
- "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
- i, (float) i, (double) i, i, i));
- }
- }
- }
-
- @AfterClass
- public static void tearDown() throws Exception {
- EnvFactory.getEnv().cleanClusterEnvironment();
- }
-
- // @Test
- public void callInferenceTestInTree() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- callInferenceTest(statement);
- }
- }
-
- // TODO: Enable this test after the call inference is supported by the table model
- // @Test
- public void callInferenceTestInTable() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- callInferenceTest(statement);
- }
- }
-
- public void callInferenceTest(Statement statement) throws SQLException {
- // SQL0: Invoke timer-sundial and timer-xl to inference, the result should success
- try (ResultSet resultSet =
- statement.executeQuery(
- "CALL INFERENCE(sundial, \"select s1 from root.AI\", generateTime=true, predict_length=720)")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "Time,output0");
- int count = 0;
- while (resultSet.next()) {
- count++;
- }
- assertEquals(720, count);
- }
- try (ResultSet resultSet =
- statement.executeQuery(
- "CALL INFERENCE(timer_xl, \"select s2 from root.AI\", generateTime=true, predict_length=256)")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "Time,output0");
- int count = 0;
- while (resultSet.next()) {
- count++;
- }
- assertEquals(256, count);
- }
- // SQL1: user-defined model inferences multi-columns with generateTime=true
- String sql1 =
- "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", generateTime=true)";
- // SQL2: user-defined model inferences multi-columns with generateTime=false
- String sql2 =
- "CALL INFERENCE(identity, \"select s2,s0,s3,s1 from root.AI\", generateTime=false)";
- // SQL3: built-in model inferences single column with given predict_length and multi-outputs
- String sql3 =
- "CALL INFERENCE(naive_forecaster, \"select s0 from root.AI\", predict_length=3, generateTime=true)";
- // SQL4: built-in model inferences single column with given predict_length
- String sql4 =
- "CALL INFERENCE(holtwinters, \"select s0 from root.AI\", predict_length=6, generateTime=true)";
- // TODO: enable following tests after refactor the CALL INFERENCE
-
- // try (ResultSet resultSet = statement.executeQuery(sql1)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "Time,output0,output1,output2,output3");
- // int count = 0;
- // while (resultSet.next()) {
- // float s0 = resultSet.getFloat(2);
- // float s1 = resultSet.getFloat(3);
- // float s2 = resultSet.getFloat(4);
- // float s3 = resultSet.getFloat(5);
- //
- // assertEquals(s0, count + 1.0, 0.0001);
- // assertEquals(s1, count + 2.0, 0.0001);
- // assertEquals(s2, count + 3.0, 0.0001);
- // assertEquals(s3, count + 4.0, 0.0001);
- // count++;
- // }
- // assertEquals(7, count);
- // }
- //
- // try (ResultSet resultSet = statement.executeQuery(sql2)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "output0,output1,output2");
- // int count = 0;
- // while (resultSet.next()) {
- // float s2 = resultSet.getFloat(1);
- // float s0 = resultSet.getFloat(2);
- // float s3 = resultSet.getFloat(3);
- // float s1 = resultSet.getFloat(4);
- //
- // assertEquals(s0, count + 1.0, 0.0001);
- // assertEquals(s1, count + 2.0, 0.0001);
- // assertEquals(s2, count + 3.0, 0.0001);
- // assertEquals(s3, count + 4.0, 0.0001);
- // count++;
- // }
- // assertEquals(7, count);
- // }
-
- // try (ResultSet resultSet = statement.executeQuery(sql3)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "Time,output0,output1,output2");
- // int count = 0;
- // while (resultSet.next()) {
- // count++;
- // }
- // assertEquals(3, count);
- // }
-
- // try (ResultSet resultSet = statement.executeQuery(sql4)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "Time,output0");
- // int count = 0;
- // while (resultSet.next()) {
- // count++;
- // }
- // assertEquals(6, count);
- // }
- }
-
- // @Test
- public void errorCallInferenceTestInTree() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- errorCallInferenceTest(statement);
- }
- }
-
- // TODO: Enable this test after the call inference is supported by the table model
- // @Test
- public void errorCallInferenceTestInTable() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- errorCallInferenceTest(statement);
- }
- }
-
- public void errorCallInferenceTest(Statement statement) {
- String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))";
- errorTest(statement, sql, "1505: model [notFound404] has not been created.");
- sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", window=head(2))";
- // TODO: enable following tests after refactor the CALL INFERENCE
- // errorTest(statement, sql, "701: Window output 2 is not equal to input size of model 7");
- sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI limit 5\")";
- // errorTest(
- // statement,
- // sql,
- // "301: The number of rows 5 in the input data does not match the model input 7. Try to
- // use LIMIT in SQL or WINDOW in CALL INFERENCE");
- sql = "CREATE MODEL 中文 USING URI \"" + EXAMPLE_MODEL_PATH + "\"";
- errorTest(statement, sql, "701: ModelId can only contain letters, numbers, and underscores");
- }
-
- @Test
- public void selectForecastTestInTable() throws SQLException {
- try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
- Statement statement = connection.createStatement()) {
- // SQL0: Invoke timer-sundial and timer-xl to forecast, the result should success
- try (ResultSet resultSet =
- statement.executeQuery(
- "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s1 FROM root.AI) ORDER BY time, output_length=>720)")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "time,s1");
- int count = 0;
- while (resultSet.next()) {
- count++;
- }
- assertEquals(720, count);
- }
- try (ResultSet resultSet =
- statement.executeQuery(
- "SELECT * FROM FORECAST(model_id=>'timer_xl', input=>(SELECT time,s2 FROM root.AI) ORDER BY time, output_length=>256)")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "time,s2");
- int count = 0;
- while (resultSet.next()) {
- count++;
- }
- assertEquals(256, count);
- }
- // SQL1: user-defined model inferences multi-columns with generateTime=true
- String sql1 =
- "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s0,s1,s2,s3 FROM root.AI) ORDER BY time, output_length=>7)";
- // SQL2: user-defined model inferences multi-columns with generateTime=false
- String sql2 =
- "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s2,s0,s3,s1 FROM root.AI) ORDER BY time, output_length=>7)";
- // SQL3: built-in model inferences single column with given predict_length and multi-outputs
- String sql3 =
- "SELECT * FROM FORECAST(model_id=>'naive_forecaster', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>3)";
- // SQL4: built-in model inferences single column with given predict_length
- String sql4 =
- "SELECT * FROM FORECAST(model_id=>'holtwinters', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>6)";
- // TODO: enable following tests after refactor the FORECAST
- // try (ResultSet resultSet = statement.executeQuery(sql1)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "time,s0,s1,s2,s3");
- // int count = 0;
- // while (resultSet.next()) {
- // float s0 = resultSet.getFloat(2);
- // float s1 = resultSet.getFloat(3);
- // float s2 = resultSet.getFloat(4);
- // float s3 = resultSet.getFloat(5);
- //
- // assertEquals(s0, count + 1.0, 0.0001);
- // assertEquals(s1, count + 2.0, 0.0001);
- // assertEquals(s2, count + 3.0, 0.0001);
- // assertEquals(s3, count + 4.0, 0.0001);
- // count++;
- // }
- // assertEquals(7, count);
- // }
- //
- // try (ResultSet resultSet = statement.executeQuery(sql2)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "time,s2,s0,s3,s1");
- // int count = 0;
- // while (resultSet.next()) {
- // float s2 = resultSet.getFloat(1);
- // float s0 = resultSet.getFloat(2);
- // float s3 = resultSet.getFloat(3);
- // float s1 = resultSet.getFloat(4);
- //
- // assertEquals(s0, count + 1.0, 0.0001);
- // assertEquals(s1, count + 2.0, 0.0001);
- // assertEquals(s2, count + 3.0, 0.0001);
- // assertEquals(s3, count + 4.0, 0.0001);
- // count++;
- // }
- // assertEquals(7, count);
- // }
-
- // try (ResultSet resultSet = statement.executeQuery(sql3)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "time,s0,s1,s2");
- // int count = 0;
- // while (resultSet.next()) {
- // count++;
- // }
- // assertEquals(3, count);
- // }
-
- // try (ResultSet resultSet = statement.executeQuery(sql4)) {
- // ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- // checkHeader(resultSetMetaData, "time,s0");
- // int count = 0;
- // while (resultSet.next()) {
- // count++;
- // }
- // assertEquals(6, count);
- // }
- }
- }
-}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
index 93351c0178526..2ae1b860cd230 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java
@@ -35,14 +35,14 @@
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
-import java.util.concurrent.TimeUnit;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
+import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
public class AINodeInstanceManagementIT {
- private static final int WAITING_TIME_SEC = 30;
private static final Set TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1"));
@BeforeClass
@@ -85,52 +85,18 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter
}
// Load sundial to each device
- statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
- TimeUnit.SECONDS.sleep(WAITING_TIME_SEC);
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS 0")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- while (resultSet.next()) {
- Assert.assertEquals("0", resultSet.getString("DeviceID"));
- Assert.assertEquals("Timer-Sundial", resultSet.getString("ModelType"));
- Assert.assertTrue(resultSet.getInt("Count(instances)") > 1);
- }
- }
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- final Set resultDevices = new HashSet<>();
- while (resultSet.next()) {
- resultDevices.add(resultSet.getString("DeviceID"));
- }
- Assert.assertEquals(TARGET_DEVICES, resultDevices);
- }
+ statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES));
+ checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
// Load timer_xl to each device
- statement.execute("LOAD MODEL timer_xl TO DEVICES \"cpu,0,1\"");
- TimeUnit.SECONDS.sleep(WAITING_TIME_SEC);
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- final Set resultDevices = new HashSet<>();
- while (resultSet.next()) {
- if (resultSet.getString("ModelType").equals("Timer-XL")) {
- resultDevices.add(resultSet.getString("DeviceID"));
- }
- Assert.assertTrue(resultSet.getInt("Count(instances)") > 1);
- }
- Assert.assertEquals(TARGET_DEVICES, resultDevices);
- }
+ statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES));
+ checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
// Clean every device
- statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
- statement.execute("UNLOAD MODEL timer_xl FROM DEVICES \"cpu,0,1\"");
- TimeUnit.SECONDS.sleep(WAITING_TIME_SEC);
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- Assert.assertFalse(resultSet.next());
- }
+ statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES));
+ statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES));
+ checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString());
+ checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
}
private static final int LOOP_CNT = 10;
@@ -141,23 +107,9 @@ public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException
Statement statement = connection.createStatement()) {
for (int i = 0; i < LOOP_CNT; i++) {
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
- TimeUnit.SECONDS.sleep(WAITING_TIME_SEC);
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- final Set resultDevices = new HashSet<>();
- while (resultSet.next()) {
- resultDevices.add(resultSet.getString("DeviceID"));
- }
- Assert.assertEquals(TARGET_DEVICES, resultDevices);
- }
+ checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
- TimeUnit.SECONDS.sleep(WAITING_TIME_SEC);
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- Assert.assertFalse(resultSet.next());
- }
+ checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
}
}
}
@@ -170,12 +122,7 @@ public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedExcept
statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\"");
statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\"");
}
- TimeUnit.SECONDS.sleep(WAITING_TIME_SEC * LOOP_CNT);
- try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) {
- ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
- checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)");
- Assert.assertFalse(resultSet.next());
- }
+ checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString());
}
}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
index 2a1461e4a15b6..b92b80aecf321 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java
@@ -19,6 +19,7 @@
package org.apache.iotdb.ainode.it;
+import org.apache.iotdb.ainode.utils.AINodeTestUtils;
import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo;
import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
@@ -36,13 +37,8 @@
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
-import java.util.AbstractMap;
-import java.util.Map;
import java.util.concurrent.TimeUnit;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest;
import static org.junit.Assert.assertEquals;
@@ -54,36 +50,6 @@
@Category({AIClusterIT.class})
public class AINodeModelManageIT {
- private static final Map BUILT_IN_MODEL_MAP =
- Stream.of(
- new AbstractMap.SimpleEntry<>(
- "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "holtwinters",
- new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "exponential_smoothing",
- new FakeModelInfo(
- "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "naive_forecaster",
- new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "stl_forecaster",
- new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "gaussian_hmm",
- new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "sundial", new FakeModelInfo("sundial", "Timer-Sundial", "BUILT-IN", "ACTIVE")),
- new AbstractMap.SimpleEntry<>(
- "timer_xl", new FakeModelInfo("timer_xl", "Timer-XL", "BUILT-IN", "ACTIVE")))
- .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
-
@BeforeClass
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
@@ -95,7 +61,7 @@ public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}
- @Test
+ // @Test
public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
@@ -103,7 +69,7 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup
}
}
- @Test
+ // @Test
public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
@@ -114,8 +80,7 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru
private void userDefinedModelManagementTest(Statement statement)
throws SQLException, InterruptedException {
final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'";
- final String registerSql =
- "create model operationTest using uri \"" + EXAMPLE_MODEL_PATH + "\"";
+ final String registerSql = "create model operationTest using uri \"" + "\"";
final String showSql = "SHOW MODELS operationTest";
final String dropSql = "DROP MODEL operationTest";
@@ -166,7 +131,7 @@ private void userDefinedModelManagementTest(Statement statement)
public void dropBuiltInModelErrorTestInTree() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed");
+ errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial");
}
}
@@ -174,7 +139,7 @@ public void dropBuiltInModelErrorTestInTree() throws SQLException {
public void dropBuiltInModelErrorTestInTable() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
- errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed");
+ errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial");
}
}
@@ -208,10 +173,10 @@ private void showBuiltInModelTest(Statement statement) throws SQLException {
resultSet.getString(2),
resultSet.getString(3),
resultSet.getString(4));
- assertTrue(BUILT_IN_MODEL_MAP.containsKey(modelInfo.getModelId()));
- assertEquals(BUILT_IN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo);
+ assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId()));
+ assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo);
}
}
- assertEquals(BUILT_IN_MODEL_MAP.size(), built_in_model_count);
+ assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.size(), built_in_model_count);
}
}
diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index cbb0b03b22997..0de90c42925fe 100644
--- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -19,30 +19,68 @@
package org.apache.iotdb.ainode.utils;
-import java.io.File;
+import com.google.common.collect.ImmutableSet;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;
+import java.util.AbstractMap;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Map;
import java.util.Objects;
+import java.util.Set;
import java.util.concurrent.TimeUnit;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class AINodeTestUtils {
- public static final String EXAMPLE_MODEL_PATH =
- "file://"
- + System.getProperty("user.dir")
- + File.separator
- + "src"
- + File.separator
- + "test"
- + File.separator
- + "resources"
- + File.separator
- + "ainode-example";
+ public static final Map BUILTIN_LTSM_MAP =
+ Stream.of(
+ new AbstractMap.SimpleEntry<>(
+ "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active")))
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+
+ public static final Map BUILTIN_MODEL_MAP;
+
+ static {
+ Map tmp =
+ Stream.of(
+ new AbstractMap.SimpleEntry<>(
+ "arima", new FakeModelInfo("arima", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "exponential_smoothing",
+ new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "naive_forecaster",
+ new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "stl_forecaster",
+ new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "gaussian_hmm",
+ new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")),
+ new AbstractMap.SimpleEntry<>(
+ "stray", new FakeModelInfo("stray", "sktime", "builtin", "active")))
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ tmp.putAll(BUILTIN_LTSM_MAP);
+ BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp);
+ }
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(AINodeTestUtils.class);
public static void checkHeader(ResultSetMetaData resultSetMetaData, String title)
throws SQLException {
@@ -94,6 +132,63 @@ public static void concurrentInference(
}
}
+ public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device)
+ throws SQLException, InterruptedException {
+ Set targetDevices = ImmutableSet.copyOf(device.split(","));
+ LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices);
+ for (int retry = 0; retry < 200; retry++) {
+ Set foundDevices = new HashSet<>();
+ try (final ResultSet resultSet =
+ statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
+ while (resultSet.next()) {
+ String deviceId = resultSet.getString("DeviceId");
+ String loadedModelId = resultSet.getString("ModelId");
+ int count = resultSet.getInt("Count(instances)");
+ LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
+ if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
+ foundDevices.add(deviceId);
+ LOGGER.info("Model {} is loaded to device {}", modelId, device);
+ }
+ }
+ if (foundDevices.containsAll(targetDevices)) {
+ LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices);
+ return;
+ }
+ }
+ TimeUnit.SECONDS.sleep(3);
+ }
+ fail("Model " + modelId + " is not loaded on device " + device);
+ }
+
+ public static void checkModelNotOnSpecifiedDevice(
+ Statement statement, String modelId, String device)
+ throws SQLException, InterruptedException {
+ Set targetDevices = ImmutableSet.copyOf(device.split(","));
+ LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices);
+ for (int retry = 0; retry < 50; retry++) {
+ Set foundDevices = new HashSet<>();
+ try (final ResultSet resultSet =
+ statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) {
+ while (resultSet.next()) {
+ String deviceId = resultSet.getString("DeviceId");
+ String loadedModelId = resultSet.getString("ModelId");
+ int count = resultSet.getInt("Count(instances)");
+ LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count);
+ if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) {
+ foundDevices.add(deviceId);
+ LOGGER.info("Model {} is loaded to device {}", modelId, device);
+ }
+ }
+ if (foundDevices.isEmpty()) {
+ LOGGER.info("Model {} is unloaded from devices {}.", modelId, targetDevices);
+ return;
+ }
+ }
+ TimeUnit.SECONDS.sleep(3);
+ }
+ fail("Model " + modelId + " is still loaded on device " + device);
+ }
+
public static class FakeModelInfo {
private final String modelId;
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java
new file mode 100644
index 0000000000000..f06d46201aff2
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java
@@ -0,0 +1,385 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.relational.it.db.it;
+
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.TableClusterIT;
+import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;
+import org.apache.iotdb.itbase.runtime.ClusterTestConnection;
+
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+import java.sql.Statement;
+
+import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
+public class IoTDBPreparedStatementIT {
+ private static final String DATABASE_NAME = "test";
+ private static final String[] sqls =
+ new String[] {
+ "CREATE DATABASE " + DATABASE_NAME,
+ "USE " + DATABASE_NAME,
+ "CREATE TABLE test_table(id INT64 FIELD, name STRING FIELD, value DOUBLE FIELD)",
+ "INSERT INTO test_table VALUES (2025-01-01T00:00:00, 1, 'Alice', 100.5)",
+ "INSERT INTO test_table VALUES (2025-01-01T00:01:00, 2, 'Bob', 200.3)",
+ "INSERT INTO test_table VALUES (2025-01-01T00:02:00, 3, 'Charlie', 300.7)",
+ "INSERT INTO test_table VALUES (2025-01-01T00:03:00, 4, 'David', 400.2)",
+ "INSERT INTO test_table VALUES (2025-01-01T00:04:00, 5, 'Eve', 500.9)",
+ };
+
+ protected static void insertData() {
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ for (String sql : sqls) {
+ statement.execute(sql);
+ }
+ } catch (Exception e) {
+ fail("insertData failed: " + e.getMessage());
+ }
+ }
+
+ @BeforeClass
+ public static void setUp() {
+ EnvFactory.getEnv().initClusterEnvironment();
+ insertData();
+ }
+
+ @AfterClass
+ public static void tearDown() {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ /**
+ * Execute a prepared statement query and verify the result. For PreparedStatement EXECUTE
+ * queries, use the write connection directly instead of tableResultSetEqualTest, because
+ * PreparedStatements are session-scoped and tableResultSetEqualTest may route queries to
+ * different nodes where the PreparedStatement doesn't exist.
+ */
+ private static void executePreparedStatementAndVerify(
+ Connection connection,
+ Statement statement,
+ String executeSql,
+ String[] expectedHeader,
+ String[] expectedRetArray)
+ throws SQLException {
+ // Execute with parameters using write connection directly
+ // In cluster test, we need to use write connection to ensure same session
+ if (connection instanceof ClusterTestConnection) {
+ // Use write connection directly for PreparedStatement queries
+ try (Statement writeStatement =
+ ((ClusterTestConnection) connection)
+ .writeConnection
+ .getUnderlyingConnection()
+ .createStatement();
+ ResultSet resultSet = writeStatement.executeQuery(executeSql)) {
+ ResultSetMetaData metaData = resultSet.getMetaData();
+
+ // Verify header
+ assertEquals(expectedHeader.length, metaData.getColumnCount());
+ for (int i = 1; i <= metaData.getColumnCount(); i++) {
+ assertEquals(expectedHeader[i - 1], metaData.getColumnName(i));
+ }
+
+ // Verify data
+ int cnt = 0;
+ while (resultSet.next()) {
+ StringBuilder builder = new StringBuilder();
+ for (int i = 1; i <= expectedHeader.length; i++) {
+ builder.append(resultSet.getString(i)).append(",");
+ }
+ assertEquals(expectedRetArray[cnt], builder.toString());
+ cnt++;
+ }
+ assertEquals(expectedRetArray.length, cnt);
+ }
+ } else {
+ try (ResultSet resultSet = statement.executeQuery(executeSql)) {
+ ResultSetMetaData metaData = resultSet.getMetaData();
+
+ // Verify header
+ assertEquals(expectedHeader.length, metaData.getColumnCount());
+ for (int i = 1; i <= metaData.getColumnCount(); i++) {
+ assertEquals(expectedHeader[i - 1], metaData.getColumnName(i));
+ }
+
+ // Verify data
+ int cnt = 0;
+ while (resultSet.next()) {
+ StringBuilder builder = new StringBuilder();
+ for (int i = 1; i <= expectedHeader.length; i++) {
+ builder.append(resultSet.getString(i)).append(",");
+ }
+ assertEquals(expectedRetArray[cnt], builder.toString());
+ cnt++;
+ }
+ assertEquals(expectedRetArray.length, cnt);
+ }
+ }
+ }
+
+ @Test
+ public void testPrepareAndExecute() {
+ String[] expectedHeader = new String[] {"time", "id", "name", "value"};
+ String[] retArray = new String[] {"2025-01-01T00:01:00.000Z,2,Bob,200.3,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare a statement
+ statement.execute("PREPARE stmt1 FROM SELECT * FROM test_table WHERE id = ?");
+ // Execute with parameter using write connection directly
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt1 USING 2", expectedHeader, retArray);
+ // Deallocate
+ statement.execute("DEALLOCATE PREPARE stmt1");
+ } catch (SQLException e) {
+ fail("testPrepareAndExecute failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPrepareAndExecuteMultipleTimes() {
+ String[] expectedHeader = new String[] {"time", "id", "name", "value"};
+ String[] retArray1 = new String[] {"2025-01-01T00:00:00.000Z,1,Alice,100.5,"};
+ String[] retArray2 = new String[] {"2025-01-01T00:02:00.000Z,3,Charlie,300.7,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare a statement
+ statement.execute("PREPARE stmt2 FROM SELECT * FROM test_table WHERE id = ?");
+ // Execute multiple times with different parameters using write connection directly
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt2 USING 1", expectedHeader, retArray1);
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt2 USING 3", expectedHeader, retArray2);
+ // Deallocate
+ statement.execute("DEALLOCATE PREPARE stmt2");
+ } catch (SQLException e) {
+ fail("testPrepareAndExecuteMultipleTimes failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPrepareWithMultipleParameters() {
+ String[] expectedHeader = new String[] {"time", "id", "name", "value"};
+ String[] retArray = new String[] {"2025-01-01T00:01:00.000Z,2,Bob,200.3,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare a statement with multiple parameters
+ statement.execute("PREPARE stmt3 FROM SELECT * FROM test_table WHERE id = ? AND value > ?");
+ // Execute with multiple parameters using write connection directly
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt3 USING 2, 150.0", expectedHeader, retArray);
+ // Deallocate
+ statement.execute("DEALLOCATE PREPARE stmt3");
+ } catch (SQLException e) {
+ fail("testPrepareWithMultipleParameters failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testExecuteImmediate() {
+ String[] expectedHeader = new String[] {"time", "id", "name", "value"};
+ String[] retArray = new String[] {"2025-01-01T00:03:00.000Z,4,David,400.2,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Execute immediate with SQL string and parameters
+ tableResultSetEqualTest(
+ "EXECUTE IMMEDIATE 'SELECT * FROM test_table WHERE id = ?' USING 4",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ } catch (SQLException e) {
+ fail("testExecuteImmediate failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testExecuteImmediateWithoutParameters() {
+ String[] expectedHeader = new String[] {"_col0"};
+ String[] retArray = new String[] {"5,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Execute immediate without parameters
+ tableResultSetEqualTest(
+ "EXECUTE IMMEDIATE 'SELECT COUNT(*) FROM test_table'",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ } catch (SQLException e) {
+ fail("testExecuteImmediateWithoutParameters failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testExecuteImmediateWithMultipleParameters() {
+ String[] expectedHeader = new String[] {"time", "id", "name", "value"};
+ String[] retArray = new String[] {"2025-01-01T00:04:00.000Z,5,Eve,500.9,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Execute immediate with multiple parameters
+ tableResultSetEqualTest(
+ "EXECUTE IMMEDIATE 'SELECT * FROM test_table WHERE id = ? AND value > ?' USING 5, 450.0",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ } catch (SQLException e) {
+ fail("testExecuteImmediateWithMultipleParameters failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testDeallocateNonExistentStatement() {
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Try to deallocate a non-existent statement
+ SQLException exception =
+ assertThrows(
+ SQLException.class, () -> statement.execute("DEALLOCATE PREPARE non_existent_stmt"));
+ assertTrue(
+ exception.getMessage().contains("does not exist")
+ || exception.getMessage().contains("Prepared statement"));
+ } catch (SQLException e) {
+ fail("testDeallocateNonExistentStatement failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testExecuteNonExistentStatement() {
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Try to execute a non-existent statement
+ SQLException exception =
+ assertThrows(
+ SQLException.class, () -> statement.execute("EXECUTE non_existent_stmt USING 1"));
+ assertTrue(
+ exception.getMessage().contains("does not exist")
+ || exception.getMessage().contains("Prepared statement"));
+ } catch (SQLException e) {
+ fail("testExecuteNonExistentStatement failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testMultiplePreparedStatements() {
+ String[] expectedHeader1 = new String[] {"time", "id", "name", "value"};
+ String[] retArray1 = new String[] {"2025-01-01T00:00:00.000Z,1,Alice,100.5,"};
+ String[] expectedHeader2 = new String[] {"_col0"};
+ String[] retArray2 = new String[] {"4,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare multiple statements
+ statement.execute("PREPARE stmt4 FROM SELECT * FROM test_table WHERE id = ?");
+ statement.execute("PREPARE stmt5 FROM SELECT COUNT(*) FROM test_table WHERE value > ?");
+ // Execute both statements using write connection directly
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt4 USING 1", expectedHeader1, retArray1);
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt5 USING 200.0", expectedHeader2, retArray2);
+ // Deallocate both
+ statement.execute("DEALLOCATE PREPARE stmt4");
+ statement.execute("DEALLOCATE PREPARE stmt5");
+ } catch (SQLException e) {
+ fail("testMultiplePreparedStatements failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPrepareDuplicateName() {
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare a statement
+ statement.execute("PREPARE stmt6 FROM SELECT * FROM test_table WHERE id = ?");
+ // Try to prepare another statement with the same name
+ SQLException exception =
+ assertThrows(
+ SQLException.class,
+ () -> statement.execute("PREPARE stmt6 FROM SELECT * FROM test_table WHERE id = ?"));
+ assertTrue(
+ exception.getMessage().contains("already exists")
+ || exception.getMessage().contains("Prepared statement"));
+ // Cleanup
+ statement.execute("DEALLOCATE PREPARE stmt6");
+ } catch (SQLException e) {
+ fail("testPrepareDuplicateName failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPrepareAndExecuteWithAggregation() {
+ String[] expectedHeader = new String[] {"_col0"};
+ String[] retArray = new String[] {"300.40000000000003,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare a statement with aggregation
+ statement.execute(
+ "PREPARE stmt7 FROM SELECT AVG(value) FROM test_table WHERE id >= ? AND id <= ?");
+ // Execute with parameters using write connection directly
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt7 USING 2, 4", expectedHeader, retArray);
+ // Deallocate
+ statement.execute("DEALLOCATE PREPARE stmt7");
+ } catch (SQLException e) {
+ fail("testPrepareAndExecuteWithAggregation failed: " + e.getMessage());
+ }
+ }
+
+ @Test
+ public void testPrepareAndExecuteWithStringParameter() {
+ String[] expectedHeader = new String[] {"time", "id", "name", "value"};
+ String[] retArray = new String[] {"2025-01-01T00:02:00.000Z,3,Charlie,300.7,"};
+ try (Connection connection = EnvFactory.getEnv().getTableConnection();
+ Statement statement = connection.createStatement()) {
+ statement.execute("USE " + DATABASE_NAME);
+ // Prepare a statement with string parameter
+ statement.execute("PREPARE stmt8 FROM SELECT * FROM test_table WHERE name = ?");
+ // Execute with string parameter using write connection directly
+ executePreparedStatementAndVerify(
+ connection, statement, "EXECUTE stmt8 USING 'Charlie'", expectedHeader, retArray);
+ // Deallocate
+ statement.execute("DEALLOCATE PREPARE stmt8");
+ } catch (SQLException e) {
+ fail("testPrepareAndExecuteWithStringParameter failed: " + e.getMessage());
+ }
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java
new file mode 100644
index 0000000000000..66941ec784f81
--- /dev/null
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.relational.it.query.recent.informationschema;
+
+import org.apache.iotdb.commons.conf.CommonDescriptor;
+import org.apache.iotdb.db.queryengine.execution.QueryState;
+import org.apache.iotdb.it.env.EnvFactory;
+import org.apache.iotdb.it.framework.IoTDBTestRunner;
+import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;
+import org.apache.iotdb.itbase.env.BaseEnv;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.experimental.categories.Category;
+import org.junit.runner.RunWith;
+
+import java.sql.Connection;
+import java.sql.ResultSet;
+import java.sql.ResultSetMetaData;
+import java.sql.SQLException;
+import java.sql.Statement;
+
+import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.END_TIME_TABLE_MODEL;
+import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.NUMS;
+import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL;
+import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATE_TABLE_MODEL;
+import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL;
+import static org.apache.iotdb.commons.schema.table.InformationSchema.getSchemaTables;
+import static org.apache.iotdb.db.it.utils.TestUtils.createUser;
+import static org.apache.iotdb.itbase.env.BaseEnv.TABLE_SQL_DIALECT;
+import static org.junit.Assert.fail;
+
+@RunWith(IoTDBTestRunner.class)
+@Category({TableLocalStandaloneIT.class})
+// This IT will run at least 60s, so we only run it in 1C1D
+public class IoTDBCurrentQueriesIT {
+ private static final int CURRENT_QUERIES_COLUMN_NUM =
+ getSchemaTables().get("current_queries").getColumnNum();
+ private static final int QUERIES_COSTS_HISTOGRAM_COLUMN_NUM =
+ getSchemaTables().get("queries_costs_histogram").getColumnNum();
+ private static final String ADMIN_NAME =
+ CommonDescriptor.getInstance().getConfig().getDefaultAdminName();
+ private static final String ADMIN_PWD =
+ CommonDescriptor.getInstance().getConfig().getAdminPassword();
+
+ @BeforeClass
+ public static void setUp() throws Exception {
+ EnvFactory.getEnv().getConfig().getDataNodeConfig().setQueryCostStatWindow(1);
+ EnvFactory.getEnv().initClusterEnvironment();
+ createUser("test", "test123123456");
+ }
+
+ @AfterClass
+ public static void tearDown() throws Exception {
+ EnvFactory.getEnv().cleanClusterEnvironment();
+ }
+
+ @Test
+ public void testCurrentQueries() {
+ try {
+ Connection connection =
+ EnvFactory.getEnv().getConnection(ADMIN_NAME, ADMIN_PWD, BaseEnv.TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement();
+ statement.execute("USE information_schema");
+
+ // 1. query current_queries table
+ String sql = "SELECT * FROM current_queries";
+ ResultSet resultSet = statement.executeQuery(sql);
+ ResultSetMetaData metaData = resultSet.getMetaData();
+ Assert.assertEquals(CURRENT_QUERIES_COLUMN_NUM, metaData.getColumnCount());
+ int rowNum = 0;
+ while (resultSet.next()) {
+ Assert.assertEquals(QueryState.RUNNING.name(), resultSet.getString(STATE_TABLE_MODEL));
+ Assert.assertEquals(null, resultSet.getString(END_TIME_TABLE_MODEL));
+ Assert.assertEquals(sql, resultSet.getString(STATEMENT_TABLE_MODEL));
+ Assert.assertEquals(ADMIN_NAME, resultSet.getString(USER_TABLE_MODEL));
+ rowNum++;
+ }
+ Assert.assertEquals(1, rowNum);
+ resultSet.close();
+
+ // 2. query queries_costs_histogram table
+ sql = "SELECT * FROM queries_costs_histogram";
+ resultSet = statement.executeQuery(sql);
+ metaData = resultSet.getMetaData();
+ Assert.assertEquals(QUERIES_COSTS_HISTOGRAM_COLUMN_NUM, metaData.getColumnCount());
+ rowNum = 0;
+ int queriesCount = 0;
+ while (resultSet.next()) {
+ int nums = resultSet.getInt(NUMS);
+ if (nums > 0) {
+ queriesCount++;
+ }
+ rowNum++;
+ }
+ Assert.assertEquals(1, queriesCount);
+ Assert.assertEquals(61, rowNum);
+
+ // 3. requery current_queries table
+ sql = "SELECT * FROM current_queries";
+ resultSet = statement.executeQuery(sql);
+ metaData = resultSet.getMetaData();
+ Assert.assertEquals(CURRENT_QUERIES_COLUMN_NUM, metaData.getColumnCount());
+ rowNum = 0;
+ int finishedQueries = 0;
+ while (resultSet.next()) {
+ if (QueryState.FINISHED.name().equals(resultSet.getString(STATE_TABLE_MODEL))) {
+ finishedQueries++;
+ }
+ rowNum++;
+ }
+ // three rows in the result, 2 FINISHED and 1 RUNNING
+ Assert.assertEquals(3, rowNum);
+ Assert.assertEquals(2, finishedQueries);
+ resultSet.close();
+
+ // 4. test the expired QueryInfo was evicted
+ Thread.sleep(61_001);
+ resultSet = statement.executeQuery(sql);
+ rowNum = 0;
+ while (resultSet.next()) {
+ rowNum++;
+ }
+ // one row in the result, current query
+ Assert.assertEquals(1, rowNum);
+ resultSet.close();
+
+ sql = "SELECT * FROM queries_costs_histogram";
+ resultSet = statement.executeQuery(sql);
+ queriesCount = 0;
+ while (resultSet.next()) {
+ int nums = resultSet.getInt(NUMS);
+ if (nums > 0) {
+ queriesCount++;
+ }
+ }
+ // the last current_queries table query was recorded, others are evicted
+ Assert.assertEquals(1, queriesCount);
+ } catch (Exception e) {
+ fail(e.getMessage());
+ }
+
+ // 5. test privilege
+ testPrivilege();
+ }
+
+ private void testPrivilege() {
+ // 1. test current_queries table
+ try (Connection connection =
+ EnvFactory.getEnv().getConnection("test", "test123123456", TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ String sql = "SELECT * FROM information_schema.current_queries";
+
+ // another user executes a query
+ try (Connection connection2 =
+ EnvFactory.getEnv().getConnection(ADMIN_NAME, ADMIN_PWD, BaseEnv.TABLE_SQL_DIALECT)) {
+ ResultSet resultSet = connection2.createStatement().executeQuery(sql);
+ resultSet.close();
+ } catch (Exception e) {
+ fail(e.getMessage());
+ }
+
+ // current user query current_queries table
+ ResultSet resultSet = statement.executeQuery(sql);
+ int rowNum = 0;
+ while (resultSet.next()) {
+ rowNum++;
+ }
+ // only current query in the result
+ Assert.assertEquals(1, rowNum);
+ } catch (SQLException e) {
+ fail(e.getMessage());
+ }
+
+ // 2. test queries_costs_histogram table
+ try (Connection connection =
+ EnvFactory.getEnv().getConnection("test", "test123123456", TABLE_SQL_DIALECT);
+ Statement statement = connection.createStatement()) {
+ statement.executeQuery("SELECT * FROM information_schema.queries_costs_histogram");
+ } catch (SQLException e) {
+ Assert.assertEquals(
+ "803: Access Denied: No permissions for this operation, please add privilege SYSTEM",
+ e.getMessage());
+ }
+ }
+}
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
index e04ff838819e6..4736d9b0521d9 100644
--- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java
@@ -399,15 +399,16 @@ public void testInformationSchema() throws SQLException {
"config_nodes,INF,",
"configurations,INF,",
"connections,INF,",
+ "current_queries,INF,",
"data_nodes,INF,",
"databases,INF,",
"functions,INF,",
"keywords,INF,",
- "models,INF,",
"nodes,INF,",
"pipe_plugins,INF,",
"pipes,INF,",
"queries,INF,",
+ "queries_costs_histogram,INF,",
"regions,INF,",
"subscriptions,INF,",
"tables,INF,",
@@ -504,16 +505,6 @@ public void testInformationSchema() throws SQLException {
"database,STRING,TAG,",
"table_name,STRING,TAG,",
"view_definition,STRING,ATTRIBUTE,")));
- TestUtils.assertResultSetEqual(
- statement.executeQuery("desc models"),
- "ColumnName,DataType,Category,",
- new HashSet<>(
- Arrays.asList(
- "model_id,STRING,TAG,",
- "model_type,STRING,ATTRIBUTE,",
- "state,STRING,ATTRIBUTE,",
- "configs,STRING,ATTRIBUTE,",
- "notes,STRING,ATTRIBUTE,")));
TestUtils.assertResultSetEqual(
statement.executeQuery("desc functions"),
"ColumnName,DataType,Category,",
@@ -638,7 +629,6 @@ public void testInformationSchema() throws SQLException {
"information_schema,pipes,INF,USING,null,SYSTEM VIEW,",
"information_schema,subscriptions,INF,USING,null,SYSTEM VIEW,",
"information_schema,views,INF,USING,null,SYSTEM VIEW,",
- "information_schema,models,INF,USING,null,SYSTEM VIEW,",
"information_schema,functions,INF,USING,null,SYSTEM VIEW,",
"information_schema,configurations,INF,USING,null,SYSTEM VIEW,",
"information_schema,keywords,INF,USING,null,SYSTEM VIEW,",
@@ -646,12 +636,14 @@ public void testInformationSchema() throws SQLException {
"information_schema,config_nodes,INF,USING,null,SYSTEM VIEW,",
"information_schema,data_nodes,INF,USING,null,SYSTEM VIEW,",
"information_schema,connections,INF,USING,null,SYSTEM VIEW,",
+ "information_schema,current_queries,INF,USING,null,SYSTEM VIEW,",
+ "information_schema,queries_costs_histogram,INF,USING,null,SYSTEM VIEW,",
"test,test,INF,USING,test,BASE TABLE,",
"test,view_table,100,USING,null,VIEW FROM TREE,")));
TestUtils.assertResultSetEqual(
statement.executeQuery("count devices from tables where status = 'USING'"),
"count(devices),",
- Collections.singleton("20,"));
+ Collections.singleton("21,"));
TestUtils.assertResultSetEqual(
statement.executeQuery(
"select * from columns where table_name = 'queries' or database = 'test'"),
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java
index 36dc53a4c41b3..748a6937bcf63 100644
--- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java
@@ -816,9 +816,7 @@ private void testObject4SingleIllegalPath(final String illegal) throws Exception
+ File.separator
+ "test-classes"
+ File.separator
- + "ainode-example"
- + File.separator
- + "model.pt";
+ + "object-example.pt";
List schemaList = new ArrayList<>();
schemaList.add(new MeasurementSchema("a", TSDataType.STRING));
diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java
index 7f63d19030291..49884e3c70ba7 100644
--- a/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java
@@ -861,9 +861,7 @@ public void insertObjectTest()
+ File.separator
+ "test-classes"
+ File.separator
- + "ainode-example"
- + File.separator
- + "model.pt";
+ + "object-example.pt";
File object = new File(testObject);
try (ITableSession session = EnvFactory.getEnv().getTableSessionConnection()) {
@@ -929,9 +927,7 @@ public void insertObjectSegmentsTest()
+ File.separator
+ "test-classes"
+ File.separator
- + "ainode-example"
- + File.separator
- + "model.pt";
+ + "object-example.pt";
byte[] objectBytes = Files.readAllBytes(Paths.get(testObject));
List objectSegments = new ArrayList<>();
for (int i = 0; i < objectBytes.length; i += 512) {
diff --git a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java
index bc043c79e8a6f..80aa854e66cc1 100644
--- a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java
@@ -176,7 +176,7 @@ public void testSameDataNodeGetConnections() {
ResultSet resultSet =
statement.executeQuery(
- "SELECT * FROM connections WHERE data_node_id = '" + dataNodeId + "'");
+ "SELECT * FROM connections WHERE datanode_id = '" + dataNodeId + "'");
if (!resultSet.next()) {
fail();
}
@@ -256,7 +256,7 @@ public void testClosedDataNodeGetConnections() throws Exception {
ResultSet resultSet =
statement.executeQuery(
- "SELECT COUNT(*) FROM connections WHERE data_node_id = '" + closedDataNodeId + "'");
+ "SELECT COUNT(*) FROM connections WHERE datanode_id = '" + closedDataNodeId + "'");
if (!resultSet.next()) {
fail();
}
@@ -306,7 +306,7 @@ public void testClosedDataNodeGetConnections() throws Exception {
ResultSet resultSet =
statement.executeQuery(
- "SELECT COUNT(*) FROM connections WHERE data_node_id = '" + closedDataNodeId + "'");
+ "SELECT COUNT(*) FROM connections WHERE datanode_id = '" + closedDataNodeId + "'");
if (!resultSet.next()) {
fail();
}
diff --git a/integration-test/src/test/resources/ainode-example/config.yaml b/integration-test/src/test/resources/ainode-example/config.yaml
deleted file mode 100644
index 665acb8704e24..0000000000000
--- a/integration-test/src/test/resources/ainode-example/config.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-configs:
- input_shape: [7, 4]
- output_shape: [7, 4]
- input_type: ["float32", "float32", "float32", "float32"]
- output_type: ["float32", "float32", "float32", "float32"]
diff --git a/integration-test/src/test/resources/ainode-example/model.pt b/integration-test/src/test/resources/object-example.pt
similarity index 100%
rename from integration-test/src/test/resources/ainode-example/model.pt
rename to integration-test/src/test/resources/object-example.pt
diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec
index a131b2bcff217..bde8b845fb8c5 100644
--- a/iotdb-core/ainode/ainode.spec
+++ b/iotdb-core/ainode/ainode.spec
@@ -41,14 +41,24 @@ all_hiddenimports = []
# Only collect essential data files and binaries for critical libraries
# This reduces startup time by avoiding unnecessary module imports
essential_libraries = {
- 'torch': True, # Keep collect_all for torch as it has many dynamic imports
- 'transformers': True, # Keep collect_all for transformers
- 'safetensors': True, # Keep collect_all for safetensors
+ 'torch': True,
+ 'transformers': True,
+ 'tokenizers': True,
+ 'huggingface_hub': True,
+ 'safetensors': True,
+ 'hf_xet': True,
+ 'numpy': True,
+ 'scipy': True,
+ 'pandas': True,
+ 'sklearn': True,
+ 'statsmodels': True,
+ 'sktime': True,
+ 'pmdarima': True,
+ 'hmmlearn': True,
+ 'accelerate': True
}
-# For other libraries, use selective collection to speed up startup
-other_libraries = ['sktime', 'scipy', 'pandas', 'sklearn', 'statsmodels', 'optuna']
-
+# Collect all libraries using collect_all (includes data files and binaries)
for lib in essential_libraries:
try:
lib_datas, lib_binaries, lib_hiddenimports = collect_all(lib)
@@ -58,48 +68,105 @@ for lib in essential_libraries:
except Exception:
pass
-# For other libraries, only collect submodules (lighter weight)
-# This relies on PyInstaller's dependency analysis to include what's actually used
-for lib in other_libraries:
+# Additionally collect ALL submodules for libraries that commonly have dynamic imports
+# This is a more aggressive approach but ensures we don't miss any modules
+# Libraries that are known to have many dynamic imports and submodules
+libraries_with_dynamic_imports = [
+ 'scipy', # Has many subpackages: stats, interpolate, optimize, linalg, sparse, signal, etc.
+ 'sklearn', # Has many submodules that may be dynamically imported
+ 'transformers', # Has dynamic model loading
+ 'torch', # Has many submodules, especially _dynamo.polyfills
+]
+
+# Collect all submodules for these libraries to ensure comprehensive coverage
+for lib in libraries_with_dynamic_imports:
try:
submodules = collect_submodules(lib)
all_hiddenimports.extend(submodules)
- # Only collect essential data files and binaries, not all submodules
- # This significantly reduces startup time
- try:
- lib_datas, lib_binaries, _ = collect_all(lib)
- all_datas.extend(lib_datas)
- all_binaries.extend(lib_binaries)
- except Exception:
- # If collect_all fails, try collect_data_files for essential data only
- try:
- lib_datas = collect_data_files(lib)
- all_datas.extend(lib_datas)
- except Exception:
- pass
- except Exception:
- pass
+ print(f"Collected {len(submodules)} submodules from {lib}")
+ except Exception as e:
+ print(f"Warning: Failed to collect submodules from {lib}: {e}")
+
+
+# Helper function to collect submodules with fallback
+def collect_submodules_with_fallback(package, fallback_modules=None, package_name=None):
+ """
+ Collect all submodules for a package, with fallback to manual module list if collection fails.
+
+ Args:
+ package: Package name to collect submodules from
+ fallback_modules: List of module names to add if collection fails (optional)
+ package_name: Display name for logging (defaults to package)
+ """
+ if package_name is None:
+ package_name = package
+ try:
+ submodules = collect_submodules(package)
+ all_hiddenimports.extend(submodules)
+ print(f"Collected {len(submodules)} submodules from {package_name}")
+ except Exception as e:
+ print(f"Warning: Failed to collect {package_name} submodules: {e}")
+ if fallback_modules:
+ all_hiddenimports.extend(fallback_modules)
+ print(f"Using fallback modules for {package_name}")
+
+
+# Additional specific packages that need submodule collection
+# Note: scipy, sklearn, transformers, torch are already collected above via libraries_with_dynamic_imports
+# This section is for more specific sub-packages that need special handling
+# Format: (package_name, fallback_modules_list, display_name)
+submodule_collection_configs = [
+ # torch._dynamo.polyfills - critical for torch dynamo functionality
+ # (torch is already collected above, but this ensures polyfills are included)
+ (
+ 'torch._dynamo.polyfills',
+ [
+ 'torch._dynamo.polyfills',
+ 'torch._dynamo.polyfills.functools',
+ 'torch._dynamo.polyfills.operator',
+ 'torch._dynamo.polyfills.collections',
+ ],
+ 'torch._dynamo.polyfills'
+ ),
+ # transformers sub-packages with dynamic imports
+ # (transformers is already collected above, but these specific sub-packages may need extra attention)
+ ('transformers.generation', None, 'transformers.generation'),
+ ('transformers.models.auto', None, 'transformers.models.auto'),
+]
+
+# Collect submodules for all configured packages
+for package, fallback_modules, display_name in submodule_collection_configs:
+ collect_submodules_with_fallback(package, fallback_modules, display_name)
# Project-specific packages that need their submodules collected
# Only list top-level packages - collect_submodules will recursively collect all submodules
-TOP_LEVEL_PACKAGES = [
+project_packages = [
'iotdb.ainode.core', # This will include all sub-packages: manager, model, inference, etc.
'iotdb.thrift', # This will include all thrift sub-packages
]
# Collect all submodules for project packages automatically
# Using top-level packages avoids duplicate collection
-for package in TOP_LEVEL_PACKAGES:
- try:
- submodules = collect_submodules(package)
- all_hiddenimports.extend(submodules)
- except Exception:
- # If package doesn't exist or collection fails, add the package itself
- all_hiddenimports.append(package)
+# If collection fails, add the package itself as fallback
+for package in project_packages:
+ collect_submodules_with_fallback(package, fallback_modules=[package], package_name=package)
# Add parent packages to ensure they are included
all_hiddenimports.extend(['iotdb', 'iotdb.ainode'])
+# Fix circular import issues in scipy.stats
+# scipy.stats has circular imports that can cause issues in PyInstaller
+# We need to ensure _stats is imported before scipy.stats tries to import it
+# This helps resolve the "partially initialized module" error
+scipy_stats_critical_modules = [
+ 'scipy.stats._stats', # Core stats module, must be imported first
+ 'scipy.stats._stats_py', # Python implementation
+ 'scipy.stats._continuous_distns', # Continuous distributions
+ 'scipy.stats._discrete_distns', # Discrete distributions
+ 'scipy.stats.distributions', # Distribution base classes
+]
+all_hiddenimports.extend(scipy_stats_critical_modules)
+
# Multiprocessing support for PyInstaller
# When using multiprocessing with PyInstaller, we need to ensure proper handling
multiprocessing_modules = [
@@ -119,9 +186,6 @@ multiprocessing_modules = [
# Additional dependencies that may need explicit import
# These are external libraries that might use dynamic imports
external_dependencies = [
- 'huggingface_hub',
- 'tokenizers',
- 'hf_xet',
'einops',
'dynaconf',
'tzlocal',
@@ -161,7 +225,9 @@ a = Analysis(
win_no_prefer_redirects=False,
win_private_assemblies=False,
cipher=block_cipher,
- noarchive=True, # Set to True to speed up startup - files are not archived into PYZ
+ noarchive=False, # Set to False to avoid circular import issues with scipy.stats
+ # When noarchive=True, modules are loaded as separate files which can cause
+ # circular import issues. Using PYZ archive helps PyInstaller handle module loading order better.
)
# Package all PYZ files
diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py
index afcf0683d7d04..e465df7e36d29 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/config.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/config.py
@@ -20,7 +20,6 @@
from iotdb.ainode.core.constant import (
AINODE_BUILD_INFO,
- AINODE_BUILTIN_MODELS_DIR,
AINODE_CLUSTER_INGRESS_ADDRESS,
AINODE_CLUSTER_INGRESS_PASSWORD,
AINODE_CLUSTER_INGRESS_PORT,
@@ -33,10 +32,11 @@
AINODE_CONF_POM_FILE_NAME,
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS,
AINODE_INFERENCE_EXTRA_MEMORY_RATIO,
- AINODE_INFERENCE_MAX_PREDICT_LENGTH,
+ AINODE_INFERENCE_MAX_OUTPUT_LENGTH,
AINODE_INFERENCE_MEMORY_USAGE_RATIO,
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP,
AINODE_LOG_DIR,
+ AINODE_MODELS_BUILTIN_DIR,
AINODE_MODELS_DIR,
AINODE_RPC_ADDRESS,
AINODE_RPC_PORT,
@@ -75,9 +75,7 @@ def __init__(self):
self._ain_inference_batch_interval_in_ms: int = (
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS
)
- self._ain_inference_max_predict_length: int = (
- AINODE_INFERENCE_MAX_PREDICT_LENGTH
- )
+ self._ain_inference_max_output_length: int = AINODE_INFERENCE_MAX_OUTPUT_LENGTH
self._ain_inference_model_mem_usage_map: dict[str, int] = (
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP
)
@@ -95,7 +93,7 @@ def __init__(self):
# Directory to save models
self._ain_models_dir = AINODE_MODELS_DIR
- self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR
+ self._ain_models_builtin_dir = AINODE_MODELS_BUILTIN_DIR
self._ain_system_dir = AINODE_SYSTEM_DIR
# Whether to enable compression for thrift
@@ -160,13 +158,13 @@ def set_ain_inference_batch_interval_in_ms(
) -> None:
self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms
- def get_ain_inference_max_predict_length(self) -> int:
- return self._ain_inference_max_predict_length
+ def get_ain_inference_max_output_length(self) -> int:
+ return self._ain_inference_max_output_length
- def set_ain_inference_max_predict_length(
- self, ain_inference_max_predict_length: int
+ def set_ain_inference_max_output_length(
+ self, ain_inference_max_output_length: int
) -> None:
- self._ain_inference_max_predict_length = ain_inference_max_predict_length
+ self._ain_inference_max_output_length = ain_inference_max_output_length
def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]:
return self._ain_inference_model_mem_usage_map
@@ -204,11 +202,11 @@ def get_ain_models_dir(self) -> str:
def set_ain_models_dir(self, ain_models_dir: str) -> None:
self._ain_models_dir = ain_models_dir
- def get_ain_builtin_models_dir(self) -> str:
- return self._ain_builtin_models_dir
+ def get_ain_models_builtin_dir(self) -> str:
+ return self._ain_models_builtin_dir
- def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None:
- self._ain_builtin_models_dir = ain_builtin_models_dir
+ def set_ain_models_builtin_dir(self, ain_models_builtin_dir: str) -> None:
+ self._ain_models_builtin_dir = ain_models_builtin_dir
def get_ain_system_dir(self) -> str:
return self._ain_system_dir
@@ -374,6 +372,11 @@ def _load_config_from_file(self) -> None:
if "ain_models_dir" in config_keys:
self._config.set_ain_models_dir(file_configs["ain_models_dir"])
+ if "ain_models_builtin_dir" in config_keys:
+ self._config.set_ain_models_builtin_dir(
+ file_configs["ain_models_builtin_dir"]
+ )
+
if "ain_system_dir" in config_keys:
self._config.set_ain_system_dir(file_configs["ain_system_dir"])
diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py
index b9923d3e3ee7e..d8f730c829c89 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/constant.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py
@@ -18,9 +18,7 @@
import logging
import os
from enum import Enum
-from typing import List
-from iotdb.ainode.core.model.model_enums import BuiltInModelType
from iotdb.thrift.common.ttypes import TEndPoint
IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "")
@@ -50,21 +48,21 @@
# AINode inference configuration
AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15
-AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880
+AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880
+
+# TODO: Should be optimized
AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = {
- BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB
- BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB
+ "sundial": 1036 * 1024**2, # 1036 MiB
+ "timer": 856 * 1024**2, # 856 MiB
} # the memory usage of each model in bytes
+
AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference
AINODE_INFERENCE_EXTRA_MEMORY_RATIO = (
1.2 # the overhead ratio for inference, used to estimate the pool size
)
-# AINode folder structure
AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models")
-AINODE_BUILTIN_MODELS_DIR = os.path.join(
- IOTDB_AINODE_HOME, "data/ainode/models/weights"
-) # For built-in models, we only need to store their weights and config.
+AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model"
AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system")
AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs")
@@ -77,11 +75,6 @@
"log_inference_rank_{}_" # example: log_inference_rank_0_all.log
)
-# AINode model management
-MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors"
-MODEL_CONFIG_FILE_IN_JSON = "config.json"
-MODEL_WEIGHTS_FILE_IN_PT = "model.pt"
-MODEL_CONFIG_FILE_IN_YAML = "config.yaml"
DEFAULT_CHUNK_SIZE = 8192
@@ -100,27 +93,6 @@ def get_status_code(self) -> int:
return self.value
-class TaskType(Enum):
- FORECAST = "forecast"
-
-
-class OptionsKey(Enum):
- # common
- TASK_TYPE = "task_type"
- MODEL_TYPE = "model_type"
- AUTO_TUNING = "auto_tuning"
- INPUT_VARS = "input_vars"
-
- # forecast
- INPUT_LENGTH = "input_length"
- PREDICT_LENGTH = "predict_length"
- PREDICT_INDEX_LIST = "predict_index_list"
- INPUT_TYPE_LIST = "input_type_list"
-
- def name(self) -> str:
- return self.value
-
-
class HyperparameterName(Enum):
# Training hyperparameter
LEARNING_RATE = "learning_rate"
@@ -139,134 +111,3 @@ class HyperparameterName(Enum):
def name(self):
return self.value
-
-
-class ForecastModelType(Enum):
- DLINEAR = "dlinear"
- DLINEAR_INDIVIDUAL = "dlinear_individual"
- NBEATS = "nbeats"
-
- @classmethod
- def values(cls) -> List[str]:
- values = []
- for item in list(cls):
- values.append(item.value)
- return values
-
-
-class ModelInputName(Enum):
- DATA_X = "data_x"
- TIME_STAMP_X = "time_stamp_x"
- TIME_STAMP_Y = "time_stamp_y"
- DEC_INP = "dec_inp"
-
-
-class AttributeName(Enum):
- # forecast Attribute
- PREDICT_LENGTH = "predict_length"
-
- # NaiveForecaster
- STRATEGY = "strategy"
- SP = "sp"
-
- # STLForecaster
- # SP = 'sp'
- SEASONAL = "seasonal"
- SEASONAL_DEG = "seasonal_deg"
- TREND_DEG = "trend_deg"
- LOW_PASS_DEG = "low_pass_deg"
- SEASONAL_JUMP = "seasonal_jump"
- TREND_JUMP = "trend_jump"
- LOSS_PASS_JUMP = "low_pass_jump"
-
- # ExponentialSmoothing
- DAMPED_TREND = "damped_trend"
- INITIALIZATION_METHOD = "initialization_method"
- OPTIMIZED = "optimized"
- REMOVE_BIAS = "remove_bias"
- USE_BRUTE = "use_brute"
-
- # Arima
- ORDER = "order"
- SEASONAL_ORDER = "seasonal_order"
- METHOD = "method"
- MAXITER = "maxiter"
- SUPPRESS_WARNINGS = "suppress_warnings"
- OUT_OF_SAMPLE_SIZE = "out_of_sample_size"
- SCORING = "scoring"
- WITH_INTERCEPT = "with_intercept"
- TIME_VARYING_REGRESSION = "time_varying_regression"
- ENFORCE_STATIONARITY = "enforce_stationarity"
- ENFORCE_INVERTIBILITY = "enforce_invertibility"
- SIMPLE_DIFFERENCING = "simple_differencing"
- MEASUREMENT_ERROR = "measurement_error"
- MLE_REGRESSION = "mle_regression"
- HAMILTON_REPRESENTATION = "hamilton_representation"
- CONCENTRATE_SCALE = "concentrate_scale"
-
- # GAUSSIAN_HMM
- N_COMPONENTS = "n_components"
- COVARIANCE_TYPE = "covariance_type"
- MIN_COVAR = "min_covar"
- STARTPROB_PRIOR = "startprob_prior"
- TRANSMAT_PRIOR = "transmat_prior"
- MEANS_PRIOR = "means_prior"
- MEANS_WEIGHT = "means_weight"
- COVARS_PRIOR = "covars_prior"
- COVARS_WEIGHT = "covars_weight"
- ALGORITHM = "algorithm"
- N_ITER = "n_iter"
- TOL = "tol"
- PARAMS = "params"
- INIT_PARAMS = "init_params"
- IMPLEMENTATION = "implementation"
-
- # GMMHMM
- # N_COMPONENTS = "n_components"
- N_MIX = "n_mix"
- # MIN_COVAR = "min_covar"
- # STARTPROB_PRIOR = "startprob_prior"
- # TRANSMAT_PRIOR = "transmat_prior"
- WEIGHTS_PRIOR = "weights_prior"
-
- # MEANS_PRIOR = "means_prior"
- # MEANS_WEIGHT = "means_weight"
- # ALGORITHM = "algorithm"
- # COVARIANCE_TYPE = "covariance_type"
- # N_ITER = "n_iter"
- # TOL = "tol"
- # INIT_PARAMS = "init_params"
- # PARAMS = "params"
- # IMPLEMENTATION = "implementation"
-
- # STRAY
- ALPHA = "alpha"
- K = "k"
- KNN_ALGORITHM = "knn_algorithm"
- P = "p"
- SIZE_THRESHOLD = "size_threshold"
- OUTLIER_TAIL = "outlier_tail"
-
- # timerxl
- INPUT_TOKEN_LEN = "input_token_len"
- HIDDEN_SIZE = "hidden_size"
- INTERMEDIATE_SIZE = "intermediate_size"
- OUTPUT_TOKEN_LENS = "output_token_lens"
- NUM_HIDDEN_LAYERS = "num_hidden_layers"
- NUM_ATTENTION_HEADS = "num_attention_heads"
- HIDDEN_ACT = "hidden_act"
- USE_CACHE = "use_cache"
- ROPE_THETA = "rope_theta"
- ATTENTION_DROPOUT = "attention_dropout"
- INITIALIZER_RANGE = "initializer_range"
- MAX_POSITION_EMBEDDINGS = "max_position_embeddings"
- CKPT_PATH = "ckpt_path"
-
- # sundial
- DROPOUT_RATE = "dropout_rate"
- FLOW_LOSS_DEPTH = "flow_loss_depth"
- NUM_SAMPLING_STEPS = "num_sampling_steps"
- DIFFUSION_BATCH_MUL = "diffusion_batch_mul"
-
- def name(self) -> str:
- return self.value
diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py
index bc89cdc306625..30b9d54dcc7df 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/exception.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py
@@ -17,7 +17,7 @@
#
import re
-from iotdb.ainode.core.constant import (
+from iotdb.ainode.core.model.model_constants import (
MODEL_CONFIG_FILE_IN_YAML,
MODEL_WEIGHTS_FILE_IN_PT,
)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
index 82c72cc37abf5..50634914c2737 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py
@@ -15,14 +15,12 @@
# specific language governing permissions and limitations
# under the License.
#
+
import threading
from typing import Any
import torch
-from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import (
- AbstractInferencePipeline,
-)
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.util.atmoic_int import AtomicInt
@@ -41,8 +39,7 @@ def __init__(
req_id: str,
model_id: str,
inputs: torch.Tensor,
- inference_pipeline: AbstractInferencePipeline,
- max_new_tokens: int = 96,
+ output_length: int = 96,
**infer_kwargs,
):
if inputs.ndim == 1:
@@ -52,9 +49,8 @@ def __init__(
self.model_id = model_id
self.inputs = inputs
self.infer_kwargs = infer_kwargs
- self.inference_pipeline = inference_pipeline
- self.max_new_tokens = (
- max_new_tokens # Number of time series data points to generate
+ self.output_length = (
+ output_length # Number of time series data points to generate
)
self.batch_size = inputs.size(0)
@@ -65,7 +61,7 @@ def __init__(
# Preallocate output buffer [batch_size, max_new_tokens]
self.output_tensor = torch.zeros(
- self.batch_size, max_new_tokens, device="cpu"
+ self.batch_size, output_length, device="cpu"
) # shape: [self.batch_size, max_new_steps]
def mark_running(self):
@@ -77,7 +73,7 @@ def mark_finished(self):
def is_finished(self) -> bool:
return (
self.state == InferenceRequestState.FINISHED
- or self.cur_step_idx >= self.max_new_tokens
+ or self.cur_step_idx >= self.output_length
)
def write_step_output(self, step_output: torch.Tensor):
@@ -87,11 +83,11 @@ def write_step_output(self, step_output: torch.Tensor):
batch_size, step_size = step_output.shape
end_idx = self.cur_step_idx + step_size
- if end_idx > self.max_new_tokens:
+ if end_idx > self.output_length:
self.output_tensor[:, self.cur_step_idx :] = step_output[
- :, : self.max_new_tokens - self.cur_step_idx
+ :, : self.output_length - self.cur_step_idx
]
- self.cur_step_idx = self.max_new_tokens
+ self.cur_step_idx = self.output_length
else:
self.output_tensor[:, self.cur_step_idx : end_idx] = step_output
self.cur_step_idx = end_idx
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
index 6b054c91fe31c..a6c415a6c848b 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py
@@ -25,19 +25,22 @@
import numpy as np
import torch
import torch.multiprocessing as mp
-from transformers import PretrainedConfig
from iotdb.ainode.core.config import AINodeDescriptor
from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE
from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher
from iotdb.ainode.core.inference.inference_request import InferenceRequest
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import (
+ ChatPipeline,
+ ClassificationPipeline,
+ ForecastPipeline,
+)
+from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline
from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import (
BasicRequestScheduler,
)
from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.manager.model_manager import ModelManager
-from iotdb.ainode.core.model.model_enums import BuiltInModelType
-from iotdb.ainode.core.model.model_info import ModelInfo
+from iotdb.ainode.core.model.model_storage import ModelInfo
from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device
@@ -62,7 +65,6 @@ def __init__(
pool_id: int,
model_info: ModelInfo,
device: str,
- config: PretrainedConfig,
request_queue: mp.Queue,
result_queue: mp.Queue,
ready_event,
@@ -71,7 +73,6 @@ def __init__(
super().__init__()
self.pool_id = pool_id
self.model_info = model_info
- self.config = config
self.pool_kwargs = pool_kwargs
self.ready_event = ready_event
self.device = convert_device_id_to_torch_device(device)
@@ -86,8 +87,8 @@ def __init__(
self._batcher = BasicBatcher()
self._stop_event = mp.Event()
- self._model = None
- self._model_manager = None
+ self._inference_pipeline = None
+
self._logger = None
# Fix inference seed
@@ -98,9 +99,6 @@ def __init__(
def _activate_requests(self):
requests = self._request_scheduler.schedule_activate()
for request in requests:
- request.inputs = request.inference_pipeline.preprocess_inputs(
- request.inputs
- )
request.mark_running()
self._running_queue.put(request)
self._logger.debug(
@@ -117,72 +115,51 @@ def _step(self):
grouped_requests = defaultdict(list)
for req in all_requests:
- key = (req.inputs.shape[1], req.max_new_tokens)
+ key = (req.inputs.shape[1], req.output_length)
grouped_requests[key].append(req)
grouped_requests = list(grouped_requests.values())
for requests in grouped_requests:
batch_inputs = self._batcher.batch_request(requests).to(self.device)
- if self.model_info.model_type == BuiltInModelType.SUNDIAL.value:
- batch_output = self._model.generate(
+ if isinstance(self._inference_pipeline, ForecastPipeline):
+ batch_output = self._inference_pipeline.forecast(
batch_inputs,
- max_new_tokens=requests[0].max_new_tokens,
- num_samples=10,
+ predict_length=requests[0].output_length,
revin=True,
)
-
- offset = 0
- for request in requests:
- request.output_tensor = request.output_tensor.to(self.device)
- cur_batch_size = request.batch_size
- cur_output = batch_output[offset : offset + cur_batch_size]
- offset += cur_batch_size
- request.write_step_output(cur_output.mean(dim=1))
-
- request.inference_pipeline.post_decode()
- if request.is_finished():
- request.inference_pipeline.post_inference()
- self._logger.debug(
- f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
- )
- # ensure the output tensor is on CPU before sending to result queue
- request.output_tensor = request.output_tensor.cpu()
- self._finished_queue.put(request)
- else:
- self._logger.debug(
- f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
- )
- self._waiting_queue.put(request)
-
- elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value:
- batch_output = self._model.generate(
+ elif isinstance(self._inference_pipeline, ClassificationPipeline):
+ batch_output = self._inference_pipeline.classify(
batch_inputs,
- max_new_tokens=requests[0].max_new_tokens,
- revin=True,
+ # more infer kwargs can be added here
)
-
- offset = 0
- for request in requests:
- request.output_tensor = request.output_tensor.to(self.device)
- cur_batch_size = request.batch_size
- cur_output = batch_output[offset : offset + cur_batch_size]
- offset += cur_batch_size
- request.write_step_output(cur_output)
-
- request.inference_pipeline.post_decode()
- if request.is_finished():
- request.inference_pipeline.post_inference()
- self._logger.debug(
- f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
- )
- # ensure the output tensor is on CPU before sending to result queue
- request.output_tensor = request.output_tensor.cpu()
- self._finished_queue.put(request)
- else:
- self._logger.debug(
- f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
- )
- self._waiting_queue.put(request)
+ elif isinstance(self._inference_pipeline, ChatPipeline):
+ batch_output = self._inference_pipeline.chat(
+ batch_inputs,
+ # more infer kwargs can be added here
+ )
+ else:
+ self._logger.error("[Inference] Unsupported pipeline type.")
+ offset = 0
+ for request in requests:
+ request.output_tensor = request.output_tensor.to(self.device)
+ cur_batch_size = request.batch_size
+ cur_output = batch_output[offset : offset + cur_batch_size]
+ offset += cur_batch_size
+ request.write_step_output(cur_output)
+
+ if request.is_finished():
+ # ensure the output tensor is on CPU before sending to result queue
+ request.output_tensor = request.output_tensor.cpu()
+ self._finished_queue.put(request)
+ self._logger.debug(
+ f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished"
+ )
+ else:
+ self._waiting_queue.put(request)
+ self._logger.debug(
+ f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing"
+ )
+ return
def _requests_execute_loop(self):
while not self._stop_event.is_set():
@@ -193,11 +170,8 @@ def run(self):
self._logger = Logger(
INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device)
)
- self._model_manager = ModelManager()
self._request_scheduler.device = self.device
- self._model = self._model_manager.load_model(self.model_info.model_id, {}).to(
- self.device
- )
+ self._inference_pipeline = load_pipeline(self.model_info, str(self.device))
self.ready_event.set()
activate_daemon = threading.Thread(
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py
rename to iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
new file mode 100644
index 0000000000000..82601e398059c
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py
@@ -0,0 +1,87 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from abc import ABC, abstractmethod
+
+import torch
+
+from iotdb.ainode.core.model.model_loader import load_model
+
+
+class BasicPipeline(ABC):
+ def __init__(self, model_info, **model_kwargs):
+ self.model_info = model_info
+ self.device = model_kwargs.get("device", "cpu")
+ self.model = load_model(model_info, device_map=self.device, **model_kwargs)
+
+ def _preprocess(self, inputs):
+ """
+ Preprocess the input before inference, including shape validation and value transformation.
+ """
+ return inputs
+
+ def _postprocess(self, output: torch.Tensor):
+ """
+ Post-process the outputs after the entire inference task.
+ """
+ return output
+
+
+class ForecastPipeline(BasicPipeline):
+ def __init__(self, model_info, **model_kwargs):
+ super().__init__(model_info, model_kwargs=model_kwargs)
+
+ def _preprocess(self, inputs):
+ return inputs
+
+ @abstractmethod
+ def forecast(self, inputs, **infer_kwargs):
+ pass
+
+ def _postprocess(self, output: torch.Tensor):
+ return output
+
+
+class ClassificationPipeline(BasicPipeline):
+ def __init__(self, model_info, **model_kwargs):
+ super().__init__(model_info, model_kwargs=model_kwargs)
+
+ def _preprocess(self, inputs):
+ return inputs
+
+ @abstractmethod
+ def classify(self, inputs, **kwargs):
+ pass
+
+ def _postprocess(self, output: torch.Tensor):
+ return output
+
+
+class ChatPipeline(BasicPipeline):
+ def __init__(self, model_info, **model_kwargs):
+ super().__init__(model_info, model_kwargs=model_kwargs)
+
+ def _preprocess(self, inputs):
+ return inputs
+
+ @abstractmethod
+ def chat(self, inputs, **kwargs):
+ pass
+
+ def _postprocess(self, output: torch.Tensor):
+ return output
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
new file mode 100644
index 0000000000000..a30038dd5feff
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py
@@ -0,0 +1,56 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from pathlib import Path
+
+from iotdb.ainode.core.config import AINodeDescriptor
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.model_constants import ModelCategory
+from iotdb.ainode.core.model.model_storage import ModelInfo
+from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path
+
+logger = Logger()
+
+
+def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs):
+ if model_info.model_type == "sktime":
+ from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline
+
+ pipeline_cls = SktimePipeline
+ elif model_info.category == ModelCategory.BUILTIN:
+ module_name = (
+ AINodeDescriptor().get_config().get_ain_models_builtin_dir()
+ + "."
+ + model_info.model_id
+ )
+ pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls)
+ else:
+ model_path = os.path.join(
+ os.getcwd(),
+ AINodeDescriptor().get_config().get_ain_models_dir(),
+ model_info.category.value,
+ model_info.model_id,
+ )
+ module_parent = str(Path(model_path).parent.absolute())
+ with temporary_sys_path(module_parent):
+ pipeline_cls = import_class_from_path(
+ model_info.model_id, model_info.pipeline_cls
+ )
+
+ return pipeline_cls(model_info, device=device, **model_kwargs)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
index 54580402ec293..8ffa89ffd6752 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py
@@ -22,7 +22,6 @@
from concurrent.futures import wait
from typing import Dict, Optional
-import torch
import torch.multiprocessing as mp
from iotdb.ainode.core.exception import InferenceModelInternalError
@@ -41,9 +40,6 @@
)
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.model_manager import ModelManager
-from iotdb.ainode.core.model.model_enums import BuiltInModelType
-from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
-from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
from iotdb.ainode.core.util.atmoic_int import AtomicInt
from iotdb.ainode.core.util.batch_executor import BatchExecutor
from iotdb.ainode.core.util.decorator import synchronized
@@ -76,9 +72,9 @@ def __init__(self, result_queue: mp.Queue):
thread_name_prefix=ThreadName.INFERENCE_POOL_CONTROLLER.value
)
- # =============== Pool Management ===============
+ # =============== Automatic Pool Management (Developing) ===============
@synchronized(threading.Lock())
- def first_req_init(self, model_id: str):
+ def first_req_init(self, model_id: str, device):
"""
Initialize the pools when the first request for the given model_id arrives.
"""
@@ -107,38 +103,35 @@ def _first_pool_init(self, model_id: str, device_str: str):
Initialize the first pool for the given model_id.
Ensure the pool is ready before returning.
"""
- device = torch.device(device_str)
- device_id = device.index
-
- if model_id == "sundial":
- config = SundialConfig()
- elif model_id == "timer_xl":
- config = TimerConfig()
- first_queue = mp.Queue()
- ready_event = mp.Event()
- first_pool = InferenceRequestPool(
- pool_id=0,
- model_id=model_id,
- device=device_str,
- config=config,
- request_queue=first_queue,
- result_queue=self._result_queue,
- ready_event=ready_event,
- )
- first_pool.start()
- self._register_pool(model_id, device_str, 0, first_pool, first_queue)
-
- if not ready_event.wait(timeout=30):
- self._erase_pool(model_id, device_id, 0)
- logger.error(
- f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time"
- )
- else:
- self.set_state(model_id, device_id, 0, PoolState.RUNNING)
- logger.info(
- f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}"
- )
+ pass
+ # device = torch.device(device_str)
+ # device_id = device.index
+ #
+ # first_queue = mp.Queue()
+ # ready_event = mp.Event()
+ # first_pool = InferenceRequestPool(
+ # pool_id=0,
+ # model_id=model_id,
+ # device=device_str,
+ # request_queue=first_queue,
+ # result_queue=self._result_queue,
+ # ready_event=ready_event,
+ # )
+ # first_pool.start()
+ # self._register_pool(model_id, device_str, 0, first_pool, first_queue)
+ #
+ # if not ready_event.wait(timeout=30):
+ # self._erase_pool(model_id, device_id, 0)
+ # logger.error(
+ # f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time"
+ # )
+ # else:
+ # self.set_state(model_id, device_id, 0, PoolState.RUNNING)
+ # logger.info(
+ # f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}"
+ # )
+ # =============== Pool Management ===============
def load_model(self, model_id: str, device_id_list: list[str]):
"""
Load the model to the specified devices asynchronously.
@@ -255,29 +248,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int):
"""
def _expand_pool_on_device(*_):
- result_queue = mp.Queue()
+ request_queue = mp.Queue()
pool_id = self._new_pool_id.get_and_increment()
model_info = self._model_manager.get_model_info(model_id)
- model_type = model_info.model_type
- if model_type == BuiltInModelType.SUNDIAL.value:
- config = SundialConfig()
- elif model_type == BuiltInModelType.TIMER_XL.value:
- config = TimerConfig()
- else:
- raise InferenceModelInternalError(
- f"Unsupported model type {model_type} for loading model {model_id}"
- )
pool = InferenceRequestPool(
pool_id=pool_id,
model_info=model_info,
device=device_id,
- config=config,
- request_queue=result_queue,
+ request_queue=request_queue,
result_queue=self._result_queue,
ready_event=mp.Event(),
)
pool.start()
- self._register_pool(model_id, device_id, pool_id, pool, result_queue)
+ self._register_pool(model_id, device_id, pool_id, pool, request_queue)
if not pool.ready_event.wait(timeout=300):
logger.error(
f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool failed to be ready in time"
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
index 6a2bd2b619aa7..d2e7292ecd8ff 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py
@@ -36,7 +36,7 @@
estimate_pool_size,
evaluate_system_resources,
)
-from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo
+from iotdb.ainode.core.model.model_info import ModelInfo
from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device
logger = Logger()
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py
deleted file mode 100644
index 2300169a6ee93..0000000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-from abc import ABC, abstractmethod
-
-import torch
-
-
-class AbstractInferencePipeline(ABC):
- """
- Abstract assistance strategy class for model inference.
- This class shall define the interface process for specific model.
- """
-
- def __init__(self, model_config, **infer_kwargs):
- self.model_config = model_config
- self.infer_kwargs = infer_kwargs
-
- @abstractmethod
- def preprocess_inputs(self, inputs: torch.Tensor):
- """
- Preprocess the inputs before inference, including shape validation and value transformation.
-
- Args:
- inputs (torch.Tensor): The input tensor to be preprocessed.
-
- Returns:
- torch.Tensor: The preprocessed input tensor.
- """
- # TODO: Integrate with the data processing pipeline operators
- pass
-
- @abstractmethod
- def post_decode(self):
- """
- Post-process the outputs after each decode step.
- """
- pass
-
- @abstractmethod
- def post_inference(self):
- """
- Post-process the outputs after the entire inference task.
- """
- pass
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py
index cf10b5b2cd4dc..d17f9fbcec536 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py
@@ -19,7 +19,6 @@
import string
import torch
-from transformers.modeling_outputs import MoeCausalLMOutputWithPast
def generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str:
@@ -56,25 +55,25 @@ def _slice_pkv(pkv, s, e):
return out
-def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes):
- """
- split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes)
- split_sizes[i] = ith request's batch_size。
- """
- outs = []
- start = 0
- for bsz in split_sizes:
- end = start + bsz
- outs.append(
- MoeCausalLMOutputWithPast(
- loss=_slice_tensor(batch_out.loss, start, end),
- logits=batch_out.logits[start:end],
- past_key_values=_slice_pkv(batch_out.past_key_values, start, end),
- hidden_states=_slice_tuple_of_tensors(
- batch_out.hidden_states, start, end
- ),
- attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end),
- )
- )
- start = end
- return outs
+# def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes):
+# """
+# split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes)
+# split_sizes[i] = ith request's batch_size。
+# """
+# outs = []
+# start = 0
+# for bsz in split_sizes:
+# end = start + bsz
+# outs.append(
+# MoeCausalLMOutputWithPast(
+# loss=_slice_tensor(batch_out.loss, start, end),
+# logits=batch_out.logits[start:end],
+# past_key_values=_slice_pkv(batch_out.past_key_values, start, end),
+# hidden_states=_slice_tuple_of_tensors(
+# batch_out.hidden_states, start, end
+# ),
+# attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end),
+# )
+# )
+# start = end
+# return outs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
index a67d576b0ec8c..1ce2e84e05929 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py
@@ -18,7 +18,6 @@
import threading
import time
-from abc import ABC, abstractmethod
from typing import Dict
import pandas as pd
@@ -29,29 +28,22 @@
from iotdb.ainode.core.constant import TSStatusCode
from iotdb.ainode.core.exception import (
InferenceModelInternalError,
- InvalidWindowArgumentError,
NumericalRangeException,
- runtime_error_extractor,
)
from iotdb.ainode.core.inference.inference_request import (
InferenceRequest,
InferenceRequestProxy,
)
-from iotdb.ainode.core.inference.pool_controller import PoolController
-from iotdb.ainode.core.inference.strategy.timer_sundial_inference_pipeline import (
- TimerSundialInferencePipeline,
-)
-from iotdb.ainode.core.inference.strategy.timerxl_inference_pipeline import (
- TimerXLInferencePipeline,
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import (
+ ChatPipeline,
+ ClassificationPipeline,
+ ForecastPipeline,
)
+from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline
+from iotdb.ainode.core.inference.pool_controller import PoolController
from iotdb.ainode.core.inference.utils import generate_req_id
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.model_manager import ModelManager
-from iotdb.ainode.core.model.model_enums import BuiltInModelType
-from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
-from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction
-from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
-from iotdb.ainode.core.model.timerxl.modeling_timer import TimerForPrediction
from iotdb.ainode.core.rpc.status import get_status
from iotdb.ainode.core.util.gpu_mapping import get_available_devices
from iotdb.ainode.core.util.serde import convert_to_binary
@@ -71,83 +63,6 @@
logger = Logger()
-class InferenceStrategy(ABC):
- def __init__(self, model):
- self.model = model
-
- @abstractmethod
- def infer(self, full_data, **kwargs):
- pass
-
-
-# [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length],
-# we only get valueList currently.
-class TimerXLStrategy(InferenceStrategy):
- def infer(self, full_data, predict_length=96, **_):
- data = full_data[1][0]
- if data.dtype.byteorder not in ("=", "|"):
- np_data = data.byteswap()
- data = np_data.view(np_data.dtype.newbyteorder())
- seqs = torch.tensor(data).unsqueeze(0).float()
- # TODO: unify model inference input
- output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True)
- df = pd.DataFrame(output[0])
- return convert_to_binary(df)
-
-
-class SundialStrategy(InferenceStrategy):
- def infer(self, full_data, predict_length=96, **_):
- data = full_data[1][0]
- if data.dtype.byteorder not in ("=", "|"):
- np_data = data.byteswap()
- data = np_data.view(np_data.dtype.newbyteorder())
- seqs = torch.tensor(data).unsqueeze(0).float()
- # TODO: unify model inference input
- output = self.model.generate(
- seqs, max_new_tokens=predict_length, num_samples=10, revin=True
- )
- df = pd.DataFrame(output[0].mean(dim=0))
- return convert_to_binary(df)
-
-
-class BuiltInStrategy(InferenceStrategy):
- def infer(self, full_data, **_):
- data = pd.DataFrame(full_data[1]).T
- output = self.model.inference(data)
- df = pd.DataFrame(output)
- return convert_to_binary(df)
-
-
-class RegisteredStrategy(InferenceStrategy):
- def infer(self, full_data, window_interval=None, window_step=None, **_):
- _, dataset, _, length = full_data
- if window_interval is None or window_step is None:
- window_interval = length
- window_step = float("inf")
-
- if window_interval <= 0 or window_step <= 0 or window_interval > length:
- raise InvalidWindowArgumentError(window_interval, window_step, length)
-
- data = torch.tensor(dataset, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1)
-
- times = int((length - window_interval) // window_step + 1)
- results = []
- try:
- for i in range(times):
- start = 0 if window_step == float("inf") else i * window_step
- end = start + window_interval
- window = data[:, start:end, :]
- out = self.model(window)
- df = pd.DataFrame(out.squeeze(0).detach().numpy())
- results.append(df)
- except Exception as e:
- msg = runtime_error_extractor(str(e)) or str(e)
- raise InferenceModelInternalError(msg)
-
- # concatenate or return first window for forecast
- return [convert_to_binary(df) for df in results]
-
-
class InferenceManager:
WAITING_INTERVAL_IN_MS = (
AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms()
@@ -251,15 +166,6 @@ def _process_request(self, req):
with self._result_wrapper_lock:
del self._result_wrapper_map[req_id]
- def _get_strategy(self, model_id, model):
- if isinstance(model, TimerForPrediction):
- return TimerXLStrategy(model)
- if isinstance(model, SundialForPrediction):
- return SundialStrategy(model)
- if self._model_manager.model_storage.is_built_in_or_fine_tuned(model_id):
- return BuiltInStrategy(model)
- return RegisteredStrategy(model)
-
def _run(
self,
req,
@@ -272,59 +178,54 @@ def _run(
model_id = req.modelId
try:
raw = data_getter(req)
+ # full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently.
full_data = deserializer(raw)
- inference_attrs = extract_attrs(req)
+ # TODO: TSBlock -> Tensor codes should be unified
+ data = full_data[1][0] # get valueList in ndarray
+ if data.dtype.byteorder not in ("=", "|"):
+ np_data = data.byteswap()
+ data = np_data.view(np_data.dtype.newbyteorder())
+ # the inputs should be on CPU before passing to the inference request
+ inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
- predict_length = int(inference_attrs.pop("predict_length", 96))
+ inference_attrs = extract_attrs(req)
+ output_length = int(inference_attrs.pop("output_length", 96))
if (
- predict_length
- > AINodeDescriptor().get_config().get_ain_inference_max_predict_length()
+ output_length
+ > AINodeDescriptor().get_config().get_ain_inference_max_output_length()
):
raise NumericalRangeException(
"output_length",
1,
AINodeDescriptor()
.get_config()
- .get_ain_inference_max_predict_length(),
- predict_length,
+ .get_ain_inference_max_output_length(),
+ output_length,
)
if self._pool_controller.has_request_pools(model_id):
- # use request pool to accelerate inference when the model instance is already loaded.
- # TODO: TSBlock -> Tensor codes should be unified
- data = full_data[1][0]
- if data.dtype.byteorder not in ("=", "|"):
- np_data = data.byteswap()
- data = np_data.view(np_data.dtype.newbyteorder())
- # the inputs should be on CPU before passing to the inference request
- inputs = torch.tensor(data).unsqueeze(0).float().to("cpu")
- model_type = self._model_manager.get_model_info(model_id).model_type
- if model_type == BuiltInModelType.SUNDIAL.value:
- inference_pipeline = TimerSundialInferencePipeline(SundialConfig())
- elif model_type == BuiltInModelType.TIMER_XL.value:
- inference_pipeline = TimerXLInferencePipeline(TimerConfig())
- else:
- raise InferenceModelInternalError(
- f"Unsupported model_id: {model_id}"
- )
infer_req = InferenceRequest(
req_id=generate_req_id(),
model_id=model_id,
inputs=inputs,
- inference_pipeline=inference_pipeline,
- max_new_tokens=predict_length,
+ output_length=output_length,
)
outputs = self._process_request(infer_req)
outputs = convert_to_binary(pd.DataFrame(outputs[0]))
else:
- # load model
- accel = str(inference_attrs.get("acceleration", "")).lower() == "true"
- model = self._model_manager.load_model(model_id, inference_attrs, accel)
- # inference by strategy
- strategy = self._get_strategy(model_id, model)
- outputs = strategy.infer(
- full_data, predict_length=predict_length, **inference_attrs
- )
+ model_info = self._model_manager.get_model_info(model_id)
+ inference_pipeline = load_pipeline(model_info, device="cpu")
+ if isinstance(inference_pipeline, ForecastPipeline):
+ outputs = inference_pipeline.forecast(
+ inputs, predict_length=output_length, **inference_attrs
+ )
+ elif isinstance(inference_pipeline, ClassificationPipeline):
+ outputs = inference_pipeline.classify(inputs)
+ elif isinstance(inference_pipeline, ChatPipeline):
+ outputs = inference_pipeline.chat(inputs)
+ else:
+ logger.error("[Inference] Unsupported pipeline type.")
+ outputs = convert_to_binary(pd.DataFrame(outputs[0]))
# construct response
status = get_status(TSStatusCode.SUCCESS_STATUS)
@@ -345,7 +246,7 @@ def forecast(self, req: TForecastReq):
data_getter=lambda r: r.inputData,
deserializer=deserialize,
extract_attrs=lambda r: {
- "predict_length": r.outputLength,
+ "output_length": r.outputLength,
**(r.options or {}),
},
resp_cls=TForecastResp,
@@ -358,8 +259,7 @@ def inference(self, req: TInferenceReq):
data_getter=lambda r: r.dataset,
deserializer=deserialize,
extract_attrs=lambda r: {
- "window_interval": getattr(r.windowParams, "windowInterval", None),
- "window_step": getattr(r.windowParams, "windowStep", None),
+ "output_length": int(r.inferenceAttributes.pop("outputLength", 96)),
**(r.inferenceAttributes or {}),
},
resp_cls=TInferenceResp,
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
index d84bca77c8430..8ffb33d91e2d2 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py
@@ -15,17 +15,14 @@
# specific language governing permissions and limitations
# under the License.
#
-from typing import Callable, Dict
-from torch import nn
-from yaml import YAMLError
+from typing import Any, List, Optional
from iotdb.ainode.core.constant import TSStatusCode
-from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError
+from iotdb.ainode.core.exception import BuiltInModelDeletionError
from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.model_enums import BuiltInModelType, ModelStates
-from iotdb.ainode.core.model.model_info import ModelInfo
-from iotdb.ainode.core.model.model_storage import ModelStorage
+from iotdb.ainode.core.model.model_loader import load_model
+from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage
from iotdb.ainode.core.rpc.status import get_status
from iotdb.ainode.core.util.decorator import singleton
from iotdb.thrift.ainode.ttypes import (
@@ -43,127 +40,60 @@
@singleton
class ModelManager:
def __init__(self):
- self.model_storage = ModelStorage()
+ self._model_storage = ModelStorage()
- def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp:
- logger.info(f"register model {req.modelId} from {req.uri}")
+ def register_model(
+ self,
+ req: TRegisterModelReq,
+ ) -> TRegisterModelResp:
try:
- configs, attributes = self.model_storage.register_model(
- req.modelId, req.uri
- )
- return TRegisterModelResp(
- get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes
- )
- except InvalidUriError as e:
- logger.warning(e)
- return TRegisterModelResp(
- get_status(TSStatusCode.INVALID_URI_ERROR, e.message)
- )
- except BadConfigValueError as e:
- logger.warning(e)
+ if self._model_storage.register_model(model_id=req.modelId, uri=req.uri):
+ return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS))
+ return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))
+ except ValueError as e:
return TRegisterModelResp(
- get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)
+ get_status(TSStatusCode.INVALID_URI_ERROR, str(e))
)
- except YAMLError as e:
- logger.warning(e)
- if hasattr(e, "problem_mark"):
- mark = e.problem_mark
- return TRegisterModelResp(
- get_status(
- TSStatusCode.INVALID_INFERENCE_CONFIG,
- f"An error occurred while parsing the yaml file, "
- f"at line {mark.line + 1} column {mark.column + 1}.",
- )
- )
+ except Exception as e:
return TRegisterModelResp(
- get_status(
- TSStatusCode.INVALID_INFERENCE_CONFIG,
- f"An error occurred while parsing the yaml file",
- )
+ get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
)
- except Exception as e:
- logger.warning(e)
- return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR))
+
+ def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
+ self._refresh()
+ return self._model_storage.show_models(req)
def delete_model(self, req: TDeleteModelReq) -> TSStatus:
- logger.info(f"delete model {req.modelId}")
try:
- self.model_storage.delete_model(req.modelId)
+ self._model_storage.delete_model(req.modelId)
return get_status(TSStatusCode.SUCCESS_STATUS)
- except Exception as e:
+ except BuiltInModelDeletionError as e:
logger.warning(e)
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
-
- def load_model(
- self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False
- ) -> Callable:
- """
- Load the model with the given model_id.
- """
- logger.info(f"Load model {model_id}")
- try:
- model = self.model_storage.load_model(
- model_id, inference_attrs, acceleration
- )
- logger.info(f"Model {model_id} loaded")
- return model
- except Exception as e:
- logger.error(f"Failed to load model {model_id}: {e}")
- raise
-
- def save_model(self, model_id: str, model: nn.Module) -> TSStatus:
- """
- Save the model using save_pretrained
- """
- logger.info(f"Saving model {model_id}")
- try:
- self.model_storage.save_model(model_id, model)
- logger.info(f"Saving model {model_id} successfully")
- return get_status(
- TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved successfully"
- )
except Exception as e:
- logger.error(f"Save model failed: {e}")
+ logger.warning(e)
return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e))
- def get_ckpt_path(self, model_id: str) -> str:
- """
- Get the checkpoint path for a given model ID.
-
- Args:
- model_id (str): The ID of the model.
-
- Returns:
- str: The path to the checkpoint file for the model.
- """
- return self.model_storage.get_ckpt_path(model_id)
-
- def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
- return self.model_storage.show_models(req)
-
- def register_built_in_model(self, model_info: ModelInfo):
- self.model_storage.register_built_in_model(model_info)
-
- def get_model_info(self, model_id: str) -> ModelInfo:
- return self.model_storage.get_model_info(model_id)
-
- def update_model_state(self, model_id: str, state: ModelStates):
- self.model_storage.update_model_state(model_id, state)
-
- def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
- """
- Get the type of the model with the given model_id.
- """
- return self.model_storage.get_built_in_model_type(model_id)
-
- def is_built_in_or_fine_tuned(self, model_id: str) -> bool:
- """
- Check if the model_id corresponds to a built-in or fine-tuned model.
-
- Args:
- model_id (str): The ID of the model.
-
- Returns:
- bool: True if the model is built-in or fine_tuned, False otherwise.
- """
- return self.model_storage.is_built_in_or_fine_tuned(model_id)
+ def get_model_info(
+ self,
+ model_id: str,
+ category: Optional[ModelCategory] = None,
+ ) -> Optional[ModelInfo]:
+ return self._model_storage.get_model_info(model_id, category)
+
+ def get_model_infos(
+ self,
+ category: Optional[ModelCategory] = None,
+ model_type: Optional[str] = None,
+ ) -> List[ModelInfo]:
+ return self._model_storage.get_model_infos(category, model_type)
+
+ def _refresh(self):
+ """Refresh the model list (re-scan the file system)"""
+ self._model_storage.discover_all_models()
+
+ def get_registered_models(self) -> List[str]:
+ return self._model_storage.get_registered_models()
+
+ def is_model_registered(self, model_id: str) -> bool:
+ return self._model_storage.is_model_registered(model_id)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
index 0264e27331a86..23a98f26bbffa 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py
@@ -25,7 +25,7 @@
from iotdb.ainode.core.exception import ModelNotExistError
from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.manager.model_manager import ModelManager
-from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP
+from iotdb.ainode.core.model.model_loader import load_model
logger = Logger()
@@ -47,7 +47,8 @@ def measure_model_memory(device: torch.device, model_id: str) -> int:
torch.cuda.synchronize(device)
start = torch.cuda.memory_reserved(device)
- model = ModelManager().load_model(model_id, {}).to(device)
+ model_info = ModelManager().get_model_info(model_id)
+ model = load_model(model_info).to(device)
torch.cuda.synchronize(device)
end = torch.cuda.memory_reserved(device)
usage = end - start
@@ -80,7 +81,7 @@ def evaluate_system_resources(device: torch.device) -> dict:
def estimate_pool_size(device: torch.device, model_id: str) -> int:
- model_info = BUILT_IN_LTSM_MAP.get(model_id, None)
+ model_info = ModelManager().get_model_info(model_id)
if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP:
logger.error(
f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported."
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py
deleted file mode 100644
index 3b55142350bad..0000000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py
+++ /dev/null
@@ -1,1238 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import os
-from abc import abstractmethod
-from typing import Callable, Dict, List
-
-import numpy as np
-from huggingface_hub import hf_hub_download
-from sklearn.preprocessing import MinMaxScaler
-from sktime.detection.hmm_learn import GMMHMM, GaussianHMM
-from sktime.detection.stray import STRAY
-from sktime.forecasting.arima import ARIMA
-from sktime.forecasting.exp_smoothing import ExponentialSmoothing
-from sktime.forecasting.naive import NaiveForecaster
-from sktime.forecasting.trend import STLForecaster
-
-from iotdb.ainode.core.config import AINodeDescriptor
-from iotdb.ainode.core.constant import (
- MODEL_CONFIG_FILE_IN_JSON,
- MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
- AttributeName,
-)
-from iotdb.ainode.core.exception import (
- BuiltInModelNotSupportError,
- InferenceModelInternalError,
- ListRangeException,
- NumericalRangeException,
- StringRangeException,
- WrongAttributeTypeError,
-)
-from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.model_enums import BuiltInModelType
-from iotdb.ainode.core.model.model_info import TIMER_REPO_ID
-from iotdb.ainode.core.model.sundial import modeling_sundial
-from iotdb.ainode.core.model.timerxl import modeling_timer
-
-logger = Logger()
-
-
-def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool:
- weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
- config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON)
- if not os.path.exists(weights_path):
- logger.info(
- f"Model weights file not found at {weights_path}, downloading from HuggingFace..."
- )
- try:
- hf_hub_download(
- repo_id=repo_id,
- filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
- local_dir=local_dir,
- )
- logger.info(f"Got file to {weights_path}")
- except Exception as e:
- logger.error(
- f"Failed to download model weights file to {local_dir} due to {e}"
- )
- return False
- if not os.path.exists(config_path):
- logger.info(
- f"Model config file not found at {config_path}, downloading from HuggingFace..."
- )
- try:
- hf_hub_download(
- repo_id=repo_id,
- filename=MODEL_CONFIG_FILE_IN_JSON,
- local_dir=local_dir,
- )
- logger.info(f"Got file to {config_path}")
- except Exception as e:
- logger.error(
- f"Failed to download model config file to {local_dir} due to {e}"
- )
- return False
- return True
-
-
-def download_built_in_ltsm_from_hf_if_necessary(
- model_type: BuiltInModelType, local_dir: str
-) -> bool:
- """
- Download the built-in ltsm from HuggingFace repository when necessary.
-
- Return:
- bool: True if the model is existed or downloaded successfully, False otherwise.
- """
- repo_id = TIMER_REPO_ID[model_type]
- if not _download_file_from_hf_if_necessary(local_dir, repo_id):
- return False
- return True
-
-
-def get_model_attributes(model_type: BuiltInModelType):
- if model_type == BuiltInModelType.ARIMA:
- attribute_map = arima_attribute_map
- elif model_type == BuiltInModelType.NAIVE_FORECASTER:
- attribute_map = naive_forecaster_attribute_map
- elif (
- model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
- or model_type == BuiltInModelType.HOLTWINTERS
- ):
- attribute_map = exponential_smoothing_attribute_map
- elif model_type == BuiltInModelType.STL_FORECASTER:
- attribute_map = stl_forecaster_attribute_map
- elif model_type == BuiltInModelType.GMM_HMM:
- attribute_map = gmmhmm_attribute_map
- elif model_type == BuiltInModelType.GAUSSIAN_HMM:
- attribute_map = gaussian_hmm_attribute_map
- elif model_type == BuiltInModelType.STRAY:
- attribute_map = stray_attribute_map
- elif model_type == BuiltInModelType.TIMER_XL:
- attribute_map = timerxl_attribute_map
- elif model_type == BuiltInModelType.SUNDIAL:
- attribute_map = sundial_attribute_map
- else:
- raise BuiltInModelNotSupportError(model_type.value)
- return attribute_map
-
-
-def fetch_built_in_model(
- model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str]
-) -> Callable:
- """
- Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config.
- Args:
- model_type: the type of the built-in model
- model_dir: for huggingface models only, the directory where the model is stored
- Returns:
- model: the built-in model
- """
- default_attributes = get_model_attributes(model_type)
- # parse the attributes from inference_attrs
- attributes = parse_attribute(inference_attrs, default_attributes)
-
- # build the built-in model
- if model_type == BuiltInModelType.ARIMA:
- model = ArimaModel(attributes)
- elif (
- model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING
- or model_type == BuiltInModelType.HOLTWINTERS
- ):
- model = ExponentialSmoothingModel(attributes)
- elif model_type == BuiltInModelType.NAIVE_FORECASTER:
- model = NaiveForecasterModel(attributes)
- elif model_type == BuiltInModelType.STL_FORECASTER:
- model = STLForecasterModel(attributes)
- elif model_type == BuiltInModelType.GMM_HMM:
- model = GMMHMMModel(attributes)
- elif model_type == BuiltInModelType.GAUSSIAN_HMM:
- model = GaussianHmmModel(attributes)
- elif model_type == BuiltInModelType.STRAY:
- model = STRAYModel(attributes)
- elif model_type == BuiltInModelType.TIMER_XL:
- model = modeling_timer.TimerForPrediction.from_pretrained(model_dir)
- elif model_type == BuiltInModelType.SUNDIAL:
- model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir)
- else:
- raise BuiltInModelNotSupportError(model_type.value)
-
- return model
-
-
-class Attribute(object):
- def __init__(self, name: str):
- """
- Args:
- name: the name of the attribute
- """
- self._name = name
-
- @abstractmethod
- def get_default_value(self):
- raise NotImplementedError
-
- @abstractmethod
- def validate_value(self, value):
- raise NotImplementedError
-
- @abstractmethod
- def parse(self, string_value: str):
- raise NotImplementedError
-
-
-class IntAttribute(Attribute):
- def __init__(
- self,
- name: str,
- default_value: int,
- default_low: int,
- default_high: int,
- ):
- super(IntAttribute, self).__init__(name)
- self.__default_value = default_value
- self.__default_low = default_low
- self.__default_high = default_high
-
- def get_default_value(self):
- return self.__default_value
-
- def validate_value(self, value):
- if self.__default_low <= value <= self.__default_high:
- return True
- raise NumericalRangeException(
- self._name, value, self.__default_low, self.__default_high
- )
-
- def parse(self, string_value: str):
- try:
- int_value = int(string_value)
- except:
- raise WrongAttributeTypeError(self._name, "int")
- return int_value
-
-
-class FloatAttribute(Attribute):
- def __init__(
- self,
- name: str,
- default_value: float,
- default_low: float,
- default_high: float,
- ):
- super(FloatAttribute, self).__init__(name)
- self.__default_value = default_value
- self.__default_low = default_low
- self.__default_high = default_high
-
- def get_default_value(self):
- return self.__default_value
-
- def validate_value(self, value):
- if self.__default_low <= value <= self.__default_high:
- return True
- raise NumericalRangeException(
- self._name, value, self.__default_low, self.__default_high
- )
-
- def parse(self, string_value: str):
- try:
- float_value = float(string_value)
- except:
- raise WrongAttributeTypeError(self._name, "float")
- return float_value
-
-
-class StringAttribute(Attribute):
- def __init__(self, name: str, default_value: str, value_choices: List[str]):
- super(StringAttribute, self).__init__(name)
- self.__default_value = default_value
- self.__value_choices = value_choices
-
- def get_default_value(self):
- return self.__default_value
-
- def validate_value(self, value):
- if value in self.__value_choices:
- return True
- raise StringRangeException(self._name, value, self.__value_choices)
-
- def parse(self, string_value: str):
- return string_value
-
-
-class BooleanAttribute(Attribute):
- def __init__(self, name: str, default_value: bool):
- super(BooleanAttribute, self).__init__(name)
- self.__default_value = default_value
-
- def get_default_value(self):
- return self.__default_value
-
- def validate_value(self, value):
- if isinstance(value, bool):
- return True
- raise WrongAttributeTypeError(self._name, "bool")
-
- def parse(self, string_value: str):
- if string_value.lower() == "true":
- return True
- elif string_value.lower() == "false":
- return False
- else:
- raise WrongAttributeTypeError(self._name, "bool")
-
-
-class ListAttribute(Attribute):
- def __init__(self, name: str, default_value: List, value_type):
- """
- value_type is the type of the elements in the list, e.g. int, float, str
- """
- super(ListAttribute, self).__init__(name)
- self.__default_value = default_value
- self.__value_type = value_type
- self.__type_to_str = {str: "str", int: "int", float: "float"}
-
- def get_default_value(self):
- return self.__default_value
-
- def validate_value(self, value):
- if not isinstance(value, list):
- raise WrongAttributeTypeError(self._name, "list")
- for value_item in value:
- if not isinstance(value_item, self.__value_type):
- raise WrongAttributeTypeError(self._name, self.__value_type)
- return True
-
- def parse(self, string_value: str):
- try:
- list_value = eval(string_value)
- except:
- raise WrongAttributeTypeError(self._name, "list")
- if not isinstance(list_value, list):
- raise WrongAttributeTypeError(self._name, "list")
- for i in range(len(list_value)):
- try:
- list_value[i] = self.__value_type(list_value[i])
- except:
- raise ListRangeException(
- self._name, list_value, self.__type_to_str[self.__value_type]
- )
- return list_value
-
-
-class TupleAttribute(Attribute):
- def __init__(self, name: str, default_value: tuple, value_type):
- """
- value_type is the type of the elements in the list, e.g. int, float, str
- """
- super(TupleAttribute, self).__init__(name)
- self.__default_value = default_value
- self.__value_type = value_type
- self.__type_to_str = {str: "str", int: "int", float: "float"}
-
- def get_default_value(self):
- return self.__default_value
-
- def validate_value(self, value):
- if not isinstance(value, tuple):
- raise WrongAttributeTypeError(self._name, "tuple")
- for value_item in value:
- if not isinstance(value_item, self.__value_type):
- raise WrongAttributeTypeError(self._name, self.__value_type)
- return True
-
- def parse(self, string_value: str):
- try:
- tuple_value = eval(string_value)
- except:
- raise WrongAttributeTypeError(self._name, "tuple")
- if not isinstance(tuple_value, tuple):
- raise WrongAttributeTypeError(self._name, "tuple")
- list_value = list(tuple_value)
- for i in range(len(list_value)):
- try:
- list_value[i] = self.__value_type(list_value[i])
- except:
- raise ListRangeException(
- self._name, list_value, self.__type_to_str[self.__value_type]
- )
- tuple_value = tuple(list_value)
- return tuple_value
-
-
-def parse_attribute(
- input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute]
-):
- """
- Args:
- input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of
- the attribute
- attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute
- object
- Returns:
- a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute
- """
- attributes = {}
- for attribute_name in attribute_map:
- # user specified the attribute
- if attribute_name in input_attributes:
- attribute = attribute_map[attribute_name]
- value = attribute.parse(input_attributes[attribute_name])
- attribute.validate_value(value)
- attributes[attribute_name] = value
- # user did not specify the attribute, use the default value
- else:
- try:
- attributes[attribute_name] = attribute_map[
- attribute_name
- ].get_default_value()
- except NotImplementedError as e:
- logger.error(f"attribute {attribute_name} is not implemented.")
- raise e
- return attributes
-
-
-sundial_attribute_map = {
- AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
- name=AttributeName.INPUT_TOKEN_LEN.value,
- default_value=16,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.HIDDEN_SIZE.value: IntAttribute(
- name=AttributeName.HIDDEN_SIZE.value,
- default_value=768,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.INTERMEDIATE_SIZE.value: IntAttribute(
- name=AttributeName.INTERMEDIATE_SIZE.value,
- default_value=3072,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute(
- name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int
- ),
- AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute(
- name=AttributeName.NUM_HIDDEN_LAYERS.value,
- default_value=12,
- default_low=1,
- default_high=16,
- ),
- AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute(
- name=AttributeName.NUM_ATTENTION_HEADS.value,
- default_value=12,
- default_low=1,
- default_high=192,
- ),
- AttributeName.HIDDEN_ACT.value: StringAttribute(
- name=AttributeName.HIDDEN_ACT.value,
- default_value="silu",
- value_choices=["relu", "gelu", "silu", "tanh"],
- ),
- AttributeName.USE_CACHE.value: BooleanAttribute(
- name=AttributeName.USE_CACHE.value,
- default_value=True,
- ),
- AttributeName.ROPE_THETA.value: IntAttribute(
- name=AttributeName.ROPE_THETA.value,
- default_value=10000,
- default_low=1000,
- default_high=50000,
- ),
- AttributeName.DROPOUT_RATE.value: FloatAttribute(
- name=AttributeName.DROPOUT_RATE.value,
- default_value=0.1,
- default_low=0.0,
- default_high=1.0,
- ),
- AttributeName.INITIALIZER_RANGE.value: FloatAttribute(
- name=AttributeName.INITIALIZER_RANGE.value,
- default_value=0.02,
- default_low=0.0,
- default_high=1.0,
- ),
- AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute(
- name=AttributeName.MAX_POSITION_EMBEDDINGS.value,
- default_value=10000,
- default_low=1,
- default_high=50000,
- ),
- AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute(
- name=AttributeName.FLOW_LOSS_DEPTH.value,
- default_value=3,
- default_low=1,
- default_high=50,
- ),
- AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute(
- name=AttributeName.NUM_SAMPLING_STEPS.value,
- default_value=50,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute(
- name=AttributeName.DIFFUSION_BATCH_MUL.value,
- default_value=4,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.CKPT_PATH.value: StringAttribute(
- name=AttributeName.CKPT_PATH.value,
- default_value=os.path.join(
- os.getcwd(),
- AINodeDescriptor().get_config().get_ain_models_dir(),
- "weights",
- "sundial",
- ),
- value_choices=[""],
- ),
-}
-
-timerxl_attribute_map = {
- AttributeName.INPUT_TOKEN_LEN.value: IntAttribute(
- name=AttributeName.INPUT_TOKEN_LEN.value,
- default_value=96,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.HIDDEN_SIZE.value: IntAttribute(
- name=AttributeName.HIDDEN_SIZE.value,
- default_value=1024,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.INTERMEDIATE_SIZE.value: IntAttribute(
- name=AttributeName.INTERMEDIATE_SIZE.value,
- default_value=2048,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute(
- name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int
- ),
- AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute(
- name=AttributeName.NUM_HIDDEN_LAYERS.value,
- default_value=8,
- default_low=1,
- default_high=16,
- ),
- AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute(
- name=AttributeName.NUM_ATTENTION_HEADS.value,
- default_value=8,
- default_low=1,
- default_high=192,
- ),
- AttributeName.HIDDEN_ACT.value: StringAttribute(
- name=AttributeName.HIDDEN_ACT.value,
- default_value="silu",
- value_choices=["relu", "gelu", "silu", "tanh"],
- ),
- AttributeName.USE_CACHE.value: BooleanAttribute(
- name=AttributeName.USE_CACHE.value,
- default_value=True,
- ),
- AttributeName.ROPE_THETA.value: IntAttribute(
- name=AttributeName.ROPE_THETA.value,
- default_value=10000,
- default_low=1000,
- default_high=50000,
- ),
- AttributeName.ATTENTION_DROPOUT.value: FloatAttribute(
- name=AttributeName.ATTENTION_DROPOUT.value,
- default_value=0.0,
- default_low=0.0,
- default_high=1.0,
- ),
- AttributeName.INITIALIZER_RANGE.value: FloatAttribute(
- name=AttributeName.INITIALIZER_RANGE.value,
- default_value=0.02,
- default_low=0.0,
- default_high=1.0,
- ),
- AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute(
- name=AttributeName.MAX_POSITION_EMBEDDINGS.value,
- default_value=10000,
- default_low=1,
- default_high=50000,
- ),
- AttributeName.CKPT_PATH.value: StringAttribute(
- name=AttributeName.CKPT_PATH.value,
- default_value=os.path.join(
- os.getcwd(),
- AINodeDescriptor().get_config().get_ain_models_dir(),
- "weights",
- "timerxl",
- "model.safetensors",
- ),
- value_choices=[""],
- ),
-}
-
-# built-in sktime model attributes
-# NaiveForecaster
-naive_forecaster_attribute_map = {
- AttributeName.PREDICT_LENGTH.value: IntAttribute(
- name=AttributeName.PREDICT_LENGTH.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.STRATEGY.value: StringAttribute(
- name=AttributeName.STRATEGY.value,
- default_value="last",
- value_choices=["last", "mean"],
- ),
- AttributeName.SP.value: IntAttribute(
- name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000
- ),
-}
-# ExponentialSmoothing
-exponential_smoothing_attribute_map = {
- AttributeName.PREDICT_LENGTH.value: IntAttribute(
- name=AttributeName.PREDICT_LENGTH.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.DAMPED_TREND.value: BooleanAttribute(
- name=AttributeName.DAMPED_TREND.value,
- default_value=False,
- ),
- AttributeName.INITIALIZATION_METHOD.value: StringAttribute(
- name=AttributeName.INITIALIZATION_METHOD.value,
- default_value="estimated",
- value_choices=["estimated", "heuristic", "legacy-heuristic", "known"],
- ),
- AttributeName.OPTIMIZED.value: BooleanAttribute(
- name=AttributeName.OPTIMIZED.value,
- default_value=True,
- ),
- AttributeName.REMOVE_BIAS.value: BooleanAttribute(
- name=AttributeName.REMOVE_BIAS.value,
- default_value=False,
- ),
- AttributeName.USE_BRUTE.value: BooleanAttribute(
- name=AttributeName.USE_BRUTE.value,
- default_value=False,
- ),
-}
-# Arima
-arima_attribute_map = {
- AttributeName.PREDICT_LENGTH.value: IntAttribute(
- name=AttributeName.PREDICT_LENGTH.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.ORDER.value: TupleAttribute(
- name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int
- ),
- AttributeName.SEASONAL_ORDER.value: TupleAttribute(
- name=AttributeName.SEASONAL_ORDER.value,
- default_value=(0, 0, 0, 0),
- value_type=int,
- ),
- AttributeName.METHOD.value: StringAttribute(
- name=AttributeName.METHOD.value,
- default_value="lbfgs",
- value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"],
- ),
- AttributeName.MAXITER.value: IntAttribute(
- name=AttributeName.MAXITER.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute(
- name=AttributeName.SUPPRESS_WARNINGS.value,
- default_value=True,
- ),
- AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute(
- name=AttributeName.OUT_OF_SAMPLE_SIZE.value,
- default_value=0,
- default_low=0,
- default_high=5000,
- ),
- AttributeName.SCORING.value: StringAttribute(
- name=AttributeName.SCORING.value,
- default_value="mse",
- value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"],
- ),
- AttributeName.WITH_INTERCEPT.value: BooleanAttribute(
- name=AttributeName.WITH_INTERCEPT.value,
- default_value=True,
- ),
- AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute(
- name=AttributeName.TIME_VARYING_REGRESSION.value,
- default_value=False,
- ),
- AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute(
- name=AttributeName.ENFORCE_STATIONARITY.value,
- default_value=True,
- ),
- AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute(
- name=AttributeName.ENFORCE_INVERTIBILITY.value,
- default_value=True,
- ),
- AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute(
- name=AttributeName.SIMPLE_DIFFERENCING.value,
- default_value=False,
- ),
- AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute(
- name=AttributeName.MEASUREMENT_ERROR.value,
- default_value=False,
- ),
- AttributeName.MLE_REGRESSION.value: BooleanAttribute(
- name=AttributeName.MLE_REGRESSION.value,
- default_value=True,
- ),
- AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute(
- name=AttributeName.HAMILTON_REPRESENTATION.value,
- default_value=False,
- ),
- AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute(
- name=AttributeName.CONCENTRATE_SCALE.value,
- default_value=False,
- ),
-}
-# STLForecaster
-stl_forecaster_attribute_map = {
- AttributeName.PREDICT_LENGTH.value: IntAttribute(
- name=AttributeName.PREDICT_LENGTH.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.SP.value: IntAttribute(
- name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000
- ),
- AttributeName.SEASONAL.value: IntAttribute(
- name=AttributeName.SEASONAL.value,
- default_value=7,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.SEASONAL_DEG.value: IntAttribute(
- name=AttributeName.SEASONAL_DEG.value,
- default_value=1,
- default_low=0,
- default_high=5000,
- ),
- AttributeName.TREND_DEG.value: IntAttribute(
- name=AttributeName.TREND_DEG.value,
- default_value=1,
- default_low=0,
- default_high=5000,
- ),
- AttributeName.LOW_PASS_DEG.value: IntAttribute(
- name=AttributeName.LOW_PASS_DEG.value,
- default_value=1,
- default_low=0,
- default_high=5000,
- ),
- AttributeName.SEASONAL_JUMP.value: IntAttribute(
- name=AttributeName.SEASONAL_JUMP.value,
- default_value=1,
- default_low=0,
- default_high=5000,
- ),
- AttributeName.TREND_JUMP.value: IntAttribute(
- name=AttributeName.TREND_JUMP.value,
- default_value=1,
- default_low=0,
- default_high=5000,
- ),
- AttributeName.LOSS_PASS_JUMP.value: IntAttribute(
- name=AttributeName.LOSS_PASS_JUMP.value,
- default_value=1,
- default_low=0,
- default_high=5000,
- ),
-}
-
-# GAUSSIAN_HMM
-gaussian_hmm_attribute_map = {
- AttributeName.N_COMPONENTS.value: IntAttribute(
- name=AttributeName.N_COMPONENTS.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.COVARIANCE_TYPE.value: StringAttribute(
- name=AttributeName.COVARIANCE_TYPE.value,
- default_value="diag",
- value_choices=["spherical", "diag", "full", "tied"],
- ),
- AttributeName.MIN_COVAR.value: FloatAttribute(
- name=AttributeName.MIN_COVAR.value,
- default_value=1e-3,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.STARTPROB_PRIOR.value: FloatAttribute(
- name=AttributeName.STARTPROB_PRIOR.value,
- default_value=1.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.TRANSMAT_PRIOR.value: FloatAttribute(
- name=AttributeName.TRANSMAT_PRIOR.value,
- default_value=1.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.MEANS_PRIOR.value: FloatAttribute(
- name=AttributeName.MEANS_PRIOR.value,
- default_value=0.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.MEANS_WEIGHT.value: FloatAttribute(
- name=AttributeName.MEANS_WEIGHT.value,
- default_value=0.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.COVARS_PRIOR.value: FloatAttribute(
- name=AttributeName.COVARS_PRIOR.value,
- default_value=1e-2,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.COVARS_WEIGHT.value: FloatAttribute(
- name=AttributeName.COVARS_WEIGHT.value,
- default_value=1.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.ALGORITHM.value: StringAttribute(
- name=AttributeName.ALGORITHM.value,
- default_value="viterbi",
- value_choices=["viterbi", "map"],
- ),
- AttributeName.N_ITER.value: IntAttribute(
- name=AttributeName.N_ITER.value,
- default_value=10,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.TOL.value: FloatAttribute(
- name=AttributeName.TOL.value,
- default_value=1e-2,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.PARAMS.value: StringAttribute(
- name=AttributeName.PARAMS.value,
- default_value="stmc",
- value_choices=["stmc", "stm"],
- ),
- AttributeName.INIT_PARAMS.value: StringAttribute(
- name=AttributeName.INIT_PARAMS.value,
- default_value="stmc",
- value_choices=["stmc", "stm"],
- ),
- AttributeName.IMPLEMENTATION.value: StringAttribute(
- name=AttributeName.IMPLEMENTATION.value,
- default_value="log",
- value_choices=["log", "scaling"],
- ),
-}
-
-# GMMHMM
-gmmhmm_attribute_map = {
- AttributeName.N_COMPONENTS.value: IntAttribute(
- name=AttributeName.N_COMPONENTS.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.N_MIX.value: IntAttribute(
- name=AttributeName.N_MIX.value,
- default_value=1,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.MIN_COVAR.value: FloatAttribute(
- name=AttributeName.MIN_COVAR.value,
- default_value=1e-3,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.STARTPROB_PRIOR.value: FloatAttribute(
- name=AttributeName.STARTPROB_PRIOR.value,
- default_value=1.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.TRANSMAT_PRIOR.value: FloatAttribute(
- name=AttributeName.TRANSMAT_PRIOR.value,
- default_value=1.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.WEIGHTS_PRIOR.value: FloatAttribute(
- name=AttributeName.WEIGHTS_PRIOR.value,
- default_value=1.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.MEANS_PRIOR.value: FloatAttribute(
- name=AttributeName.MEANS_PRIOR.value,
- default_value=0.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.MEANS_WEIGHT.value: FloatAttribute(
- name=AttributeName.MEANS_WEIGHT.value,
- default_value=0.0,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.ALGORITHM.value: StringAttribute(
- name=AttributeName.ALGORITHM.value,
- default_value="viterbi",
- value_choices=["viterbi", "map"],
- ),
- AttributeName.COVARIANCE_TYPE.value: StringAttribute(
- name=AttributeName.COVARIANCE_TYPE.value,
- default_value="diag",
- value_choices=["sperical", "diag", "full", "tied"],
- ),
- AttributeName.N_ITER.value: IntAttribute(
- name=AttributeName.N_ITER.value,
- default_value=10,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.TOL.value: FloatAttribute(
- name=AttributeName.TOL.value,
- default_value=1e-2,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.INIT_PARAMS.value: StringAttribute(
- name=AttributeName.INIT_PARAMS.value,
- default_value="stmcw",
- value_choices=[
- "s",
- "t",
- "m",
- "c",
- "w",
- "st",
- "sm",
- "sc",
- "sw",
- "tm",
- "tc",
- "tw",
- "mc",
- "mw",
- "cw",
- "stm",
- "stc",
- "stw",
- "smc",
- "smw",
- "scw",
- "tmc",
- "tmw",
- "tcw",
- "mcw",
- "stmc",
- "stmw",
- "stcw",
- "smcw",
- "tmcw",
- "stmcw",
- ],
- ),
- AttributeName.PARAMS.value: StringAttribute(
- name=AttributeName.PARAMS.value,
- default_value="stmcw",
- value_choices=[
- "s",
- "t",
- "m",
- "c",
- "w",
- "st",
- "sm",
- "sc",
- "sw",
- "tm",
- "tc",
- "tw",
- "mc",
- "mw",
- "cw",
- "stm",
- "stc",
- "stw",
- "smc",
- "smw",
- "scw",
- "tmc",
- "tmw",
- "tcw",
- "mcw",
- "stmc",
- "stmw",
- "stcw",
- "smcw",
- "tmcw",
- "stmcw",
- ],
- ),
- AttributeName.IMPLEMENTATION.value: StringAttribute(
- name=AttributeName.IMPLEMENTATION.value,
- default_value="log",
- value_choices=["log", "scaling"],
- ),
-}
-
-# STRAY
-stray_attribute_map = {
- AttributeName.ALPHA.value: FloatAttribute(
- name=AttributeName.ALPHA.value,
- default_value=0.01,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.K.value: IntAttribute(
- name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000
- ),
- AttributeName.KNN_ALGORITHM.value: StringAttribute(
- name=AttributeName.KNN_ALGORITHM.value,
- default_value="brute",
- value_choices=["brute", "kd_tree", "ball_tree", "auto"],
- ),
- AttributeName.P.value: FloatAttribute(
- name=AttributeName.P.value,
- default_value=0.5,
- default_low=-1e10,
- default_high=1e10,
- ),
- AttributeName.SIZE_THRESHOLD.value: IntAttribute(
- name=AttributeName.SIZE_THRESHOLD.value,
- default_value=50,
- default_low=1,
- default_high=5000,
- ),
- AttributeName.OUTLIER_TAIL.value: StringAttribute(
- name=AttributeName.OUTLIER_TAIL.value,
- default_value="max",
- value_choices=["min", "max"],
- ),
-}
-
-
-class BuiltInModel(object):
- def __init__(self, attributes):
- self._attributes = attributes
- self._model = None
-
- @abstractmethod
- def inference(self, data):
- raise NotImplementedError
-
-
-class ArimaModel(BuiltInModel):
- def __init__(self, attributes):
- super(ArimaModel, self).__init__(attributes)
- self._model = ARIMA(
- order=attributes["order"],
- seasonal_order=attributes["seasonal_order"],
- method=attributes["method"],
- suppress_warnings=attributes["suppress_warnings"],
- maxiter=attributes["maxiter"],
- out_of_sample_size=attributes["out_of_sample_size"],
- scoring=attributes["scoring"],
- with_intercept=attributes["with_intercept"],
- time_varying_regression=attributes["time_varying_regression"],
- enforce_stationarity=attributes["enforce_stationarity"],
- enforce_invertibility=attributes["enforce_invertibility"],
- simple_differencing=attributes["simple_differencing"],
- measurement_error=attributes["measurement_error"],
- mle_regression=attributes["mle_regression"],
- hamilton_representation=attributes["hamilton_representation"],
- concentrate_scale=attributes["concentrate_scale"],
- )
-
- def inference(self, data):
- try:
- predict_length = self._attributes["predict_length"]
- self._model.fit(data)
- output = self._model.predict(fh=range(predict_length))
- output = np.array(output, dtype=np.float64)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
-
-
-class ExponentialSmoothingModel(BuiltInModel):
- def __init__(self, attributes):
- super(ExponentialSmoothingModel, self).__init__(attributes)
- self._model = ExponentialSmoothing(
- damped_trend=attributes["damped_trend"],
- initialization_method=attributes["initialization_method"],
- optimized=attributes["optimized"],
- remove_bias=attributes["remove_bias"],
- use_brute=attributes["use_brute"],
- )
-
- def inference(self, data):
- try:
- predict_length = self._attributes["predict_length"]
- self._model.fit(data)
- output = self._model.predict(fh=range(predict_length))
- output = np.array(output, dtype=np.float64)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
-
-
-class NaiveForecasterModel(BuiltInModel):
- def __init__(self, attributes):
- super(NaiveForecasterModel, self).__init__(attributes)
- self._model = NaiveForecaster(
- strategy=attributes["strategy"], sp=attributes["sp"]
- )
-
- def inference(self, data):
- try:
- predict_length = self._attributes["predict_length"]
- self._model.fit(data)
- output = self._model.predict(fh=range(predict_length))
- output = np.array(output, dtype=np.float64)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
-
-
-class STLForecasterModel(BuiltInModel):
- def __init__(self, attributes):
- super(STLForecasterModel, self).__init__(attributes)
- self._model = STLForecaster(
- sp=attributes["sp"],
- seasonal=attributes["seasonal"],
- seasonal_deg=attributes["seasonal_deg"],
- trend_deg=attributes["trend_deg"],
- low_pass_deg=attributes["low_pass_deg"],
- seasonal_jump=attributes["seasonal_jump"],
- trend_jump=attributes["trend_jump"],
- low_pass_jump=attributes["low_pass_jump"],
- )
-
- def inference(self, data):
- try:
- predict_length = self._attributes["predict_length"]
- self._model.fit(data)
- output = self._model.predict(fh=range(predict_length))
- output = np.array(output, dtype=np.float64)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
-
-
-class GMMHMMModel(BuiltInModel):
- def __init__(self, attributes):
- super(GMMHMMModel, self).__init__(attributes)
- self._model = GMMHMM(
- n_components=attributes["n_components"],
- n_mix=attributes["n_mix"],
- min_covar=attributes["min_covar"],
- startprob_prior=attributes["startprob_prior"],
- transmat_prior=attributes["transmat_prior"],
- means_prior=attributes["means_prior"],
- means_weight=attributes["means_weight"],
- weights_prior=attributes["weights_prior"],
- algorithm=attributes["algorithm"],
- covariance_type=attributes["covariance_type"],
- n_iter=attributes["n_iter"],
- tol=attributes["tol"],
- params=attributes["params"],
- init_params=attributes["init_params"],
- implementation=attributes["implementation"],
- )
-
- def inference(self, data):
- try:
- self._model.fit(data)
- output = self._model.predict(data)
- output = np.array(output, dtype=np.int32)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
-
-
-class GaussianHmmModel(BuiltInModel):
- def __init__(self, attributes):
- super(GaussianHmmModel, self).__init__(attributes)
- self._model = GaussianHMM(
- n_components=attributes["n_components"],
- covariance_type=attributes["covariance_type"],
- min_covar=attributes["min_covar"],
- startprob_prior=attributes["startprob_prior"],
- transmat_prior=attributes["transmat_prior"],
- means_prior=attributes["means_prior"],
- means_weight=attributes["means_weight"],
- covars_prior=attributes["covars_prior"],
- covars_weight=attributes["covars_weight"],
- algorithm=attributes["algorithm"],
- n_iter=attributes["n_iter"],
- tol=attributes["tol"],
- params=attributes["params"],
- init_params=attributes["init_params"],
- implementation=attributes["implementation"],
- )
-
- def inference(self, data):
- try:
- self._model.fit(data)
- output = self._model.predict(data)
- output = np.array(output, dtype=np.int32)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
-
-
-class STRAYModel(BuiltInModel):
- def __init__(self, attributes):
- super(STRAYModel, self).__init__(attributes)
- self._model = STRAY(
- alpha=attributes["alpha"],
- k=attributes["k"],
- knn_algorithm=attributes["knn_algorithm"],
- p=attributes["p"],
- size_threshold=attributes["size_threshold"],
- outlier_tail=attributes["outlier_tail"],
- )
-
- def inference(self, data):
- try:
- data = MinMaxScaler().fit_transform(data)
- output = self._model.fit_transform(data)
- # change the output to int
- output = np.array(output, dtype=np.int32)
- return output
- except Exception as e:
- raise InferenceModelInternalError(str(e))
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
new file mode 100644
index 0000000000000..c42ec98551b83
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from enum import Enum
+
+# Model file constants
+MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors"
+MODEL_CONFIG_FILE_IN_JSON = "config.json"
+MODEL_WEIGHTS_FILE_IN_PT = "model.pt"
+MODEL_CONFIG_FILE_IN_YAML = "config.yaml"
+
+
+# Model file constants
+MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors"
+MODEL_CONFIG_FILE_IN_JSON = "config.json"
+MODEL_WEIGHTS_FILE_IN_PT = "model.pt"
+MODEL_CONFIG_FILE_IN_YAML = "config.yaml"
+
+
+class ModelCategory(Enum):
+ BUILTIN = "builtin"
+ USER_DEFINED = "user_defined"
+
+
+class ModelStates(Enum):
+ INACTIVE = "inactive"
+ ACTIVATING = "activating"
+ ACTIVE = "active"
+ DROPPING = "dropping"
+
+
+class UriType(Enum):
+ REPO = "repo"
+ FILE = "file"
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py
deleted file mode 100644
index 348f9924316b6..0000000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from enum import Enum
-from typing import List
-
-
-class BuiltInModelType(Enum):
- # forecast models
- ARIMA = "Arima"
- HOLTWINTERS = "HoltWinters"
- EXPONENTIAL_SMOOTHING = "ExponentialSmoothing"
- NAIVE_FORECASTER = "NaiveForecaster"
- STL_FORECASTER = "StlForecaster"
-
- # anomaly detection models
- GAUSSIAN_HMM = "GaussianHmm"
- GMM_HMM = "GmmHmm"
- STRAY = "Stray"
-
- # large time series models (LTSM)
- TIMER_XL = "Timer-XL"
- # sundial
- SUNDIAL = "Timer-Sundial"
-
- @classmethod
- def values(cls) -> List[str]:
- return [item.value for item in cls]
-
- @staticmethod
- def is_built_in_model(model_type: str) -> bool:
- """
- Check if the given model type corresponds to a built-in model.
- """
- return model_type in BuiltInModelType.values()
-
-
-class ModelFileType(Enum):
- SAFETENSORS = "safetensors"
- PYTORCH = "pytorch"
- UNKNOWN = "unknown"
-
-
-class ModelCategory(Enum):
- BUILT_IN = "BUILT-IN"
- FINE_TUNED = "FINE-TUNED"
- USER_DEFINED = "USER-DEFINED"
-
-
-class ModelStates(Enum):
- ACTIVE = "ACTIVE"
- INACTIVE = "INACTIVE"
- LOADING = "LOADING"
- DROPPING = "DROPPING"
- TRAINING = "TRAINING"
- FAILED = "FAILED"
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py
deleted file mode 100644
index 26d863156f379..0000000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py
+++ /dev/null
@@ -1,291 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import glob
-import os
-import shutil
-from urllib.parse import urljoin
-
-import yaml
-
-from iotdb.ainode.core.constant import (
- MODEL_CONFIG_FILE_IN_YAML,
- MODEL_WEIGHTS_FILE_IN_PT,
-)
-from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError
-from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.model_enums import ModelFileType
-from iotdb.ainode.core.model.uri_utils import (
- UriType,
- download_file,
- download_snapshot_from_hf,
-)
-from iotdb.ainode.core.util.serde import get_data_type_byte_from_str
-from iotdb.thrift.ainode.ttypes import TConfigs
-
-logger = Logger()
-
-
-def fetch_model_by_uri(
- uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType
-):
- """
- Fetch the model files from the specified URI.
-
- Args:
- uri_type (UriType): type of the URI, either repo, file, http or https
- uri (str): a network or a local path of the model to be registered
- storage_path (str): path to save the whole model, including weights, config, codes, etc.
- model_file_type (ModelFileType): The type of model file, either safetensors or pytorch
- Returns: TODO: Will be removed in future
- configs: TConfigs
- attributes: str
- """
- if uri_type == UriType.REPO or uri_type in [UriType.HTTP, UriType.HTTPS]:
- return _fetch_model_from_network(uri, storage_path, model_file_type)
- elif uri_type == UriType.FILE:
- return _fetch_model_from_local(uri, storage_path, model_file_type)
- else:
- raise InvalidUriError(f"Invalid URI type: {uri_type}")
-
-
-def _fetch_model_from_network(
- uri: str, storage_path: str, model_file_type: ModelFileType
-):
- """
- Returns: TODO: Will be removed in future
- configs: TConfigs
- attributes: str
- """
- if model_file_type == ModelFileType.SAFETENSORS:
- download_snapshot_from_hf(uri, storage_path)
- return _process_huggingface_files(storage_path)
-
- # TODO: The following codes might be refactored in future
- # concat uri to get complete url
- uri = uri if uri.endswith("/") else uri + "/"
- target_model_path = urljoin(uri, MODEL_WEIGHTS_FILE_IN_PT)
- target_config_path = urljoin(uri, MODEL_CONFIG_FILE_IN_YAML)
-
- # download config file
- config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML)
- download_file(target_config_path, config_storage_path)
-
- # read and parse config dict from config.yaml
- with open(config_storage_path, "r", encoding="utf-8") as file:
- config_dict = yaml.safe_load(file)
- configs, attributes = _parse_inference_config(config_dict)
-
- # if config.yaml is correct, download model file
- model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT)
- download_file(target_model_path, model_storage_path)
- return configs, attributes
-
-
-def _fetch_model_from_local(
- uri: str, storage_path: str, model_file_type: ModelFileType
-):
- """
- Returns: TODO: Will be removed in future
- configs: TConfigs
- attributes: str
- """
- if model_file_type == ModelFileType.SAFETENSORS:
- # copy anything in the uri to local_dir
- for file in os.listdir(uri):
- shutil.copy(os.path.join(uri, file), storage_path)
- return _process_huggingface_files(storage_path)
- # concat uri to get complete path
- target_model_path = os.path.join(uri, MODEL_WEIGHTS_FILE_IN_PT)
- model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT)
- target_config_path = os.path.join(uri, MODEL_CONFIG_FILE_IN_YAML)
- config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML)
-
- # check if file exist
- exist_model_file = os.path.exists(target_model_path)
- exist_config_file = os.path.exists(target_config_path)
-
- configs = None
- attributes = None
- if exist_model_file and exist_config_file:
- # copy config.yaml
- shutil.copy(target_config_path, config_storage_path)
- logger.info(
- f"copy file from {target_config_path} to {config_storage_path} success"
- )
-
- # read and parse config dict from config.yaml
- with open(config_storage_path, "r", encoding="utf-8") as file:
- config_dict = yaml.safe_load(file)
- configs, attributes = _parse_inference_config(config_dict)
-
- # if config.yaml is correct, copy model file
- shutil.copy(target_model_path, model_storage_path)
- logger.info(
- f"copy file from {target_model_path} to {model_storage_path} success"
- )
-
- elif not exist_model_file or not exist_config_file:
- raise InvalidUriError(uri)
-
- return configs, attributes
-
-
-def _parse_inference_config(config_dict):
- """
- Args:
- config_dict: dict
- - configs: dict
- - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2]
- - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2]
- - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64
- - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64
- - attributes: dict
- Returns:
- configs: TConfigs
- attributes: str
- """
- configs = config_dict["configs"]
-
- # check if input_shape and output_shape are two-dimensional array
- if not (
- isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2
- ):
- raise BadConfigValueError(
- "input_shape",
- configs["input_shape"],
- "input_shape should be a two-dimensional array.",
- )
- if not (
- isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2
- ):
- raise BadConfigValueError(
- "output_shape",
- configs["output_shape"],
- "output_shape should be a two-dimensional array.",
- )
-
- # check if input_shape and output_shape are positive integer
- input_shape_is_positive_number = (
- isinstance(configs["input_shape"][0], int)
- and isinstance(configs["input_shape"][1], int)
- and configs["input_shape"][0] > 0
- and configs["input_shape"][1] > 0
- )
- if not input_shape_is_positive_number:
- raise BadConfigValueError(
- "input_shape",
- configs["input_shape"],
- "element in input_shape should be positive integer.",
- )
-
- output_shape_is_positive_number = (
- isinstance(configs["output_shape"][0], int)
- and isinstance(configs["output_shape"][1], int)
- and configs["output_shape"][0] > 0
- and configs["output_shape"][1] > 0
- )
- if not output_shape_is_positive_number:
- raise BadConfigValueError(
- "output_shape",
- configs["output_shape"],
- "element in output_shape should be positive integer.",
- )
-
- # check if input_type and output_type are one-dimensional array with right length
- if "input_type" in configs and not (
- isinstance(configs["input_type"], list)
- and len(configs["input_type"]) == configs["input_shape"][1]
- ):
- raise BadConfigValueError(
- "input_type",
- configs["input_type"],
- "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].",
- )
-
- if "output_type" in configs and not (
- isinstance(configs["output_type"], list)
- and len(configs["output_type"]) == configs["output_shape"][1]
- ):
- raise BadConfigValueError(
- "output_type",
- configs["output_type"],
- "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].",
- )
-
- # parse input_type and output_type to byte
- if "input_type" in configs:
- input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]]
- else:
- input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][
- 1
- ]
-
- if "output_type" in configs:
- output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]]
- else:
- output_type = [get_data_type_byte_from_str("float32")] * configs[
- "output_shape"
- ][1]
-
- # parse attributes
- attributes = ""
- if "attributes" in config_dict:
- attributes = str(config_dict["attributes"])
-
- return (
- TConfigs(
- configs["input_shape"], configs["output_shape"], input_type, output_type
- ),
- attributes,
- )
-
-
-def _process_huggingface_files(local_dir: str):
- """
- TODO: Currently, we use this function to convert the model config from huggingface, we will refactor this in the future.
- """
- config_file = None
- for config_name in ["config.json", "model_config.json"]:
- config_path = os.path.join(local_dir, config_name)
- if os.path.exists(config_path):
- config_file = config_path
- break
-
- if not config_file:
- raise InvalidUriError(f"No config.json found in {local_dir}")
-
- safetensors_files = glob.glob(os.path.join(local_dir, "*.safetensors"))
- if not safetensors_files:
- raise InvalidUriError(f"No .safetensors files found in {local_dir}")
-
- simple_config = {
- "configs": {
- "input_shape": [96, 1],
- "output_shape": [96, 1],
- "input_type": ["float32"],
- "output_type": ["float32"],
- },
- "attributes": {
- "model_type": "huggingface_model",
- "source_dir": local_dir,
- "files": [os.path.basename(f) for f in safetensors_files],
- },
- }
-
- configs, attributes = _parse_inference_config(simple_config)
- return configs, attributes
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index 167bfd76640d1..718ead530dd2c 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -15,140 +15,118 @@
# specific language governing permissions and limitations
# under the License.
#
-import glob
-import os
-from iotdb.ainode.core.constant import (
- MODEL_CONFIG_FILE_IN_JSON,
- MODEL_CONFIG_FILE_IN_YAML,
- MODEL_WEIGHTS_FILE_IN_PT,
- MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
-)
-from iotdb.ainode.core.model.model_enums import (
- BuiltInModelType,
- ModelCategory,
- ModelFileType,
- ModelStates,
-)
+from typing import Dict, Optional
-
-def get_model_file_type(model_path: str) -> ModelFileType:
- """
- Determine the file type of the specified model directory.
- """
- if _has_safetensors_format(model_path):
- return ModelFileType.SAFETENSORS
- elif _has_pytorch_format(model_path):
- return ModelFileType.PYTORCH
- else:
- return ModelFileType.UNKNOWN
-
-
-def _has_safetensors_format(path: str) -> bool:
- """Check if directory contains safetensors files."""
- safetensors_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_SAFETENSORS))
- json_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_JSON))
- return len(safetensors_files) > 0 and len(json_files) > 0
-
-
-def _has_pytorch_format(path: str) -> bool:
- """Check if directory contains pytorch files."""
- pt_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_PT))
- yaml_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_YAML))
- return len(pt_files) > 0 and len(yaml_files) > 0
-
-
-def get_built_in_model_type(model_type: str) -> BuiltInModelType:
- if not BuiltInModelType.is_built_in_model(model_type):
- raise ValueError(f"Invalid built-in model type: {model_type}")
- return BuiltInModelType(model_type)
+from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates
class ModelInfo:
def __init__(
self,
model_id: str,
- model_type: str,
category: ModelCategory,
state: ModelStates,
+ model_type: str = "",
+ config_cls: str = "",
+ model_cls: str = "",
+ pipeline_cls: str = "",
+ repo_id: str = "",
+ auto_map: Optional[Dict] = None,
+ _transformers_registered: bool = False,
):
self.model_id = model_id
self.model_type = model_type
self.category = category
self.state = state
+ self.config_cls = config_cls
+ self.model_cls = model_cls
+ self.pipeline_cls = pipeline_cls
+ self.repo_id = repo_id
+ self.auto_map = auto_map # If exists, indicates it's a Transformers model
+ self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers
+ def __repr__(self):
+ return (
+ f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, "
+ f"category={self.category.value}, state={self.state.value}, "
+ f"has_auto_map={self.auto_map is not None})"
+ )
-TIMER_REPO_ID = {
- BuiltInModelType.TIMER_XL: "thuml/timer-base-84m",
- BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m",
-}
-# Built-in machine learning models, they can be employed directly
-BUILT_IN_MACHINE_LEARNING_MODEL_MAP = {
+BUILTIN_SKTIME_MODEL_MAP = {
# forecast models
"arima": ModelInfo(
model_id="arima",
- model_type=BuiltInModelType.ARIMA.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
"holtwinters": ModelInfo(
model_id="holtwinters",
- model_type=BuiltInModelType.HOLTWINTERS.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
"exponential_smoothing": ModelInfo(
model_id="exponential_smoothing",
- model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
"naive_forecaster": ModelInfo(
model_id="naive_forecaster",
- model_type=BuiltInModelType.NAIVE_FORECASTER.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
"stl_forecaster": ModelInfo(
model_id="stl_forecaster",
- model_type=BuiltInModelType.STL_FORECASTER.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
# anomaly detection models
"gaussian_hmm": ModelInfo(
model_id="gaussian_hmm",
- model_type=BuiltInModelType.GAUSSIAN_HMM.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
"gmm_hmm": ModelInfo(
model_id="gmm_hmm",
- model_type=BuiltInModelType.GMM_HMM.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
"stray": ModelInfo(
model_id="stray",
- model_type=BuiltInModelType.STRAY.value,
- category=ModelCategory.BUILT_IN,
+ category=ModelCategory.BUILTIN,
state=ModelStates.ACTIVE,
+ model_type="sktime",
),
}
-# Built-in large time series models (LTSM), their weights are not included in AINode by default
-BUILT_IN_LTSM_MAP = {
+# Built-in huggingface transformers models, their weights are not included in AINode by default
+BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
"timer_xl": ModelInfo(
model_id="timer_xl",
- model_type=BuiltInModelType.TIMER_XL.value,
- category=ModelCategory.BUILT_IN,
- state=ModelStates.LOADING,
+ category=ModelCategory.BUILTIN,
+ state=ModelStates.INACTIVE,
+ model_type="timer",
+ config_cls="configuration_timer.TimerConfig",
+ model_cls="modeling_timer.TimerForPrediction",
+ pipeline_cls="pipeline_timer.TimerPipeline",
+ repo_id="thuml/timer-base-84m",
),
"sundial": ModelInfo(
model_id="sundial",
- model_type=BuiltInModelType.SUNDIAL.value,
- category=ModelCategory.BUILT_IN,
- state=ModelStates.LOADING,
+ category=ModelCategory.BUILTIN,
+ state=ModelStates.INACTIVE,
+ model_type="sundial",
+ config_cls="configuration_sundial.SundialConfig",
+ model_cls="modeling_sundial.SundialForPrediction",
+ pipeline_cls="pipeline_sundial.SundialPipeline",
+ repo_id="thuml/sundial-base-128m",
),
}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
new file mode 100644
index 0000000000000..a6e3b1f7b5e38
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from pathlib import Path
+from typing import Any
+
+import torch
+from transformers import (
+ AutoConfig,
+ AutoModelForCausalLM,
+ AutoModelForNextSentencePrediction,
+ AutoModelForSeq2SeqLM,
+ AutoModelForSequenceClassification,
+ AutoModelForTimeSeriesPrediction,
+ AutoModelForTokenClassification,
+)
+
+from iotdb.ainode.core.config import AINodeDescriptor
+from iotdb.ainode.core.exception import ModelNotExistError
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.model_constants import ModelCategory
+from iotdb.ainode.core.model.model_info import ModelInfo
+from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model
+from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path
+
+logger = Logger()
+
+
+def load_model(model_info: ModelInfo, **model_kwargs) -> Any:
+ if model_info.auto_map is not None:
+ model = load_model_from_transformers(model_info, **model_kwargs)
+ else:
+ if model_info.model_type == "sktime":
+ model = create_sktime_model(model_info.model_id)
+ else:
+ model = load_model_from_pt(model_info, **model_kwargs)
+
+ logger.info(
+ f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully."
+ )
+ return model
+
+
+def load_model_from_transformers(model_info: ModelInfo, **model_kwargs):
+ device_map = model_kwargs.get("device_map", "cpu")
+ trust_remote_code = model_kwargs.get("trust_remote_code", True)
+ train_from_scratch = model_kwargs.get("train_from_scratch", False)
+
+ model_path = os.path.join(
+ os.getcwd(),
+ AINodeDescriptor().get_config().get_ain_models_dir(),
+ model_info.category.value,
+ model_info.model_id,
+ )
+
+ if model_info.category == ModelCategory.BUILTIN:
+ module_name = (
+ AINodeDescriptor().get_config().get_ain_models_builtin_dir()
+ + "."
+ + model_info.model_id
+ )
+ config_cls = import_class_from_path(module_name, model_info.config_cls)
+ model_cls = import_class_from_path(module_name, model_info.model_cls)
+ elif model_info.model_cls and model_info.config_cls:
+ module_parent = str(Path(model_path).parent.absolute())
+ with temporary_sys_path(module_parent):
+ config_cls = import_class_from_path(
+ model_info.model_id, model_info.config_cls
+ )
+ model_cls = import_class_from_path(
+ model_info.model_id, model_info.model_cls
+ )
+ else:
+ config_cls = AutoConfig.from_pretrained(model_path)
+ if type(config_cls) in AutoModelForTimeSeriesPrediction._model_mapping.keys():
+ model_cls = AutoModelForTimeSeriesPrediction
+ elif (
+ type(config_cls) in AutoModelForNextSentencePrediction._model_mapping.keys()
+ ):
+ model_cls = AutoModelForNextSentencePrediction
+ elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys():
+ model_cls = AutoModelForSeq2SeqLM
+ elif (
+ type(config_cls) in AutoModelForSequenceClassification._model_mapping.keys()
+ ):
+ model_cls = AutoModelForSequenceClassification
+ elif type(config_cls) in AutoModelForTokenClassification._model_mapping.keys():
+ model_cls = AutoModelForTokenClassification
+ else:
+ model_cls = AutoModelForCausalLM
+
+ if train_from_scratch:
+ model = model_cls.from_config(
+ config_cls, trust_remote_code=trust_remote_code, device_map=device_map
+ )
+ else:
+ model = model_cls.from_pretrained(
+ model_path,
+ trust_remote_code=trust_remote_code,
+ device_map=device_map,
+ )
+
+ return model
+
+
+def load_model_from_pt(model_info: ModelInfo, **kwargs):
+ device_map = kwargs.get("device_map", "cpu")
+ acceleration = kwargs.get("acceleration", False)
+ model_path = os.path.join(
+ os.getcwd(),
+ AINodeDescriptor().get_config().get_ain_models_dir(),
+ model_info.category.value,
+ model_info.model_id,
+ )
+ model_file = os.path.join(model_path, "model.pt")
+ if not os.path.exists(model_file):
+ logger.error(f"Model file not found at {model_file}.")
+ raise ModelNotExistError(model_file)
+ model = torch.jit.load(model_file)
+ if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration:
+ return model
+ try:
+ model = torch.compile(model)
+ except Exception as e:
+ logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}")
+ return model.to(device_map)
+
+
+def load_model_for_efficient_inference():
+ # TODO: An efficient model loading method for inference based on model_arguments
+ pass
+
+
+def load_model_for_powerful_finetune():
+ # TODO: An powerful model loading method for finetune based on model_arguments
+ pass
+
+
+def unload_model():
+ pass
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
index e346f569102e3..5194ed4df1bd7 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py
@@ -20,43 +20,37 @@
import json
import os
import shutil
-from collections.abc import Callable
-from typing import Dict
+from pathlib import Path
+from typing import Dict, List, Optional
-import torch
-from torch import nn
+from huggingface_hub import hf_hub_download, snapshot_download
+from transformers import AutoConfig, AutoModelForCausalLM
from iotdb.ainode.core.config import AINodeDescriptor
-from iotdb.ainode.core.constant import (
- MODEL_CONFIG_FILE_IN_JSON,
- MODEL_WEIGHTS_FILE_IN_PT,
- TSStatusCode,
-)
-from iotdb.ainode.core.exception import (
- BuiltInModelDeletionError,
- ModelNotExistError,
- UnsupportedError,
-)
+from iotdb.ainode.core.constant import TSStatusCode
+from iotdb.ainode.core.exception import BuiltInModelDeletionError
from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.built_in_model_factory import (
- download_built_in_ltsm_from_hf_if_necessary,
- fetch_built_in_model,
-)
-from iotdb.ainode.core.model.model_enums import (
- BuiltInModelType,
+from iotdb.ainode.core.model.model_constants import (
+ MODEL_CONFIG_FILE_IN_JSON,
+ MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
ModelCategory,
- ModelFileType,
ModelStates,
+ UriType,
)
-from iotdb.ainode.core.model.model_factory import fetch_model_by_uri
from iotdb.ainode.core.model.model_info import (
- BUILT_IN_LTSM_MAP,
- BUILT_IN_MACHINE_LEARNING_MODEL_MAP,
+ BUILTIN_HF_TRANSFORMERS_MODEL_MAP,
+ BUILTIN_SKTIME_MODEL_MAP,
ModelInfo,
- get_built_in_model_type,
- get_model_file_type,
)
-from iotdb.ainode.core.model.uri_utils import get_model_register_strategy
+from iotdb.ainode.core.model.utils import (
+ ensure_init_file,
+ get_parsed_uri,
+ import_class_from_path,
+ load_model_config_in_json,
+ parse_uri_type,
+ temporary_sys_path,
+ validate_model_files,
+)
from iotdb.ainode.core.util.lock import ModelLockPool
from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp
from iotdb.thrift.common.ttypes import TSStatus
@@ -64,320 +58,368 @@
logger = Logger()
-class ModelStorage(object):
+class ModelStorage:
+ """Model storage class - unified management of model discovery and registration"""
+
def __init__(self):
- self._model_dir = os.path.join(
+ self._models_dir = os.path.join(
os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir()
)
- if not os.path.exists(self._model_dir):
- try:
- os.makedirs(self._model_dir)
- except PermissionError as e:
- logger.error(e)
- raise e
- self._builtin_model_dir = os.path.join(
- os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir()
- )
- if not os.path.exists(self._builtin_model_dir):
- try:
- os.makedirs(self._builtin_model_dir)
- except PermissionError as e:
- logger.error(e)
- raise e
+ # Unified storage: category -> {model_id -> ModelInfo}
+ self._models: Dict[str, Dict[str, ModelInfo]] = {
+ ModelCategory.BUILTIN.value: {},
+ ModelCategory.USER_DEFINED.value: {},
+ }
+ # Async download executor
+ self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2)
+ # Thread lock pool for protecting concurrent access to model information
self._lock_pool = ModelLockPool()
- self._executor = concurrent.futures.ThreadPoolExecutor(
- max_workers=1
- ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading.
- self._model_info_map: Dict[str, ModelInfo] = {}
- self._init_model_info_map()
+ self._initialize_directories()
+ self.discover_all_models()
+
+ def _initialize_directories(self):
+ """Initialize directory structure and ensure __init__.py files exist"""
+ os.makedirs(self._models_dir, exist_ok=True)
+ ensure_init_file(self._models_dir)
+ for category in ModelCategory:
+ category_path = os.path.join(self._models_dir, category.value)
+ os.makedirs(category_path, exist_ok=True)
+ ensure_init_file(category_path)
+
+ # ==================== Discovery Methods ====================
+
+ def discover_all_models(self):
+ """Scan file system to discover all models"""
+ self._discover_category(ModelCategory.BUILTIN)
+ self._discover_category(ModelCategory.USER_DEFINED)
+
+ def _discover_category(self, category: ModelCategory):
+ """Discover all models in a category directory"""
+ category_path = os.path.join(self._models_dir, category.value)
+ if category == ModelCategory.BUILTIN:
+ self._discover_builtin_models(category_path)
+ elif category == ModelCategory.USER_DEFINED:
+ for model_id in os.listdir(category_path):
+ if os.path.isdir(os.path.join(category_path, model_id)):
+ self._process_user_defined_model_directory(
+ os.path.join(category_path, model_id), model_id
+ )
- def _init_model_info_map(self):
- """
- Initialize the model info map.
- """
- # 1. initialize built-in and ready-to-use models
- for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP:
- self._model_info_map[model_id] = BUILT_IN_MACHINE_LEARNING_MODEL_MAP[
- model_id
- ]
- # 2. retrieve fine-tuned models from the built-in model directory
- fine_tuned_models = self._retrieve_fine_tuned_models()
- for model_id in fine_tuned_models:
- self._model_info_map[model_id] = fine_tuned_models[model_id]
- # 3. automatically downloading the weights of built-in LSTM models when necessary
- for model_id in BUILT_IN_LTSM_MAP:
- if model_id not in self._model_info_map:
- self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id]
- future = self._executor.submit(
- self._download_built_in_model_if_necessary, model_id
- )
- future.add_done_callback(
- lambda f, mid=model_id: self._callback_model_download_result(f, mid)
- )
- # 4. retrieve user-defined models from the model directory
- user_defined_models = self._retrieve_user_defined_models()
- for model_id in user_defined_models:
- self._model_info_map[model_id] = user_defined_models[model_id]
+ def _discover_builtin_models(self, category_path: str):
+ # Register SKTIME models directly from map
+ for model_id in BUILTIN_SKTIME_MODEL_MAP.keys():
+ with self._lock_pool.get_lock(model_id).write_lock():
+ self._models[ModelCategory.BUILTIN.value][model_id] = (
+ BUILTIN_SKTIME_MODEL_MAP[model_id]
+ )
- def _retrieve_fine_tuned_models(self):
- """
- Retrieve fine-tuned models from the built-in model directory.
+ # Process HuggingFace Transformers models
+ for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys():
+ model_dir = os.path.join(category_path, model_id)
+ os.makedirs(model_dir, exist_ok=True)
+ self._process_builtin_model_directory(model_dir, model_id)
- Returns:
- {"model_id": ModelInfo}
- """
- result = {}
- build_in_dirs = [
- d
- for d in os.listdir(self._builtin_model_dir)
- if os.path.isdir(os.path.join(self._builtin_model_dir, d))
- ]
- for model_id in build_in_dirs:
- config_file_path = os.path.join(
- self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON
+ def _process_builtin_model_directory(self, model_dir: str, model_id: str):
+ """Handling the discovery logic for a builtin model directory."""
+ ensure_init_file(model_dir)
+ with self._lock_pool.get_lock(model_id).write_lock():
+ # Check if model already exists and is in a valid state
+ existing_model = self._models[ModelCategory.BUILTIN.value].get(model_id)
+ if existing_model:
+ # If model is already ACTIVATING or ACTIVE, skip duplicate download
+ if existing_model.state in (ModelStates.ACTIVATING, ModelStates.ACTIVE):
+ return
+
+ # If model not exists or is INACTIVE, we'll try to update its info and download its weights
+ self._models[ModelCategory.BUILTIN.value][model_id] = (
+ BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id]
)
- if os.path.isfile(config_file_path):
- with open(config_file_path, "r") as f:
- model_config = json.load(f)
- if "model_type" in model_config:
- model_type = model_config["model_type"]
- model_info = ModelInfo(
- model_id=model_id,
- model_type=model_type,
- category=ModelCategory.FINE_TUNED,
- state=ModelStates.ACTIVE,
+ self._models[ModelCategory.BUILTIN.value][
+ model_id
+ ].state = ModelStates.ACTIVATING
+
+ def _download_model_if_necessary() -> bool:
+ """Returns: True if the model is existed or downloaded successfully, False otherwise."""
+ repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id
+ weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
+ config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
+ if not os.path.exists(weights_path):
+ try:
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ local_dir=model_dir,
)
- # Refactor the built-in model category
- if "timer_xl" == model_id:
- model_info.category = ModelCategory.BUILT_IN
- if "sundial" == model_id:
- model_info.category = ModelCategory.BUILT_IN
- # Compatible patch with the codes in HuggingFace
- if "timer" == model_type:
- model_info.model_type = BuiltInModelType.TIMER_XL.value
- if "sundial" == model_type:
- model_info.model_type = BuiltInModelType.SUNDIAL.value
- result[model_id] = model_info
- return result
-
- def _download_built_in_model_if_necessary(self, model_id: str) -> bool:
- """
- Download the built-in model if it is not already downloaded.
-
- Args:
- model_id (str): The ID of the model to download.
+ except Exception as e:
+ logger.error(
+ f"Failed to download model weights from HuggingFace: {e}"
+ )
+ return False
+ if not os.path.exists(config_path):
+ try:
+ hf_hub_download(
+ repo_id=repo_id,
+ filename=MODEL_CONFIG_FILE_IN_JSON,
+ local_dir=model_dir,
+ )
+ except Exception as e:
+ logger.error(
+ f"Failed to download model config from HuggingFace: {e}"
+ )
+ return False
+ return True
- Return:
- bool: True if the model is existed or downloaded successfully, False otherwise.
- """
- with self._lock_pool.get_lock(model_id).write_lock():
- local_dir = os.path.join(self._builtin_model_dir, model_id)
- return download_built_in_ltsm_from_hf_if_necessary(
- get_built_in_model_type(self._model_info_map[model_id].model_type),
- local_dir,
- )
+ future = self._executor.submit(_download_model_if_necessary)
+ future.add_done_callback(
+ lambda f, mid=model_id: self._callback_model_download_result(f, mid)
+ )
def _callback_model_download_result(self, future, model_id: str):
+ """Callback function for handling model download results"""
with self._lock_pool.get_lock(model_id).write_lock():
- if future.result():
- self._model_info_map[model_id].state = ModelStates.ACTIVE
- logger.info(
- f"The built-in model: {model_id} is active and ready to use."
- )
- else:
- self._model_info_map[model_id].state = ModelStates.INACTIVE
-
- def _retrieve_user_defined_models(self):
- """
- Retrieve user_defined models from the model directory.
+ try:
+ if future.result():
+ model_info = self._models[ModelCategory.BUILTIN.value][model_id]
+ model_info.state = ModelStates.ACTIVE
+ config_path = os.path.join(
+ self._models_dir,
+ ModelCategory.BUILTIN.value,
+ model_id,
+ MODEL_CONFIG_FILE_IN_JSON,
+ )
+ if os.path.exists(config_path):
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = json.load(f)
+ if model_info.model_type == "":
+ model_info.model_type = config.get("model_type", "")
+ model_info.auto_map = config.get("auto_map", None)
+ logger.info(
+ f"Model {model_id} downloaded successfully and is ready to use."
+ )
+ else:
+ self._models[ModelCategory.BUILTIN.value][
+ model_id
+ ].state = ModelStates.INACTIVE
+ logger.warning(f"Failed to download model {model_id}.")
+ except Exception as e:
+ logger.error(f"Error in download callback for model {model_id}: {e}")
+ self._models[ModelCategory.BUILTIN.value][
+ model_id
+ ].state = ModelStates.INACTIVE
+
+ def _process_user_defined_model_directory(self, model_dir: str, model_id: str):
+ """Handling the discovery logic for a user-defined model directory."""
+ config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
+ model_type = ""
+ auto_map = {}
+ pipeline_cls = ""
+ if os.path.exists(config_path):
+ config = load_model_config_in_json(config_path)
+ model_type = config.get("model_type", "")
+ auto_map = config.get("auto_map", None)
+ pipeline_cls = config.get("pipeline_cls", "")
- Returns:
- {"model_id": ModelInfo}
- """
- result = {}
- user_dirs = [
- d
- for d in os.listdir(self._model_dir)
- if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights"
- ]
- for model_id in user_dirs:
- result[model_id] = ModelInfo(
+ with self._lock_pool.get_lock(model_id).write_lock():
+ model_info = ModelInfo(
model_id=model_id,
- model_type="",
+ model_type=model_type,
category=ModelCategory.USER_DEFINED,
state=ModelStates.ACTIVE,
+ pipeline_cls=pipeline_cls,
+ auto_map=auto_map,
+ _transformers_registered=False, # Lazy registration
)
- return result
+ self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
+
+ # ==================== Registration Methods ====================
- def register_model(self, model_id: str, uri: str):
+ def register_model(self, model_id: str, uri: str) -> bool:
"""
- Args:
- model_id: id of model to register
- uri: network or local dir path of the model to register
- Returns:
- configs: TConfigs
- attributes: str
+ Supported URI formats:
+ - repo:// (Maybe in the future)
+ - file://
"""
+ uri_type = parse_uri_type(uri)
+ parsed_uri = get_parsed_uri(uri)
+
+ model_dir = os.path.join(
+ self._models_dir, ModelCategory.USER_DEFINED.value, model_id
+ )
+ os.makedirs(model_dir, exist_ok=True)
+ ensure_init_file(model_dir)
+
+ if uri_type == UriType.REPO:
+ self._fetch_model_from_hf_repo(parsed_uri, model_dir)
+ else:
+ self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir)
+
+ config_path, _ = validate_model_files(model_dir)
+ config = load_model_config_in_json(config_path)
+ model_type = config.get("model_type", "")
+ auto_map = config.get("auto_map")
+ pipeline_cls = config.get("pipeline_cls", "")
+
with self._lock_pool.get_lock(model_id).write_lock():
- storage_path = os.path.join(self._model_dir, f"{model_id}")
- # create storage dir if not exist
- if not os.path.exists(storage_path):
- os.makedirs(storage_path)
- uri_type, parsed_uri, model_file_type = get_model_register_strategy(uri)
- self._model_info_map[model_id] = ModelInfo(
+ model_info = ModelInfo(
model_id=model_id,
- model_type="",
+ model_type=model_type,
category=ModelCategory.USER_DEFINED,
- state=ModelStates.LOADING,
+ state=ModelStates.ACTIVE,
+ pipeline_cls=pipeline_cls,
+ auto_map=auto_map,
+ _transformers_registered=False, # Register later
+ )
+ self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info
+
+ if auto_map:
+ # Transformers model: immediately register to Transformers auto-loading mechanism
+ success = self._register_transformers_model(model_info)
+ if success:
+ with self._lock_pool.get_lock(model_id).write_lock():
+ model_info._transformers_registered = True
+ else:
+ with self._lock_pool.get_lock(model_id).write_lock():
+ model_info.state = ModelStates.INACTIVE
+ logger.error(f"Failed to register Transformers model {model_id}")
+ return False
+ else:
+ # Other type models: only log
+ self._register_other_model(model_info)
+
+ logger.info(f"Successfully registered model {model_id} from URI: {uri}")
+ return True
+
+ def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str):
+ logger.info(
+ f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}"
+ )
+ # Use snapshot_download to download entire repository (including config.json and model.safetensors)
+ try:
+ snapshot_download(
+ repo_id=repo_id,
+ local_dir=storage_path,
+ local_dir_use_symlinks=False,
+ )
+ except Exception as e:
+ logger.error(f"Failed to download model from HuggingFace: {e}")
+ raise
+
+ def _fetch_model_from_local(self, source_path: str, storage_path: str):
+ logger.info(f"Copying model from local path: {source_path} -> {storage_path}")
+ source_dir = Path(source_path)
+ if not source_dir.is_dir():
+ raise ValueError(
+ f"Source path does not exist or is not a directory: {source_path}"
)
- try:
- # TODO: The uri should be fetched asynchronously
- configs, attributes = fetch_model_by_uri(
- uri_type, parsed_uri, storage_path, model_file_type
- )
- self._model_info_map[model_id].state = ModelStates.ACTIVE
- return configs, attributes
- except Exception as e:
- logger.error(f"Failed to register model {model_id}: {e}")
- self._model_info_map[model_id].state = ModelStates.INACTIVE
- raise e
- def delete_model(self, model_id: str) -> None:
- """
- Args:
- model_id: id of model to delete
- Returns:
- None
- """
- # check if the model is built-in
- with self._lock_pool.get_lock(model_id).read_lock():
- if self._is_built_in(model_id):
- raise BuiltInModelDeletionError(model_id)
+ storage_dir = Path(storage_path)
+ for file in source_dir.iterdir():
+ if file.is_file():
+ shutil.copy2(file, storage_dir / file.name)
+ return
- # delete the user-defined or fine-tuned model
- with self._lock_pool.get_lock(model_id).write_lock():
- storage_path = os.path.join(self._model_dir, f"{model_id}")
- if os.path.exists(storage_path):
- shutil.rmtree(storage_path)
- storage_path = os.path.join(self._builtin_model_dir, f"{model_id}")
- if os.path.exists(storage_path):
- shutil.rmtree(storage_path)
- if model_id in self._model_info_map:
- del self._model_info_map[model_id]
- logger.info(f"Model {model_id} deleted successfully.")
-
- def _is_built_in(self, model_id: str) -> bool:
+ def _register_transformers_model(self, model_info: ModelInfo) -> bool:
"""
- Check if the model_id corresponds to a built-in model.
+ Register Transformers model to auto-loading mechanism (internal method)
+ """
+ auto_map = model_info.auto_map
+ if not auto_map:
+ return False
- Args:
- model_id (str): The ID of the model.
+ auto_config_path = auto_map.get("AutoConfig")
+ auto_model_path = auto_map.get("AutoModelForCausalLM")
- Returns:
- bool: True if the model is built-in, False otherwise.
- """
- return (
- model_id in self._model_info_map
- and self._model_info_map[model_id].category == ModelCategory.BUILT_IN
- )
+ try:
+ model_path = os.path.join(
+ self._models_dir, model_info.category.value, model_info.model_id
+ )
+ module_parent = str(Path(model_path).parent.absolute())
+ with temporary_sys_path(module_parent):
+ config_class = import_class_from_path(
+ model_info.model_id, auto_config_path
+ )
+ AutoConfig.register(model_info.model_type, config_class)
+ logger.info(
+ f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}"
+ )
- def is_built_in_or_fine_tuned(self, model_id: str) -> bool:
- """
- Check if the model_id corresponds to a built-in or fine-tuned model.
+ model_class = import_class_from_path(
+ model_info.model_id, auto_model_path
+ )
+ AutoModelForCausalLM.register(config_class, model_class)
+ logger.info(
+ f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}"
+ )
- Args:
- model_id (str): The ID of the model.
+ return True
+ except Exception as e:
+ logger.warning(
+ f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct."
+ )
+ return False
- Returns:
- bool: True if the model is built-in or fine_tuned, False otherwise.
- """
- return model_id in self._model_info_map and (
- self._model_info_map[model_id].category == ModelCategory.BUILT_IN
- or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED
+ def _register_other_model(self, model_info: ModelInfo):
+ """Register other type models (non-Transformers models)"""
+ logger.info(
+ f"Registered other type model: {model_info.model_id} ({model_info.model_type})"
)
- def load_model(
- self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool
- ) -> Callable:
+ def ensure_transformers_registered(self, model_id: str) -> ModelInfo:
"""
- Load a model with automatic detection of .safetensors or .pt format
-
+ Ensure Transformers model is registered (called for lazy registration)
+ This method uses locks to ensure thread safety. All check logic is within lock protection.
Returns:
- model: The model instance corresponding to specific model_id
- """
- with self._lock_pool.get_lock(model_id).read_lock():
- if self.is_built_in_or_fine_tuned(model_id):
- model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
- return fetch_built_in_model(
- get_built_in_model_type(self._model_info_map[model_id].model_type),
- model_dir,
- inference_attrs,
- )
- else:
- # load the user-defined model
- model_dir = os.path.join(self._model_dir, f"{model_id}")
- model_file_type = get_model_file_type(model_dir)
- if model_file_type == ModelFileType.SAFETENSORS:
- # TODO: Support this function
- raise UnsupportedError("SAFETENSORS format")
- else:
- model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT)
-
- if not os.path.exists(model_path):
- raise ModelNotExistError(model_path)
- model = torch.jit.load(model_path)
- if (
- isinstance(model, torch._dynamo.eval_frame.OptimizedModule)
- or not acceleration
- ):
- return model
-
- try:
- model = torch.compile(model)
- except Exception as e:
- logger.warning(
- f"acceleration failed, fallback to normal mode: {str(e)}"
- )
- return model
-
- def save_model(self, model_id: str, model: nn.Module):
- """
- Save the model using save_pretrained
-
- Returns:
- Whether saving succeeded
+ str: If None, registration failed, otherwise returns model path
"""
+ # Use lock to protect entire check-execute process
with self._lock_pool.get_lock(model_id).write_lock():
- if self.is_built_in_or_fine_tuned(model_id):
- model_dir = os.path.join(self._builtin_model_dir, f"{model_id}")
- model.save_pretrained(model_dir)
- else:
- # save the user-defined model
- model_dir = os.path.join(self._model_dir, f"{model_id}")
- os.makedirs(model_dir, exist_ok=True)
- model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT)
- try:
- scripted_model = (
- model
- if isinstance(model, torch.jit.ScriptModule)
- else torch.jit.script(model)
+ # Directly access _models dictionary (avoid calling get_model_info which may cause deadlock)
+ model_info = None
+ for category_dict in self._models.values():
+ if model_id in category_dict:
+ model_info = category_dict[model_id]
+ break
+
+ if not model_info:
+ logger.warning(f"Model {model_id} does not exist, cannot register")
+ return None
+
+ # If already registered, return directly
+ if model_info._transformers_registered:
+ return model_info
+
+ # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks)
+ if (
+ not model_info.auto_map
+ or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys()
+ ):
+ model_info._transformers_registered = True
+ return model_info
+
+ # Execute registration (under lock protection)
+ try:
+ success = self._register_transformers_model(model_info)
+ if success:
+ model_info._transformers_registered = True
+ logger.info(
+ f"Model {model_id} successfully registered to Transformers"
)
- torch.jit.save(scripted_model, model_path)
- except Exception as e:
- logger.error(f"Failed to save scripted model: {e}")
-
- def get_ckpt_path(self, model_id: str) -> str:
- """
- Get the checkpoint path for a given model ID.
+ return model_info
+ else:
+ model_info.state = ModelStates.INACTIVE
+ logger.error(f"Model {model_id} failed to register to Transformers")
+ return None
- Args:
- model_id (str): The ID of the model.
+ except Exception as e:
+ # Ensure state consistency in exception cases
+ model_info.state = ModelStates.INACTIVE
+ model_info._transformers_registered = False
+ logger.error(
+ f"Exception occurred while registering model {model_id} to Transformers: {e}"
+ )
+ return None
- Returns:
- str: The path to the checkpoint file for the model.
- """
- # Only support built-in models for now
- return os.path.join(self._builtin_model_dir, f"{model_id}")
+ # ==================== Show and Delete Models ====================
def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
resp_status = TSStatus(
@@ -385,8 +427,14 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
message="Show models successfully",
)
if req.modelId:
- if req.modelId in self._model_info_map:
- model_info = self._model_info_map[req.modelId]
+ # Find specified model
+ model_info = None
+ for category_dict in self._models.values():
+ if req.modelId in category_dict:
+ model_info = category_dict[req.modelId]
+ break
+
+ if model_info:
return TShowModelsResp(
status=resp_status,
modelIdList=[req.modelId],
@@ -402,55 +450,133 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp:
categoryMap={},
stateMap={},
)
+ # Return all models
+ model_id_list = []
+ model_type_map = {}
+ category_map = {}
+ state_map = {}
+
+ for category_dict in self._models.values():
+ for model_id, model_info in category_dict.items():
+ model_id_list.append(model_id)
+ model_type_map[model_id] = model_info.model_type
+ category_map[model_id] = model_info.category.value
+ state_map[model_id] = model_info.state.value
+
return TShowModelsResp(
status=resp_status,
- modelIdList=list(self._model_info_map.keys()),
- modelTypeMap=dict(
- (model_id, model_info.model_type)
- for model_id, model_info in self._model_info_map.items()
- ),
- categoryMap=dict(
- (model_id, model_info.category.value)
- for model_id, model_info in self._model_info_map.items()
- ),
- stateMap=dict(
- (model_id, model_info.state.value)
- for model_id, model_info in self._model_info_map.items()
- ),
+ modelIdList=model_id_list,
+ modelTypeMap=model_type_map,
+ categoryMap=category_map,
+ stateMap=state_map,
)
- def register_built_in_model(self, model_info: ModelInfo):
- with self._lock_pool.get_lock(model_info.model_id).write_lock():
- self._model_info_map[model_info.model_id] = model_info
+ def delete_model(self, model_id: str) -> None:
+ # Use write lock to protect entire deletion process
+ with self._lock_pool.get_lock(model_id).write_lock():
+ model_info = None
+ category_value = None
+ for cat_value, category_dict in self._models.items():
+ if model_id in category_dict:
+ model_info = category_dict[model_id]
+ category_value = cat_value
+ break
+
+ if not model_info:
+ logger.warning(f"Model {model_id} does not exist, cannot delete")
+ return
+
+ if model_info.category == ModelCategory.BUILTIN:
+ raise BuiltInModelDeletionError(model_id)
+ model_info.state = ModelStates.DROPPING
+ model_path = os.path.join(
+ self._models_dir, model_info.category.value, model_id
+ )
+ if model_path.exists():
+ try:
+ shutil.rmtree(model_path)
+ logger.info(f"Deleted model directory: {model_path}")
+ except Exception as e:
+ logger.error(f"Failed to delete model directory {model_path}: {e}")
+ raise
- def get_model_info(self, model_id: str) -> ModelInfo:
- with self._lock_pool.get_lock(model_id).read_lock():
- if model_id in self._model_info_map:
- return self._model_info_map[model_id]
- else:
- raise ValueError(f"Model {model_id} does not exist.")
+ if category_value and model_id in self._models[category_value]:
+ del self._models[category_value][model_id]
+ logger.info(f"Model {model_id} has been removed from storage")
- def update_model_state(self, model_id: str, state: ModelStates):
- with self._lock_pool.get_lock(model_id).write_lock():
- if model_id in self._model_info_map:
- self._model_info_map[model_id].state = state
- else:
- raise ValueError(f"Model {model_id} does not exist.")
+ return
- def get_built_in_model_type(self, model_id: str) -> BuiltInModelType:
+ # ==================== Query Methods ====================
+
+ def get_model_info(
+ self, model_id: str, category: Optional[ModelCategory] = None
+ ) -> Optional[ModelInfo]:
"""
- Get the type of the model with the given model_id.
+ Get single model information
- Args:
- model_id (str): The ID of the model.
+ If category is specified, use model_id's lock
+ If category is not specified, need to traverse all dictionaries, use global lock
+ """
+ if category:
+ # Category specified, only need to access specific dictionary, use model_id's lock
+ with self._lock_pool.get_lock(model_id).read_lock():
+ return self._models[category.value].get(model_id)
+ else:
+ # Category not specified, need to traverse all dictionaries, use global lock
+ with self._lock_pool.get_lock("").read_lock():
+ for category_dict in self._models.values():
+ if model_id in category_dict:
+ return category_dict[model_id]
+ return None
+
+ def get_model_infos(
+ self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None
+ ) -> List[ModelInfo]:
+ """
+ Get model information list
- Returns:
- str: The type of the model.
+ Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure
+ For single model access, using model_id-based lock would be more efficient
"""
- with self._lock_pool.get_lock(model_id).read_lock():
- if model_id in self._model_info_map:
- return get_built_in_model_type(
- self._model_info_map[model_id].model_type
- )
+ matching_models = []
+
+ # For traversal operations, we need to protect the entire dictionary structure
+ # Use a special lock (using empty string as key) to protect the entire dictionary
+ with self._lock_pool.get_lock("").read_lock():
+ if category and model_type:
+ for model_info in self._models[category.value].values():
+ if model_info.model_type == model_type:
+ matching_models.append(model_info)
+ return matching_models
+ elif category:
+ return list(self._models[category.value].values())
+ elif model_type:
+ for category_dict in self._models.values():
+ for model_info in category_dict.values():
+ if model_info.model_type == model_type:
+ matching_models.append(model_info)
+ return matching_models
else:
- raise ValueError(f"Model {model_id} does not exist.")
+ for category_dict in self._models.values():
+ matching_models.extend(category_dict.values())
+ return matching_models
+
+ def is_model_registered(self, model_id: str) -> bool:
+ """Check if model is registered (search in _models)"""
+ # Lazy registration: if it's a Transformers model and not registered, register it first
+ if self.ensure_transformers_registered(model_id) is None:
+ return False
+
+ with self._lock_pool.get_lock("").read_lock():
+ for category_dict in self._models.values():
+ if model_id in category_dict:
+ return True
+ return False
+
+ def get_registered_models(self) -> List[str]:
+ """Get list of all registered model IDs"""
+ with self._lock_pool.get_lock("").read_lock():
+ model_ids = []
+ for category_dict in self._models.values():
+ model_ids.extend(category_dict.keys())
+ return model_ids
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py
rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json
new file mode 100644
index 0000000000000..1561124badd12
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json
@@ -0,0 +1,25 @@
+{
+ "model_type": "sktime",
+ "model_id": "arima",
+ "predict_length": 1,
+ "order": [1, 0, 0],
+ "seasonal_order": [0, 0, 0, 0],
+ "start_params": null,
+ "method": "lbfgs",
+ "maxiter": 50,
+ "suppress_warnings": false,
+ "out_of_sample_size": 0,
+ "scoring": "mse",
+ "scoring_args": null,
+ "trend": null,
+ "with_intercept": true,
+ "time_varying_regression": false,
+ "enforce_stationarity": true,
+ "enforce_invertibility": true,
+ "simple_differencing": false,
+ "measurement_error": false,
+ "mle_regression": true,
+ "hamilton_representation": false,
+ "concentrate_scale": false
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
new file mode 100644
index 0000000000000..261de3c9abe7c
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py
@@ -0,0 +1,409 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Union
+
+from iotdb.ainode.core.exception import (
+ BuiltInModelNotSupportError,
+ ListRangeException,
+ NumericalRangeException,
+ StringRangeException,
+ WrongAttributeTypeError,
+)
+from iotdb.ainode.core.log import Logger
+
+logger = Logger()
+
+
+@dataclass
+class AttributeConfig:
+ """Base class for attribute configuration"""
+
+ name: str
+ default: Any
+ type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple'
+ low: Union[int, float, None] = None
+ high: Union[int, float, None] = None
+ choices: List[str] = field(default_factory=list)
+ value_type: type = None # Element type for list and tuple
+
+ def validate_value(self, value):
+ """Validate if the value meets the requirements"""
+ if self.type == "int":
+ if value is None:
+ return True # Allow None for optional int parameters
+ if not isinstance(value, int):
+ raise WrongAttributeTypeError(self.name, "int")
+ if self.low is not None and self.high is not None:
+ if not (self.low <= value <= self.high):
+ raise NumericalRangeException(self.name, value, self.low, self.high)
+ elif self.type == "float":
+ if value is None:
+ return True # Allow None for optional float parameters
+ if not isinstance(value, (int, float)):
+ raise WrongAttributeTypeError(self.name, "float")
+ value = float(value)
+ if self.low is not None and self.high is not None:
+ if not (self.low <= value <= self.high):
+ raise NumericalRangeException(self.name, value, self.low, self.high)
+ elif self.type == "str":
+ if value is None:
+ return True # Allow None for optional str parameters
+ if not isinstance(value, str):
+ raise WrongAttributeTypeError(self.name, "str")
+ if self.choices and value not in self.choices:
+ raise StringRangeException(self.name, value, self.choices)
+ elif self.type == "bool":
+ if value is None:
+ return True # Allow None for optional bool parameters
+ if not isinstance(value, bool):
+ raise WrongAttributeTypeError(self.name, "bool")
+ elif self.type == "list":
+ if not isinstance(value, list):
+ raise WrongAttributeTypeError(self.name, "list")
+ for item in value:
+ if not isinstance(item, self.value_type):
+ raise WrongAttributeTypeError(self.name, self.value_type)
+ elif self.type == "tuple":
+ if not isinstance(value, tuple):
+ raise WrongAttributeTypeError(self.name, "tuple")
+ for item in value:
+ if not isinstance(item, self.value_type):
+ raise WrongAttributeTypeError(self.name, self.value_type)
+ return True
+
+ def parse(self, string_value: str):
+ """Parse string value to corresponding type"""
+ if self.type == "int":
+ if string_value.lower() == "none" or string_value.strip() == "":
+ return None
+ try:
+ return int(string_value)
+ except:
+ raise WrongAttributeTypeError(self.name, "int")
+ elif self.type == "float":
+ if string_value.lower() == "none" or string_value.strip() == "":
+ return None
+ try:
+ return float(string_value)
+ except:
+ raise WrongAttributeTypeError(self.name, "float")
+ elif self.type == "str":
+ if string_value.lower() == "none" or string_value.strip() == "":
+ return None
+ return string_value
+ elif self.type == "bool":
+ if string_value.lower() == "true":
+ return True
+ elif string_value.lower() == "false":
+ return False
+ elif string_value.lower() == "none" or string_value.strip() == "":
+ return None
+ else:
+ raise WrongAttributeTypeError(self.name, "bool")
+ elif self.type == "list":
+ try:
+ list_value = eval(string_value)
+ except:
+ raise WrongAttributeTypeError(self.name, "list")
+ if not isinstance(list_value, list):
+ raise WrongAttributeTypeError(self.name, "list")
+ for i in range(len(list_value)):
+ try:
+ list_value[i] = self.value_type(list_value[i])
+ except:
+ raise ListRangeException(
+ self.name, list_value, str(self.value_type)
+ )
+ return list_value
+ elif self.type == "tuple":
+ try:
+ tuple_value = eval(string_value)
+ except:
+ raise WrongAttributeTypeError(self.name, "tuple")
+ if not isinstance(tuple_value, tuple):
+ raise WrongAttributeTypeError(self.name, "tuple")
+ list_value = list(tuple_value)
+ for i in range(len(list_value)):
+ try:
+ list_value[i] = self.value_type(list_value[i])
+ except:
+ raise ListRangeException(
+ self.name, list_value, str(self.value_type)
+ )
+ return tuple(list_value)
+
+
+# Model configuration definitions - using concise dictionary format
+MODEL_CONFIGS = {
+ "NAIVE_FORECASTER": {
+ "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "strategy": AttributeConfig(
+ "strategy", "last", "str", choices=["last", "mean", "drift"]
+ ),
+ "window_length": AttributeConfig("window_length", None, "int"),
+ "sp": AttributeConfig("sp", 1, "int", 1, 5000),
+ },
+ "EXPONENTIAL_SMOOTHING": {
+ "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "damped_trend": AttributeConfig("damped_trend", False, "bool"),
+ "initialization_method": AttributeConfig(
+ "initialization_method",
+ "estimated",
+ "str",
+ choices=["estimated", "heuristic", "legacy-heuristic", "known"],
+ ),
+ "optimized": AttributeConfig("optimized", True, "bool"),
+ "remove_bias": AttributeConfig("remove_bias", False, "bool"),
+ "use_brute": AttributeConfig("use_brute", False, "bool"),
+ },
+ "ARIMA": {
+ "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int),
+ "seasonal_order": AttributeConfig(
+ "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int
+ ),
+ "start_params": AttributeConfig("start_params", None, "str"),
+ "method": AttributeConfig(
+ "method",
+ "lbfgs",
+ "str",
+ choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"],
+ ),
+ "maxiter": AttributeConfig("maxiter", 50, "int", 1, 5000),
+ "suppress_warnings": AttributeConfig("suppress_warnings", False, "bool"),
+ "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000),
+ "scoring": AttributeConfig(
+ "scoring",
+ "mse",
+ "str",
+ choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"],
+ ),
+ "scoring_args": AttributeConfig("scoring_args", None, "str"),
+ "trend": AttributeConfig("trend", None, "str"),
+ "with_intercept": AttributeConfig("with_intercept", True, "bool"),
+ "time_varying_regression": AttributeConfig(
+ "time_varying_regression", False, "bool"
+ ),
+ "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"),
+ "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"),
+ "simple_differencing": AttributeConfig("simple_differencing", False, "bool"),
+ "measurement_error": AttributeConfig("measurement_error", False, "bool"),
+ "mle_regression": AttributeConfig("mle_regression", True, "bool"),
+ "hamilton_representation": AttributeConfig(
+ "hamilton_representation", False, "bool"
+ ),
+ "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"),
+ },
+ "STL_FORECASTER": {
+ "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000),
+ "sp": AttributeConfig("sp", 2, "int", 1, 5000),
+ "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000),
+ "trend": AttributeConfig("trend", None, "int"),
+ "low_pass": AttributeConfig("low_pass", None, "int"),
+ "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000),
+ "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000),
+ "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000),
+ "robust": AttributeConfig("robust", False, "bool"),
+ "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000),
+ "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000),
+ "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000),
+ "inner_iter": AttributeConfig("inner_iter", None, "int"),
+ "outer_iter": AttributeConfig("outer_iter", None, "int"),
+ "forecaster_trend": AttributeConfig("forecaster_trend", None, "str"),
+ "forecaster_seasonal": AttributeConfig("forecaster_seasonal", None, "str"),
+ "forecaster_resid": AttributeConfig("forecaster_resid", None, "str"),
+ },
+ "GAUSSIAN_HMM": {
+ "n_components": AttributeConfig("n_components", 1, "int", 1, 5000),
+ "covariance_type": AttributeConfig(
+ "covariance_type",
+ "diag",
+ "str",
+ choices=["spherical", "diag", "full", "tied"],
+ ),
+ "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10),
+ "startprob_prior": AttributeConfig(
+ "startprob_prior", 1.0, "float", -1e10, 1e10
+ ),
+ "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10),
+ "means_prior": AttributeConfig("means_prior", 0, "float", -1e10, 1e10),
+ "means_weight": AttributeConfig("means_weight", 0, "float", -1e10, 1e10),
+ "covars_prior": AttributeConfig("covars_prior", 0.01, "float", -1e10, 1e10),
+ "covars_weight": AttributeConfig("covars_weight", 1, "float", -1e10, 1e10),
+ "algorithm": AttributeConfig(
+ "algorithm", "viterbi", "str", choices=["viterbi", "map"]
+ ),
+ "random_state": AttributeConfig("random_state", None, "float"),
+ "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000),
+ "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10),
+ "verbose": AttributeConfig("verbose", False, "bool"),
+ "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]),
+ "init_params": AttributeConfig(
+ "init_params", "stmc", "str", choices=["stmc", "stm"]
+ ),
+ "implementation": AttributeConfig(
+ "implementation", "log", "str", choices=["log", "scaling"]
+ ),
+ },
+ "GMM_HMM": {
+ "n_components": AttributeConfig("n_components", 1, "int", 1, 5000),
+ "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000),
+ "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10),
+ "startprob_prior": AttributeConfig(
+ "startprob_prior", 1.0, "float", -1e10, 1e10
+ ),
+ "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10),
+ "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10),
+ "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10),
+ "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10),
+ "covars_prior": AttributeConfig("covars_prior", None, "float"),
+ "covars_weight": AttributeConfig("covars_weight", None, "float"),
+ "algorithm": AttributeConfig(
+ "algorithm", "viterbi", "str", choices=["viterbi", "map"]
+ ),
+ "covariance_type": AttributeConfig(
+ "covariance_type",
+ "diag",
+ "str",
+ choices=["spherical", "diag", "full", "tied"],
+ ),
+ "random_state": AttributeConfig("random_state", None, "int"),
+ "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000),
+ "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10),
+ "verbose": AttributeConfig("verbose", False, "bool"),
+ "init_params": AttributeConfig(
+ "init_params",
+ "stmcw",
+ "str",
+ choices=[
+ "s",
+ "t",
+ "m",
+ "c",
+ "w",
+ "st",
+ "sm",
+ "sc",
+ "sw",
+ "tm",
+ "tc",
+ "tw",
+ "mc",
+ "mw",
+ "cw",
+ "stm",
+ "stc",
+ "stw",
+ "smc",
+ "smw",
+ "scw",
+ "tmc",
+ "tmw",
+ "tcw",
+ "mcw",
+ "stmc",
+ "stmw",
+ "stcw",
+ "smcw",
+ "tmcw",
+ "stmcw",
+ ],
+ ),
+ "params": AttributeConfig(
+ "params",
+ "stmcw",
+ "str",
+ choices=[
+ "s",
+ "t",
+ "m",
+ "c",
+ "w",
+ "st",
+ "sm",
+ "sc",
+ "sw",
+ "tm",
+ "tc",
+ "tw",
+ "mc",
+ "mw",
+ "cw",
+ "stm",
+ "stc",
+ "stw",
+ "smc",
+ "smw",
+ "scw",
+ "tmc",
+ "tmw",
+ "tcw",
+ "mcw",
+ "stmc",
+ "stmw",
+ "stcw",
+ "smcw",
+ "tmcw",
+ "stmcw",
+ ],
+ ),
+ "implementation": AttributeConfig(
+ "implementation", "log", "str", choices=["log", "scaling"]
+ ),
+ },
+ "STRAY": {
+ "alpha": AttributeConfig("alpha", 0.01, "float", -1e10, 1e10),
+ "k": AttributeConfig("k", 10, "int", 1, 5000),
+ "knn_algorithm": AttributeConfig(
+ "knn_algorithm",
+ "brute",
+ "str",
+ choices=["brute", "kd_tree", "ball_tree", "auto"],
+ ),
+ "p": AttributeConfig("p", 0.5, "float", -1e10, 1e10),
+ "size_threshold": AttributeConfig("size_threshold", 50, "int", 1, 5000),
+ "outlier_tail": AttributeConfig(
+ "outlier_tail", "max", "str", choices=["min", "max"]
+ ),
+ },
+}
+
+
+def get_attributes(model_id: str) -> Dict[str, AttributeConfig]:
+ """Get attribute configuration for Sktime model"""
+ model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id
+ if model_id not in MODEL_CONFIGS:
+ raise BuiltInModelNotSupportError(model_id)
+ return MODEL_CONFIGS[model_id]
+
+
+def update_attribute(
+ input_attributes: Dict[str, str], attribute_map: Dict[str, AttributeConfig]
+) -> Dict[str, Any]:
+ """Update Sktime model attributes using input attributes"""
+ attributes = {}
+ for name, config in attribute_map.items():
+ if name in input_attributes:
+ value = config.parse(input_attributes[name])
+ config.validate_value(value)
+ attributes[name] = value
+ else:
+ attributes[name] = config.default
+ return attributes
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json
new file mode 100644
index 0000000000000..4126d9de857a6
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json
@@ -0,0 +1,11 @@
+{
+ "model_type": "sktime",
+ "model_id": "exponential_smoothing",
+ "predict_length": 1,
+ "damped_trend": false,
+ "initialization_method": "estimated",
+ "optimized": true,
+ "remove_bias": false,
+ "use_brute": false
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json
new file mode 100644
index 0000000000000..94f7d7ec659fc
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json
@@ -0,0 +1,22 @@
+{
+ "model_type": "sktime",
+ "model_id": "gaussian_hmm",
+ "n_components": 1,
+ "covariance_type": "diag",
+ "min_covar": 0.001,
+ "startprob_prior": 1.0,
+ "transmat_prior": 1.0,
+ "means_prior": 0,
+ "means_weight": 0,
+ "covars_prior": 0.01,
+ "covars_weight": 1,
+ "algorithm": "viterbi",
+ "random_state": null,
+ "n_iter": 10,
+ "tol": 0.01,
+ "verbose": false,
+ "params": "stmc",
+ "init_params": "stmc",
+ "implementation": "log"
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json
new file mode 100644
index 0000000000000..fb19d1aaf86d9
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json
@@ -0,0 +1,24 @@
+{
+ "model_type": "sktime",
+ "model_id": "gmm_hmm",
+ "n_components": 1,
+ "n_mix": 1,
+ "min_covar": 0.001,
+ "startprob_prior": 1.0,
+ "transmat_prior": 1.0,
+ "weights_prior": 1.0,
+ "means_prior": 0.0,
+ "means_weight": 0.0,
+ "covars_prior": null,
+ "covars_weight": null,
+ "algorithm": "viterbi",
+ "covariance_type": "diag",
+ "random_state": null,
+ "n_iter": 10,
+ "tol": 0.01,
+ "verbose": false,
+ "init_params": "stmcw",
+ "params": "stmcw",
+ "implementation": "log"
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
new file mode 100644
index 0000000000000..eca812d35ec9a
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py
@@ -0,0 +1,180 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from abc import abstractmethod
+from typing import Any, Dict
+
+import numpy as np
+import pandas as pd
+from sklearn.preprocessing import MinMaxScaler
+from sktime.detection.hmm_learn import GMMHMM, GaussianHMM
+from sktime.detection.stray import STRAY
+from sktime.forecasting.arima import ARIMA
+from sktime.forecasting.exp_smoothing import ExponentialSmoothing
+from sktime.forecasting.naive import NaiveForecaster
+from sktime.forecasting.trend import STLForecaster
+
+from iotdb.ainode.core.exception import (
+ BuiltInModelNotSupportError,
+ InferenceModelInternalError,
+)
+from iotdb.ainode.core.log import Logger
+
+from .configuration_sktime import get_attributes, update_attribute
+
+logger = Logger()
+
+
+class SktimeModel:
+ """Base class for Sktime models"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ self._attributes = attributes
+ self._model = None
+
+ @abstractmethod
+ def generate(self, data, **kwargs):
+ """Execute generation/inference"""
+ raise NotImplementedError
+
+
+class ForecastingModel(SktimeModel):
+ """Base class for forecasting models"""
+
+ def generate(self, data, **kwargs):
+ """Execute forecasting"""
+ try:
+ predict_length = kwargs.get(
+ "predict_length", self._attributes["predict_length"]
+ )
+ self._model.fit(data)
+ output = self._model.predict(fh=range(predict_length))
+ return np.array(output, dtype=np.float64)
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class DetectionModel(SktimeModel):
+ """Base class for detection models"""
+
+ def generate(self, data, **kwargs):
+ """Execute detection"""
+ try:
+ predict_length = kwargs.get("predict_length", data.size)
+ output = self._model.fit_transform(data[:predict_length])
+ if isinstance(output, pd.DataFrame):
+ return np.array(output["labels"], dtype=np.int32)
+ else:
+ return np.array(output, dtype=np.int32)
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+class ArimaModel(ForecastingModel):
+ """ARIMA model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = ARIMA(
+ **{k: v for k, v in attributes.items() if k != "predict_length"}
+ )
+
+
+class ExponentialSmoothingModel(ForecastingModel):
+ """Exponential smoothing model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = ExponentialSmoothing(
+ **{k: v for k, v in attributes.items() if k != "predict_length"}
+ )
+
+
+class NaiveForecasterModel(ForecastingModel):
+ """Naive forecaster model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = NaiveForecaster(
+ **{k: v for k, v in attributes.items() if k != "predict_length"}
+ )
+
+
+class STLForecasterModel(ForecastingModel):
+ """STL forecaster model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = STLForecaster(
+ **{k: v for k, v in attributes.items() if k != "predict_length"}
+ )
+
+
+class GMMHMMModel(DetectionModel):
+ """GMM HMM model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = GMMHMM(**attributes)
+
+
+class GaussianHmmModel(DetectionModel):
+ """Gaussian HMM model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = GaussianHMM(**attributes)
+
+
+class STRAYModel(DetectionModel):
+ """STRAY anomaly detection model"""
+
+ def __init__(self, attributes: Dict[str, Any]):
+ super().__init__(attributes)
+ self._model = STRAY(**{k: v for k, v in attributes.items() if v is not None})
+
+ def generate(self, data, **kwargs):
+ """STRAY requires special handling: normalize first"""
+ try:
+ scaled_data = MinMaxScaler().fit_transform(data.values.reshape(-1, 1))
+ scaled_data = pd.Series(scaled_data.flatten())
+ return super().generate(scaled_data, **kwargs)
+ except Exception as e:
+ raise InferenceModelInternalError(str(e))
+
+
+# Model factory mapping
+_MODEL_FACTORY = {
+ "ARIMA": ArimaModel,
+ "EXPONENTIAL_SMOOTHING": ExponentialSmoothingModel,
+ "HOLTWINTERS": ExponentialSmoothingModel, # Use the same model class
+ "NAIVE_FORECASTER": NaiveForecasterModel,
+ "STL_FORECASTER": STLForecasterModel,
+ "GMM_HMM": GMMHMMModel,
+ "GAUSSIAN_HMM": GaussianHmmModel,
+ "STRAY": STRAYModel,
+}
+
+
+def create_sktime_model(model_id: str, **kwargs) -> SktimeModel:
+ """Create a Sktime model instance"""
+ attributes = update_attribute({**kwargs}, get_attributes(model_id.upper()))
+ model_class = _MODEL_FACTORY.get(model_id.upper())
+ if model_class is None:
+ raise BuiltInModelNotSupportError(model_id)
+ return model_class(attributes)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json
new file mode 100644
index 0000000000000..3dadd7c3b1e5d
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json
@@ -0,0 +1,9 @@
+{
+ "model_type": "sktime",
+ "model_id": "naive_forecaster",
+ "predict_length": 1,
+ "strategy": "last",
+ "window_length": null,
+ "sp": 1
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
new file mode 100644
index 0000000000000..ced21f29a2b82
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import numpy as np
+import pandas as pd
+import torch
+
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
+
+
+class SktimePipeline(ForecastPipeline):
+ def __init__(self, model_info, **model_kwargs):
+ model_kwargs.pop("device", None) # sktime models run on CPU
+ super().__init__(model_info, model_kwargs=model_kwargs)
+
+ def _preprocess(self, inputs):
+ return inputs
+
+ def forecast(self, inputs, **infer_kwargs):
+ predict_length = infer_kwargs.get("predict_length", 96)
+ input_ids = self._preprocess(inputs)
+
+ # Convert to pandas Series for sktime (sktime expects Series or DataFrame)
+ # Handle batch dimension: if batch_size > 1, process each sample separately
+ if len(input_ids.shape) == 2 and input_ids.shape[0] > 1:
+ # Batch processing: convert each row to Series
+ outputs = []
+ for i in range(input_ids.shape[0]):
+ series = pd.Series(
+ input_ids[i].cpu().numpy()
+ if isinstance(input_ids, torch.Tensor)
+ else input_ids[i]
+ )
+ output = self.model.generate(series, predict_length=predict_length)
+ outputs.append(output)
+ output = np.array(outputs)
+ else:
+ # Single sample: convert to Series
+ if isinstance(input_ids, torch.Tensor):
+ series = pd.Series(input_ids.squeeze().cpu().numpy())
+ else:
+ series = pd.Series(input_ids.squeeze())
+ output = self.model.generate(series, predict_length=predict_length)
+ # Add batch dimension if needed
+ if len(output.shape) == 1:
+ output = output[np.newaxis, :]
+
+ return self._postprocess(output)
+
+ def _postprocess(self, output):
+ if isinstance(output, np.ndarray):
+ return torch.from_numpy(output).float()
+ return output
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json
new file mode 100644
index 0000000000000..bfe71dbc48614
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json
@@ -0,0 +1,22 @@
+{
+ "model_type": "sktime",
+ "model_id": "stl_forecaster",
+ "predict_length": 1,
+ "sp": 2,
+ "seasonal": 7,
+ "trend": null,
+ "low_pass": null,
+ "seasonal_deg": 1,
+ "trend_deg": 1,
+ "low_pass_deg": 1,
+ "robust": false,
+ "seasonal_jump": 1,
+ "trend_jump": 1,
+ "low_pass_jump": 1,
+ "inner_iter": null,
+ "outer_iter": null,
+ "forecaster_trend": null,
+ "forecaster_seasonal": null,
+ "forecaster_resid": null
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json
new file mode 100644
index 0000000000000..e5bcc03cd0714
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json
@@ -0,0 +1,11 @@
+{
+ "model_type": "sktime",
+ "model_id": "stray",
+ "alpha": 0.01,
+ "k": 10,
+ "knn_algorithm": "brute",
+ "p": 0.5,
+ "size_threshold": 50,
+ "outlier_tail": "max"
+}
+
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
index 3ebf516f705e0..dc1de32506e57 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py
@@ -16,13 +16,10 @@
# under the License.
#
-import os
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
-from huggingface_hub import hf_hub_download
-from safetensors.torch import load_file as load_safetensors
from torch import nn
from transformers import Cache, DynamicCache, PreTrainedModel
from transformers.activations import ACT2FN
@@ -32,13 +29,10 @@
MoeModelOutputWithPast,
)
-from iotdb.ainode.core.log import Logger
from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss
from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin
-logger = Logger()
-
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
@@ -616,11 +610,7 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[
- :,
- -(attention_mask.shape[1] - past_length)
- * self.config.input_token_len :,
- ]
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
@@ -633,10 +623,9 @@ def prepare_inputs_for_generation(
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
- token_num = (
- input_ids.shape[1] + self.config.input_token_len - 1
- ) // self.config.input_token_len
- position_ids = position_ids[:, -token_num:]
+ position_ids = position_ids[
+ :, -(input_ids.shape[1] // self.config.input_token_len) :
+ ]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
similarity index 56%
rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py
rename to iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
index 17c88e32fb5a0..85b6f7db2ffef 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py
@@ -19,33 +19,33 @@
import torch
from iotdb.ainode.core.exception import InferenceModelInternalError
-from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import (
- AbstractInferencePipeline,
-)
-from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
-class TimerSundialInferencePipeline(AbstractInferencePipeline):
- """
- Strategy for Timer-Sundial model inference.
- """
+class SundialPipeline(ForecastPipeline):
+ def __init__(self, model_info, **model_kwargs):
+ super().__init__(model_info, model_kwargs=model_kwargs)
- def __init__(self, model_config: SundialConfig, **infer_kwargs):
- super().__init__(model_config, infer_kwargs=infer_kwargs)
-
- def preprocess_inputs(self, inputs: torch.Tensor):
- super().preprocess_inputs(inputs)
+ def _preprocess(self, inputs):
if len(inputs.shape) != 2:
raise InferenceModelInternalError(
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
)
- # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py
return inputs
- def post_decode(self):
- # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py
- pass
-
- def post_inference(self):
- # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py
- pass
+ def forecast(self, inputs, **infer_kwargs):
+ predict_length = infer_kwargs.get("predict_length", 96)
+ num_samples = infer_kwargs.get("num_samples", 10)
+ revin = infer_kwargs.get("revin", True)
+
+ input_ids = self._preprocess(inputs)
+ output = self.model.generate(
+ input_ids,
+ max_new_tokens=predict_length,
+ num_samples=num_samples,
+ revin=revin,
+ )
+ return self._postprocess(output)
+
+ def _postprocess(self, output: torch.Tensor):
+ return output.mean(dim=1)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py
new file mode 100644
index 0000000000000..2a1e720805f29
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py
rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
similarity index 98%
rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
index 0a33c682742aa..fc9d7b41388bc 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py
@@ -16,7 +16,7 @@
# under the License.
#
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Tuple, Union
import torch
import torch.nn.functional as F
@@ -29,11 +29,8 @@
MoeModelOutputWithPast,
)
-from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
-from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin
-
-logger = Logger()
+from iotdb.ainode.core.model.timer_xl.configuration_timer import TimerConfig
+from iotdb.ainode.core.model.timer_xl.ts_generation_mixin import TSGenerationMixin
def rotate_half(x):
@@ -606,11 +603,7 @@ def prepare_inputs_for_generation(
if attention_mask is not None and attention_mask.shape[1] > (
input_ids.shape[1] // self.config.input_token_len
):
- input_ids = input_ids[
- :,
- -(attention_mask.shape[1] - past_length)
- * self.config.input_token_len :,
- ]
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < (input_ids.shape[1] // self.config.input_token_len):
diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
similarity index 52%
rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py
rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
index dc1dd304f68e8..c0f00b1f5caf3 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py
@@ -19,33 +19,29 @@
import torch
from iotdb.ainode.core.exception import InferenceModelInternalError
-from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import (
- AbstractInferencePipeline,
-)
-from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline
-class TimerXLInferencePipeline(AbstractInferencePipeline):
- """
- Strategy for Timer-XL model inference.
- """
+class TimerPipeline(ForecastPipeline):
+ def __init__(self, model_info, **model_kwargs):
+ super().__init__(model_info, model_kwargs=model_kwargs)
- def __init__(self, model_config: TimerConfig, **infer_kwargs):
- super().__init__(model_config, infer_kwargs=infer_kwargs)
-
- def preprocess_inputs(self, inputs: torch.Tensor):
- super().preprocess_inputs(inputs)
+ def _preprocess(self, inputs):
if len(inputs.shape) != 2:
raise InferenceModelInternalError(
f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}"
)
- # Considering that we are currently using the generate function interface, it seems that no pre-processing is required
return inputs
- def post_decode(self):
- # Considering that we are currently using the generate function interface, it seems that no post-processing is required
- pass
+ def forecast(self, inputs, **infer_kwargs):
+ predict_length = infer_kwargs.get("predict_length", 96)
+ revin = infer_kwargs.get("revin", True)
+
+ input_ids = self._preprocess(inputs)
+ output = self.model.generate(
+ input_ids, max_new_tokens=predict_length, revin=revin
+ )
+ return self._postprocess(output)
- def post_inference(self):
- # Considering that we are currently using the generate function interface, it seems that no post-processing is required
- pass
+ def _postprocess(self, output: torch.Tensor):
+ return output
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py
similarity index 100%
rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py
rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py
deleted file mode 100644
index b2e759e00ce07..0000000000000
--- a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import os
-from enum import Enum
-from typing import List
-
-from huggingface_hub import snapshot_download
-from requests import Session
-from requests.adapters import HTTPAdapter
-
-from iotdb.ainode.core.constant import (
- DEFAULT_CHUNK_SIZE,
- DEFAULT_RECONNECT_TIMEOUT,
- DEFAULT_RECONNECT_TIMES,
-)
-from iotdb.ainode.core.exception import UnsupportedError
-from iotdb.ainode.core.log import Logger
-from iotdb.ainode.core.model.model_enums import ModelFileType
-from iotdb.ainode.core.model.model_info import get_model_file_type
-
-HTTP_PREFIX = "http://"
-HTTPS_PREFIX = "https://"
-
-logger = Logger()
-
-
-class UriType(Enum):
- REPO = "repo"
- FILE = "file"
- HTTP = "http"
- HTTPS = "https"
-
- @classmethod
- def values(cls) -> List[str]:
- return [item.value for item in cls]
-
- @staticmethod
- def parse_uri_type(uri: str):
- """
- Parse the URI type from the given string.
- """
- if uri.startswith("repo://"):
- return UriType.REPO
- elif uri.startswith("file://"):
- return UriType.FILE
- elif uri.startswith("http://"):
- return UriType.HTTP
- elif uri.startswith("https://"):
- return UriType.HTTPS
- else:
- raise ValueError(f"Invalid URI type for {uri}")
-
-
-def get_model_register_strategy(uri: str):
- """
- Determine the loading strategy for a model based on its URI/path.
-
- Args:
- uri (str): The URI of the model to be registered.
-
- Returns:
- uri_type (UriType): The type of the URI, which can be one of: REPO, FILE, HTTP, or HTTPS.
- parsed_uri (str): Parsed uri to get related file
- model_file_type (ModelFileType): The type of the model file, which can be one of: SAFETENSORS, PYTORCH, or UNKNOWN.
- """
-
- uri_type = UriType.parse_uri_type(uri)
- if uri_type in (UriType.HTTP, UriType.HTTPS):
- # TODO: support HTTP(S) URI
- raise UnsupportedError("CREATE MODEL FROM HTTP(S) URI")
- else:
- parsed_uri = uri[7:]
- if uri_type == UriType.FILE:
- # handle ~ in URI
- parsed_uri = os.path.expanduser(parsed_uri)
- model_file_type = get_model_file_type(uri)
- elif uri_type == UriType.REPO:
- # Currently, UriType.REPO only corresponds to huggingface repository with SAFETENSORS format
- model_file_type = ModelFileType.SAFETENSORS
- else:
- raise ValueError(f"Invalid URI type for {uri}")
- return uri_type, parsed_uri, model_file_type
-
-
-def download_snapshot_from_hf(repo_id: str, local_dir: str):
- """
- Download everything from a HuggingFace repository.
-
- Args:
- repo_id (str): The HuggingFace repository ID.
- local_dir (str): The local directory to save the downloaded files.
- """
- try:
- snapshot_download(
- repo_id=repo_id,
- local_dir=local_dir,
- )
- except Exception as e:
- logger.error(f"Failed to download HuggingFace model {repo_id}: {e}")
- raise e
-
-
-def download_file(url: str, storage_path: str) -> None:
- """
- Args:
- url: url of file to download
- storage_path: path to save the file
- Returns:
- None
- """
- logger.info(f"Start Downloading file from {url} to {storage_path}")
- session = Session()
- adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES)
- session.mount(HTTP_PREFIX, adapter)
- session.mount(HTTPS_PREFIX, adapter)
- response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True)
- response.raise_for_status()
- with open(storage_path, "wb") as file:
- for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE):
- if chunk:
- file.write(chunk)
- logger.info(f"Download file from {url} to {storage_path} success")
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
new file mode 100644
index 0000000000000..1cd0ee44912d5
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import importlib
+import json
+import os.path
+import sys
+from contextlib import contextmanager
+from typing import Dict, Tuple
+
+from iotdb.ainode.core.model.model_constants import (
+ MODEL_CONFIG_FILE_IN_JSON,
+ MODEL_WEIGHTS_FILE_IN_SAFETENSORS,
+ UriType,
+)
+
+
+def parse_uri_type(uri: str) -> UriType:
+ if uri.startswith("repo://"):
+ return UriType.REPO
+ elif uri.startswith("file://"):
+ return UriType.FILE
+ else:
+ raise ValueError(
+ f"Unsupported URI type: {uri}. Supported formats: repo:// or file://"
+ )
+
+
+def get_parsed_uri(uri: str) -> str:
+ return uri[7:] # Remove "repo://" or "file://" prefix
+
+
+@contextmanager
+def temporary_sys_path(path: str):
+ """Context manager for temporarily adding a path to sys.path"""
+ path_added = path not in sys.path
+ if path_added:
+ sys.path.insert(0, path)
+ try:
+ yield
+ finally:
+ if path_added and path in sys.path:
+ sys.path.remove(path)
+
+
+def load_model_config_in_json(config_path: str) -> Dict:
+ with open(config_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+
+
+def validate_model_files(model_dir: str) -> Tuple[str, str]:
+ """Validate model files exist, return config and weights file paths"""
+
+ config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON)
+ weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)
+
+ if not os.path.exists(config_path):
+ raise ValueError(f"Model config file does not exist: {config_path}")
+ if not os.path.exists(weights_path):
+ raise ValueError(f"Model weights file does not exist: {weights_path}")
+
+ # Create __init__.py file to ensure model directory can be imported as a module
+ init_file = os.path.join(model_dir, "__init__.py")
+ if not os.path.exists(init_file):
+ with open(init_file, "w"):
+ pass
+
+ return config_path, weights_path
+
+
+def import_class_from_path(module_name, class_path: str):
+ file_name, class_name = class_path.rsplit(".", 1)
+ module = importlib.import_module(module_name + "." + file_name)
+ return getattr(module, class_name)
+
+
+def ensure_init_file(dir_path: str):
+ """Ensure __init__.py file exists in the given dir path"""
+ init_file = os.path.join(dir_path, "__init__.py")
+ os.makedirs(dir_path, exist_ok=True)
+ if not os.path.exists(init_file):
+ with open(init_file, "w"):
+ pass
diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py
index e2be6459508b7..ea6362ef080af 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py
@@ -38,7 +38,6 @@
TAINodeRemoveReq,
TAINodeRestartReq,
TNodeVersionInfo,
- TUpdateModelInfoReq,
)
logger = Logger()
@@ -155,13 +154,8 @@ def _wait_and_reconnect(self) -> None:
self._try_to_connect()
except TException:
# can not connect to each config node
- self._sync_latest_config_node_list()
self._try_to_connect()
- def _sync_latest_config_node_list(self) -> None:
- # TODO
- pass
-
def _update_config_node_leader(self, status: TSStatus) -> bool:
if status.code == TSStatusCode.REDIRECTION_RECOMMEND.get_status_code():
if status.redirectNode is not None:
@@ -271,36 +265,3 @@ def get_ainode_configuration(self, node_id: int) -> map:
self._config_leader = None
self._wait_and_reconnect()
raise TException(self._MSG_RECONNECTION_FAIL)
-
- def update_model_info(
- self,
- model_id: str,
- model_status: int,
- attribute: str = "",
- ainode_id=None,
- input_length=0,
- output_length=0,
- ) -> None:
- if ainode_id is None:
- ainode_id = []
- for _ in range(0, self._RETRY_NUM):
- try:
- req = TUpdateModelInfoReq(model_id, model_status, attribute)
- if ainode_id is not None:
- req.aiNodeIds = ainode_id
- req.inputLength = input_length
- req.outputLength = output_length
- status = self._client.updateModelInfo(req)
- if not self._update_config_node_leader(status):
- verify_success(
- status, "An error occurs when calling update model info"
- )
- return status
- except TTransport.TException:
- logger.warning(
- "Failed to connect to ConfigNode {} from AINode when executing update model info",
- self._config_leader,
- )
- self._config_leader = None
- self._wait_and_reconnect()
- raise TException(self._MSG_RECONNECTION_FAIL)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
index f01e1594f0698..6c4eedeb99f7f 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py
@@ -29,6 +29,7 @@
TAIHeartbeatResp,
TDeleteModelReq,
TForecastReq,
+ TForecastResp,
TInferenceReq,
TInferenceResp,
TLoadModelReq,
@@ -78,8 +79,14 @@ def stopAINode(self) -> TSStatus:
def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp:
return self._model_manager.register_model(req)
+ def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
+ return self._model_manager.delete_model(req)
+
+ def showModels(self, req: TShowModelsReq) -> TShowModelsResp:
+ return self._model_manager.show_models(req)
+
def loadModel(self, req: TLoadModelReq) -> TSStatus:
- status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId)
+ status = self._ensure_model_is_registered(req.existingModelId)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return status
status = _ensure_device_id_is_available(req.deviceIdList)
@@ -88,7 +95,7 @@ def loadModel(self, req: TLoadModelReq) -> TSStatus:
return self._inference_manager.load_model(req)
def unloadModel(self, req: TUnloadModelReq) -> TSStatus:
- status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId)
+ status = self._ensure_model_is_registered(req.modelId)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
return status
status = _ensure_device_id_is_available(req.deviceIdList)
@@ -96,21 +103,6 @@ def unloadModel(self, req: TUnloadModelReq) -> TSStatus:
return status
return self._inference_manager.unload_model(req)
- def deleteModel(self, req: TDeleteModelReq) -> TSStatus:
- return self._model_manager.delete_model(req)
-
- def inference(self, req: TInferenceReq) -> TInferenceResp:
- return self._inference_manager.inference(req)
-
- def forecast(self, req: TForecastReq) -> TSStatus:
- return self._inference_manager.forecast(req)
-
- def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
- return ClusterManager.get_heart_beat(req)
-
- def showModels(self, req: TShowModelsReq) -> TShowModelsResp:
- return self._model_manager.show_models(req)
-
def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp:
status = _ensure_device_id_is_available(req.deviceIdList)
if status.code != TSStatusCode.SUCCESS_STATUS.value:
@@ -123,13 +115,28 @@ def showAIDevices(self) -> TShowAIDevicesResp:
deviceIdList=get_available_devices(),
)
+ def inference(self, req: TInferenceReq) -> TInferenceResp:
+ status = self._ensure_model_is_registered(req.modelId)
+ if status.code != TSStatusCode.SUCCESS_STATUS.value:
+ return TInferenceResp(status, [])
+ return self._inference_manager.inference(req)
+
+ def forecast(self, req: TForecastReq) -> TForecastResp:
+ status = self._ensure_model_is_registered(req.modelId)
+ if status.code != TSStatusCode.SUCCESS_STATUS.value:
+ return TForecastResp(status, [])
+ return self._inference_manager.forecast(req)
+
+ def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp:
+ return ClusterManager.get_heart_beat(req)
+
def createTrainingTask(self, req: TTrainingReq) -> TSStatus:
pass
- def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus:
- if not self._model_manager.is_built_in_or_fine_tuned(model_id):
+ def _ensure_model_is_registered(self, model_id: str) -> TSStatus:
+ if not self._model_manager.is_model_registered(model_id):
return TSStatus(
code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value,
- message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.",
+ message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.",
)
return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value)
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index 331cb8ab32a34..e93bb3dfdaf10 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -79,7 +79,7 @@ exclude = [
python = ">=3.11.0,<3.14.0"
# ---- DL / HF stack ----
-torch = ">=2.7.0"
+torch = "^2.8.0,<2.9.0"
torchmetrics = "^1.8.0"
transformers = "==4.56.2"
tokenizers = ">=0.22.0,<=0.23.0"
@@ -93,7 +93,10 @@ scipy = "^1.12.0"
pandas = "^2.3.2"
scikit-learn = "^1.7.1"
statsmodels = "^0.14.5"
-sktime = "0.38.5"
+sktime = "0.40.1"
+pmdarima = "2.1.1"
+hmmlearn = "0.3.2"
+accelerate = "^1.10.1"
# ---- Optimizers / utils ----
optuna = "^4.4.0"
@@ -112,7 +115,7 @@ black = "25.1.0"
isort = "6.0.1"
setuptools = ">=75.3.0"
joblib = ">=1.4.2"
-urllib3 = ">=2.2.3"
+urllib3 = "2.6.0"
[tool.poetry.scripts]
ainode = "iotdb.ainode.core.script:main"
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java
index 2721fedafb1e6..8d9081f435273 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java
@@ -21,21 +21,28 @@
import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq;
import org.apache.iotdb.common.rpc.thrift.TEndPoint;
+import org.apache.iotdb.commons.client.ClientPoolFactory;
import org.apache.iotdb.commons.client.IClientManager;
+import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient;
import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler;
-import org.apache.iotdb.db.protocol.client.AINodeClientFactory;
-import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient;
+/** Asynchronously send RPC requests to AINodes. */
public class AsyncAINodeHeartbeatClientPool {
- private final IClientManager clientManager;
+ private final IClientManager clientManager;
private AsyncAINodeHeartbeatClientPool() {
clientManager =
- new IClientManager.Factory()
- .createClientManager(new AINodeClientFactory.AINodeHeartbeatClientPoolFactory());
+ new IClientManager.Factory()
+ .createClientManager(
+ new ClientPoolFactory.AsyncAINodeHeartbeatServiceClientPoolFactory());
}
+ /**
+ * Only used in LoadManager.
+ *
+ * @param endPoint The specific DataNode
+ */
public void getAINodeHeartBeat(
TEndPoint endPoint, TAIHeartbeatReq req, AINodeHeartbeatHandler handler) {
try {
@@ -56,6 +63,6 @@ private AsyncAINodeHeartbeatClientPoolHolder() {
}
public static AsyncAINodeHeartbeatClientPool getInstance() {
- return AsyncAINodeHeartbeatClientPool.AsyncAINodeHeartbeatClientPoolHolder.INSTANCE;
+ return AsyncAINodeHeartbeatClientPoolHolder.INSTANCE;
}
}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java
index ccc19f1a9f382..324e351302787 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java
@@ -63,7 +63,6 @@ public void writeAuditLog(
}
}
- // TODO: Is the AsyncDataNodeHeartbeatClientPool must be a singleton?
private static class AsyncDataNodeHeartbeatClientPoolHolder {
private static final AsyncDataNodeHeartbeatClientPool INSTANCE =
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
index e0b2c144c0eac..65c1ee0a9fed5 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java
@@ -21,8 +21,6 @@
import org.apache.iotdb.commons.exception.runtime.SerializationRunTimeException;
import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan;
-import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
-import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.subscription.ShowTopicPlan;
import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan;
import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan;
@@ -52,10 +50,6 @@
import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan;
import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan;
import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan;
import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan;
import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan;
@@ -574,24 +568,6 @@ public static ConfigPhysicalPlan create(final ByteBuffer buffer) throws IOExcept
case UPDATE_CQ_LAST_EXEC_TIME:
plan = new UpdateCQLastExecTimePlan();
break;
- case CreateModel:
- plan = new CreateModelPlan();
- break;
- case UpdateModelInfo:
- plan = new UpdateModelInfoPlan();
- break;
- case DropModel:
- plan = new DropModelPlan();
- break;
- case ShowModel:
- plan = new ShowModelPlan();
- break;
- case DropModelInNode:
- plan = new DropModelInNodePlan();
- break;
- case GetModelInfo:
- plan = new GetModelInfoPlan();
- break;
case CreatePipePlugin:
plan = new CreatePipePluginPlan();
break;
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
deleted file mode 100644
index dd79910e51fa8..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.request.read.model;
-
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
-import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
-
-import java.util.Objects;
-
-public class GetModelInfoPlan extends ConfigPhysicalReadPlan {
-
- private String modelId;
-
- public GetModelInfoPlan() {
- super(ConfigPhysicalPlanType.GetModelInfo);
- }
-
- public GetModelInfoPlan(final TGetModelInfoReq getModelInfoReq) {
- super(ConfigPhysicalPlanType.GetModelInfo);
- this.modelId = getModelInfoReq.getModelId();
- }
-
- public String getModelId() {
- return modelId;
- }
-
- @Override
- public boolean equals(final Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- if (!super.equals(o)) {
- return false;
- }
- final GetModelInfoPlan that = (GetModelInfoPlan) o;
- return Objects.equals(modelId, that.modelId);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(super.hashCode(), modelId);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java
deleted file mode 100644
index eca00e8827d96..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.request.read.model;
-
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
-import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan;
-
-import java.util.Objects;
-
-public class ShowModelPlan extends ConfigPhysicalReadPlan {
-
- private String modelName;
-
- public ShowModelPlan() {
- super(ConfigPhysicalPlanType.ShowModel);
- }
-
- public ShowModelPlan(final TShowModelsReq showModelReq) {
- super(ConfigPhysicalPlanType.ShowModel);
- if (showModelReq.isSetModelId()) {
- this.modelName = showModelReq.getModelId();
- }
- }
-
- public boolean isSetModelName() {
- return modelName != null;
- }
-
- public String getModelName() {
- return modelName;
- }
-
- @Override
- public boolean equals(final Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- if (!super.equals(o)) {
- return false;
- }
- final ShowModelPlan that = (ShowModelPlan) o;
- return Objects.equals(modelName, that.modelName);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(super.hashCode(), modelName);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java
deleted file mode 100644
index 61e37cdd21877..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.request.write.model;
-
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan;
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
-
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Objects;
-
-public class CreateModelPlan extends ConfigPhysicalPlan {
-
- private String modelName;
-
- public CreateModelPlan() {
- super(ConfigPhysicalPlanType.CreateModel);
- }
-
- public CreateModelPlan(String modelName) {
- super(ConfigPhysicalPlanType.CreateModel);
- this.modelName = modelName;
- }
-
- public String getModelName() {
- return modelName;
- }
-
- @Override
- protected void serializeImpl(DataOutputStream stream) throws IOException {
- stream.writeShort(getType().getPlanType());
- ReadWriteIOUtils.write(modelName, stream);
- }
-
- @Override
- protected void deserializeImpl(ByteBuffer buffer) throws IOException {
- modelName = ReadWriteIOUtils.readString(buffer);
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- if (!super.equals(o)) {
- return false;
- }
- CreateModelPlan that = (CreateModelPlan) o;
- return Objects.equals(modelName, that.modelName);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(super.hashCode(), modelName);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java
deleted file mode 100644
index 885543f84e156..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java
+++ /dev/null
@@ -1,70 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.request.write.model;
-
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan;
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Objects;
-
-public class DropModelInNodePlan extends ConfigPhysicalPlan {
-
- private int nodeId;
-
- public DropModelInNodePlan() {
- super(ConfigPhysicalPlanType.DropModelInNode);
- }
-
- public DropModelInNodePlan(int nodeId) {
- super(ConfigPhysicalPlanType.DropModelInNode);
- this.nodeId = nodeId;
- }
-
- public int getNodeId() {
- return nodeId;
- }
-
- @Override
- protected void serializeImpl(DataOutputStream stream) throws IOException {
- stream.writeShort(getType().getPlanType());
- stream.writeInt(nodeId);
- }
-
- @Override
- protected void deserializeImpl(ByteBuffer buffer) throws IOException {
- nodeId = buffer.getInt();
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (!(o instanceof DropModelInNodePlan)) return false;
- DropModelInNodePlan that = (DropModelInNodePlan) o;
- return nodeId == that.nodeId;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(super.hashCode(), nodeId);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java
deleted file mode 100644
index 813b116c645c5..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java
+++ /dev/null
@@ -1,79 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.request.write.model;
-
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan;
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
-
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.Objects;
-
-public class DropModelPlan extends ConfigPhysicalPlan {
-
- private String modelName;
-
- public DropModelPlan() {
- super(ConfigPhysicalPlanType.DropModel);
- }
-
- public DropModelPlan(String modelName) {
- super(ConfigPhysicalPlanType.DropModel);
- this.modelName = modelName;
- }
-
- public String getModelName() {
- return modelName;
- }
-
- @Override
- protected void serializeImpl(DataOutputStream stream) throws IOException {
- stream.writeShort(getType().getPlanType());
- ReadWriteIOUtils.write(modelName, stream);
- }
-
- @Override
- protected void deserializeImpl(ByteBuffer buffer) throws IOException {
- modelName = ReadWriteIOUtils.readString(buffer);
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- if (!super.equals(o)) {
- return false;
- }
- DropModelPlan that = (DropModelPlan) o;
- return modelName.equals(that.modelName);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(super.hashCode(), modelName);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
deleted file mode 100644
index ce7219e428139..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java
+++ /dev/null
@@ -1,122 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.request.write.model;
-
-import org.apache.iotdb.commons.model.ModelInformation;
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan;
-import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType;
-
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Objects;
-
-public class UpdateModelInfoPlan extends ConfigPhysicalPlan {
-
- private String modelName;
- private ModelInformation modelInformation;
-
- // The node which has the model which is only updated in model registration
- private List nodeIds;
-
- public UpdateModelInfoPlan() {
- super(ConfigPhysicalPlanType.UpdateModelInfo);
- }
-
- public UpdateModelInfoPlan(String modelName, ModelInformation modelInformation) {
- super(ConfigPhysicalPlanType.UpdateModelInfo);
- this.modelName = modelName;
- this.modelInformation = modelInformation;
- this.nodeIds = Collections.emptyList();
- }
-
- public UpdateModelInfoPlan(
- String modelName, ModelInformation modelInformation, List nodeIds) {
- super(ConfigPhysicalPlanType.UpdateModelInfo);
- this.modelName = modelName;
- this.modelInformation = modelInformation;
- this.nodeIds = nodeIds;
- }
-
- public String getModelName() {
- return modelName;
- }
-
- public ModelInformation getModelInformation() {
- return modelInformation;
- }
-
- public List getNodeIds() {
- return nodeIds;
- }
-
- public void setNodeIds(List nodeIds) {
- this.nodeIds = nodeIds;
- }
-
- @Override
- protected void serializeImpl(DataOutputStream stream) throws IOException {
- stream.writeShort(getType().getPlanType());
- ReadWriteIOUtils.write(modelName, stream);
- this.modelInformation.serialize(stream);
- ReadWriteIOUtils.write(nodeIds.size(), stream);
- for (Integer nodeId : nodeIds) {
- ReadWriteIOUtils.write(nodeId, stream);
- }
- }
-
- @Override
- protected void deserializeImpl(ByteBuffer buffer) throws IOException {
- this.modelName = ReadWriteIOUtils.readString(buffer);
- this.modelInformation = ModelInformation.deserialize(buffer);
- int size = ReadWriteIOUtils.readInt(buffer);
- this.nodeIds = new ArrayList<>();
- for (int i = 0; i < size; i++) {
- this.nodeIds.add(ReadWriteIOUtils.readInt(buffer));
- }
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o) {
- return true;
- }
- if (o == null || getClass() != o.getClass()) {
- return false;
- }
- if (!super.equals(o)) {
- return false;
- }
- UpdateModelInfoPlan that = (UpdateModelInfoPlan) o;
- return modelName.equals(that.modelName)
- && modelInformation.equals(that.modelInformation)
- && nodeIds.equals(that.nodeIds);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(super.hashCode(), modelName, modelInformation, nodeIds);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
deleted file mode 100644
index cebc1301b8912..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java
+++ /dev/null
@@ -1,63 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.response.model;
-
-import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
-import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
-import org.apache.iotdb.consensus.common.DataSet;
-
-public class GetModelInfoResp implements DataSet {
-
- private final TSStatus status;
-
- private int targetAINodeId;
- private TEndPoint targetAINodeAddress;
-
- public TSStatus getStatus() {
- return status;
- }
-
- public GetModelInfoResp(TSStatus status) {
- this.status = status;
- }
-
- public int getTargetAINodeId() {
- return targetAINodeId;
- }
-
- public void setTargetAINodeId(int targetAINodeId) {
- this.targetAINodeId = targetAINodeId;
- }
-
- public void setTargetAINodeAddress(TAINodeConfiguration aiNodeConfiguration) {
- if (aiNodeConfiguration.getLocation() == null) {
- return;
- }
- this.targetAINodeAddress = aiNodeConfiguration.getLocation().getInternalEndPoint();
- }
-
- public TGetModelInfoResp convertToThriftResponse() {
- TGetModelInfoResp resp = new TGetModelInfoResp(status);
- resp.setAiNodeAddress(targetAINodeAddress);
- return resp;
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
deleted file mode 100644
index 7490a53a01c57..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java
+++ /dev/null
@@ -1,62 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.consensus.response.model;
-
-import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.commons.model.ModelInformation;
-import org.apache.iotdb.consensus.common.DataSet;
-
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-
-// TODO: Will be removed in the future
-public class ModelTableResp implements DataSet {
-
- private final TSStatus status;
- private final List serializedAllModelInformation;
- private Map modelTypeMap;
- private Map algorithmMap;
-
- public ModelTableResp(TSStatus status) {
- this.status = status;
- this.serializedAllModelInformation = new ArrayList<>();
- }
-
- public void addModelInformation(List modelInformationList) throws IOException {
- for (ModelInformation modelInformation : modelInformationList) {
- this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
- }
- }
-
- public void addModelInformation(ModelInformation modelInformation) throws IOException {
- this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult());
- }
-
- public void setModelTypeMap(Map modelTypeMap) {
- this.modelTypeMap = modelTypeMap;
- }
-
- public void setAlgorithmMap(Map algorithmMap) {
- this.algorithmMap = algorithmMap;
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
index 5d4b09adfc710..9d7151a8d20e3 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java
@@ -19,22 +19,12 @@
package org.apache.iotdb.confignode.manager;
-import org.apache.iotdb.ainode.rpc.thrift.IDataSchema;
-import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp;
-import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq;
-import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TAINodeLocation;
import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
import org.apache.iotdb.common.rpc.thrift.TFlushReq;
import org.apache.iotdb.common.rpc.thrift.TPipeHeartbeatResp;
import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet;
@@ -58,7 +48,6 @@
import org.apache.iotdb.commons.conf.TrimProperties;
import org.apache.iotdb.commons.exception.IllegalPathException;
import org.apache.iotdb.commons.exception.MetadataException;
-import org.apache.iotdb.commons.model.ModelStatus;
import org.apache.iotdb.commons.path.PartialPath;
import org.apache.iotdb.commons.path.PathPatternTree;
import org.apache.iotdb.commons.path.PathPatternUtil;
@@ -97,7 +86,6 @@
import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan;
import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan;
import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
import org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan;
import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp;
import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp;
@@ -129,7 +117,6 @@
import org.apache.iotdb.confignode.manager.schema.ClusterSchemaQuotaStatistics;
import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager;
import org.apache.iotdb.confignode.persistence.ClusterInfo;
-import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.persistence.ProcedureInfo;
import org.apache.iotdb.confignode.persistence.TTLInfo;
import org.apache.iotdb.confignode.persistence.TriggerInfo;
@@ -144,7 +131,6 @@
import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo;
import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo;
import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils;
-import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq;
import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp;
@@ -163,13 +149,11 @@
import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq;
-import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq;
-import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq;
import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq;
@@ -186,7 +170,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq;
-import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq;
@@ -203,8 +186,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
@@ -257,11 +238,8 @@
import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq;
-import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.consensus.common.DataSet;
import org.apache.iotdb.consensus.exception.ConsensusException;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
import org.apache.iotdb.db.schemaengine.template.Template;
import org.apache.iotdb.db.schemaengine.template.TemplateAlterOperationType;
import org.apache.iotdb.db.schemaengine.template.alter.TemplateAlterOperationUtil;
@@ -340,9 +318,6 @@ public class ConfigManager implements IManager {
/** CQ. */
private final CQManager cqManager;
- /** AI Model. */
- private final ModelManager modelManager;
-
/** Pipe */
private final PipeManager pipeManager;
@@ -362,8 +337,6 @@ public class ConfigManager implements IManager {
private static final String DATABASE = "\tDatabase=";
- private static final String DOT = ".";
-
public ConfigManager() throws IOException {
// Build the persistence module
ClusterInfo clusterInfo = new ClusterInfo();
@@ -375,7 +348,6 @@ public ConfigManager() throws IOException {
UDFInfo udfInfo = new UDFInfo();
TriggerInfo triggerInfo = new TriggerInfo();
CQInfo cqInfo = new CQInfo();
- ModelInfo modelInfo = new ModelInfo();
PipeInfo pipeInfo = new PipeInfo();
QuotaInfo quotaInfo = new QuotaInfo();
TTLInfo ttlInfo = new TTLInfo();
@@ -393,7 +365,6 @@ public ConfigManager() throws IOException {
udfInfo,
triggerInfo,
cqInfo,
- modelInfo,
pipeInfo,
subscriptionInfo,
quotaInfo,
@@ -415,7 +386,6 @@ public ConfigManager() throws IOException {
this.udfManager = new UDFManager(this, udfInfo);
this.triggerManager = new TriggerManager(this, triggerInfo);
this.cqManager = new CQManager(this);
- this.modelManager = new ModelManager(this, modelInfo);
this.pipeManager = new PipeManager(this, pipeInfo);
this.subscriptionManager = new SubscriptionManager(this, subscriptionInfo);
this.auditLogger = new CNAuditLogger(this);
@@ -1289,11 +1259,6 @@ public TriggerManager getTriggerManager() {
return triggerManager;
}
- @Override
- public ModelManager getModelManager() {
- return modelManager;
- }
-
@Override
public PipeManager getPipeManager() {
return pipeManager;
@@ -2757,150 +2722,6 @@ public TSStatus transfer(List newUnknownDataList) {
return transferResult;
}
- @Override
- public TSStatus createModel(TCreateModelReq req) {
- TSStatus status = confirmLeader();
- if (nodeManager.getRegisteredAINodes().isEmpty()) {
- return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode())
- .setMessage("There is no available AINode! Try to start one.");
- }
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.createModel(req)
- : status;
- }
-
- private List fetchSchemaForTreeModel(TCreateTrainingReq req) {
- List dataSchemaList = new ArrayList<>();
- for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) {
- IDataSchema dataSchema = new IDataSchema(req.getDataSchemaForTree().getPath().get(i));
- dataSchema.setTimeRange(req.getTimeRanges().get(i));
- dataSchemaList.add(dataSchema);
- }
- return dataSchemaList;
- }
-
- private List fetchSchemaForTableModel(TCreateTrainingReq req) {
- return Collections.singletonList(new IDataSchema(req.getDataSchemaForTable().getTargetSql()));
- }
-
- public TSStatus createTraining(TCreateTrainingReq req) {
- TSStatus status = confirmLeader();
- if (nodeManager.getRegisteredAINodes().isEmpty()) {
- return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode())
- .setMessage("There is no available AINode! Try to start one.");
- }
-
- TTrainingReq trainingReq = new TTrainingReq();
- trainingReq.setModelId(req.getModelId());
- if (req.isSetExistingModelId()) {
- trainingReq.setExistingModelId(req.getExistingModelId());
- }
- if (req.isSetParameters() && !req.getParameters().isEmpty()) {
- trainingReq.setParameters(req.getParameters());
- }
-
- try {
- status = getConsensusManager().write(new CreateModelPlan(req.getModelId()));
- if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new MetadataException("Can't init model " + req.getModelId());
- }
-
- List dataSchema;
- if (req.isTableModel) {
- dataSchema = fetchSchemaForTableModel(req);
- trainingReq.setDbType("iotdb.table");
- } else {
- dataSchema = fetchSchemaForTreeModel(req);
- trainingReq.setDbType("iotdb.tree");
- }
- updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal()));
- trainingReq.setTargetDataSchema(dataSchema);
-
- TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0);
- TEndPoint targetAINodeEndPoint =
- new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort());
- try (AINodeClient client =
- AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) {
- status = client.createTrainingTask(trainingReq);
- if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new IllegalArgumentException(status.message);
- }
- }
- } catch (final Exception e) {
- status.setCode(TSStatusCode.CAN_NOT_CONNECT_CONFIGNODE.getStatusCode());
- status.setMessage(e.getMessage());
- try {
- updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.UNAVAILABLE.ordinal()));
- } catch (Exception e2) {
- LOGGER.error(e2.getMessage());
- }
- }
- return status;
- }
-
- @Override
- public TSStatus dropModel(TDropModelReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.dropModel(req)
- : status;
- }
-
- @Override
- public TSStatus loadModel(TLoadModelReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.loadModel(req)
- : status;
- }
-
- @Override
- public TSStatus unloadModel(TUnloadModelReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.unloadModel(req)
- : status;
- }
-
- @Override
- public TShowModelsResp showModel(TShowModelsReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.showModel(req)
- : new TShowModelsResp(status);
- }
-
- @Override
- public TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.showLoadedModel(req)
- : new TShowLoadedModelsResp(status, Collections.emptyMap());
- }
-
- @Override
- public TShowAIDevicesResp showAIDevices() {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.showAIDevices()
- : new TShowAIDevicesResp(status, Collections.emptyList());
- }
-
- @Override
- public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.getModelInfo(req)
- : new TGetModelInfoResp(status);
- }
-
- public TSStatus updateModelInfo(TUpdateModelInfoReq req) {
- TSStatus status = confirmLeader();
- return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()
- ? modelManager.updateModelInfo(req)
- : status;
- }
-
@Override
public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) {
TSStatus status = confirmLeader();
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
index 33e77db24907d..dff994d70e7e3 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java
@@ -19,13 +19,6 @@
package org.apache.iotdb.confignode.manager;
-import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp;
-import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation;
import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId;
import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation;
@@ -82,7 +75,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq;
-import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq;
@@ -103,7 +95,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq;
-import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq;
@@ -120,8 +111,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
@@ -255,13 +244,6 @@ public interface IManager {
*/
CQManager getCQManager();
- /**
- * Get {@link ModelManager}.
- *
- * @return {@link ModelManager} instance
- */
- ModelManager getModelManager();
-
/**
* Get {@link PipeManager}.
*
@@ -880,30 +862,6 @@ TDataPartitionTableResp getOrCreateDataPartition(
TSStatus transfer(List newUnknownDataList);
- /** Create a model. */
- TSStatus createModel(TCreateModelReq req);
-
- /** Drop a model. */
- TSStatus dropModel(TDropModelReq req);
-
- /** Load the specific model to the specific devices. */
- TSStatus loadModel(TLoadModelReq req);
-
- /** Unload the specific model from the specific devices. */
- TSStatus unloadModel(TUnloadModelReq req);
-
- /** Return the model table. */
- TShowModelsResp showModel(TShowModelsReq req);
-
- /** Return the loaded model instances. */
- TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req);
-
- /** Return all available AI devices. */
- TShowAIDevicesResp showAIDevices();
-
- /** Update the model state */
- TGetModelInfoResp getModelInfo(TGetModelInfoReq req);
-
/** Set space quota. */
TSStatus setSpaceQuota(TSetSpaceQuotaReq req);
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
deleted file mode 100644
index 3efdbc222b6d2..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java
+++ /dev/null
@@ -1,245 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.manager;
-
-import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp;
-import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq;
-import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp;
-import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq;
-import org.apache.iotdb.common.rpc.thrift.TEndPoint;
-import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.commons.client.exception.ClientManagerException;
-import org.apache.iotdb.commons.model.ModelInformation;
-import org.apache.iotdb.commons.model.ModelStatus;
-import org.apache.iotdb.commons.model.ModelType;
-import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
-import org.apache.iotdb.confignode.exception.NoAvailableAINodeException;
-import org.apache.iotdb.confignode.persistence.ModelInfo;
-import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo;
-import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
-import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
-import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
-import org.apache.iotdb.consensus.exception.ConsensusException;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
-import org.apache.iotdb.rpc.TSStatusCode;
-
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.util.List;
-
-public class ModelManager {
-
- private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class);
-
- private final ConfigManager configManager;
- private final ModelInfo modelInfo;
-
- public ModelManager(ConfigManager configManager, ModelInfo modelInfo) {
- this.configManager = configManager;
- this.modelInfo = modelInfo;
- }
-
- public TSStatus createModel(TCreateModelReq req) {
- if (modelInfo.contain(req.modelName)) {
- return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
- .setMessage(String.format("Model name %s already exists", req.modelName));
- }
- try {
- if (req.uri.isEmpty()) {
- return configManager.getConsensusManager().write(new CreateModelPlan(req.modelName));
- }
- return configManager.getProcedureManager().createModel(req.modelName, req.uri);
- } catch (ConsensusException e) {
- LOGGER.warn("Unexpected error happened while getting model: ", e);
- // consensus layer related errors
- TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
- res.setMessage(e.getMessage());
- return res;
- }
- }
-
- public TSStatus dropModel(TDropModelReq req) {
- if (modelInfo.checkModelType(req.getModelId()) != ModelType.USER_DEFINED) {
- return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode())
- .setMessage(String.format("Built-in model %s can't be removed", req.modelId));
- }
- if (!modelInfo.contain(req.modelId)) {
- return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode())
- .setMessage(String.format("Model name %s doesn't exists", req.modelId));
- }
- return configManager.getProcedureManager().dropModel(req.getModelId());
- }
-
- public TSStatus loadModel(TLoadModelReq req) {
- try (AINodeClient client = getAINodeClient()) {
- org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq loadModelReq =
- new org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq(
- req.existingModelId, req.deviceIdList);
- return client.loadModel(loadModelReq);
- } catch (Exception e) {
- LOGGER.warn("Failed to load model due to", e);
- return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode())
- .setMessage(e.getMessage());
- }
- }
-
- public TSStatus unloadModel(TUnloadModelReq req) {
- try (AINodeClient client = getAINodeClient()) {
- org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq unloadModelReq =
- new org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq(req.modelId, req.deviceIdList);
- return client.unloadModel(unloadModelReq);
- } catch (Exception e) {
- LOGGER.warn("Failed to unload model due to", e);
- return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode())
- .setMessage(e.getMessage());
- }
- }
-
- public TShowModelsResp showModel(final TShowModelsReq req) {
- try (AINodeClient client = getAINodeClient()) {
- TShowModelsReq showModelsReq = new TShowModelsReq();
- if (req.isSetModelId()) {
- showModelsReq.setModelId(req.getModelId());
- }
- TShowModelsResp resp = client.showModels(showModelsReq);
- TShowModelsResp res =
- new TShowModelsResp()
- .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
- res.setModelIdList(resp.getModelIdList());
- res.setModelTypeMap(resp.getModelTypeMap());
- res.setCategoryMap(resp.getCategoryMap());
- res.setStateMap(resp.getStateMap());
- return res;
- } catch (Exception e) {
- LOGGER.warn("Failed to show models due to", e);
- return new TShowModelsResp()
- .setStatus(
- new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode())
- .setMessage(e.getMessage()));
- }
- }
-
- public TShowLoadedModelsResp showLoadedModel(final TShowLoadedModelsReq req) {
- try (AINodeClient client = getAINodeClient()) {
- TShowLoadedModelsReq showModelsReq =
- new TShowLoadedModelsReq().setDeviceIdList(req.getDeviceIdList());
- TShowLoadedModelsResp resp = client.showLoadedModels(showModelsReq);
- TShowLoadedModelsResp res =
- new TShowLoadedModelsResp()
- .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
- res.setDeviceLoadedModelsMap(resp.getDeviceLoadedModelsMap());
- return res;
- } catch (Exception e) {
- LOGGER.warn("Failed to show loaded models due to", e);
- return new TShowLoadedModelsResp()
- .setStatus(
- new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode())
- .setMessage(e.getMessage()));
- }
- }
-
- public TShowAIDevicesResp showAIDevices() {
- try (AINodeClient client = getAINodeClient()) {
- org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp resp = client.showAIDevices();
- TShowAIDevicesResp res =
- new TShowAIDevicesResp()
- .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
- res.setDeviceIdList(resp.getDeviceIdList());
- return res;
- } catch (Exception e) {
- LOGGER.warn("Failed to show AI devices due to", e);
- return new TShowAIDevicesResp()
- .setStatus(
- new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode())
- .setMessage(e.getMessage()));
- }
- }
-
- public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
- return new TGetModelInfoResp()
- .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()))
- .setAiNodeAddress(
- configManager
- .getNodeManager()
- .getRegisteredAINodes()
- .get(0)
- .getLocation()
- .getInternalEndPoint());
- }
-
- // Currently this method is only used by built-in timer_xl
- public TSStatus updateModelInfo(TUpdateModelInfoReq req) {
- if (!modelInfo.contain(req.getModelId())) {
- return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode())
- .setMessage(String.format("Model %s doesn't exists", req.getModelId()));
- }
- try {
- ModelInformation modelInformation =
- new ModelInformation(ModelType.USER_DEFINED, req.getModelId());
- modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]);
- modelInformation.setAttribute(req.getAttributes());
- modelInformation.setInputColumnSize(1);
- if (req.isSetOutputLength()) {
- modelInformation.setOutputLength(req.getOutputLength());
- }
- if (req.isSetInputLength()) {
- modelInformation.setInputLength(req.getInputLength());
- }
- UpdateModelInfoPlan updateModelInfoPlan =
- new UpdateModelInfoPlan(req.getModelId(), modelInformation);
- if (req.isSetAiNodeIds()) {
- updateModelInfoPlan.setNodeIds(req.getAiNodeIds());
- }
- return configManager.getConsensusManager().write(updateModelInfoPlan);
- } catch (ConsensusException e) {
- LOGGER.warn("Unexpected error happened while updating model info: ", e);
- // consensus layer related errors
- TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode());
- res.setMessage(e.getMessage());
- return res;
- }
- }
-
- private AINodeClient getAINodeClient() throws NoAvailableAINodeException, ClientManagerException {
- List aiNodeInfo = configManager.getNodeManager().getRegisteredAINodeInfoList();
- if (aiNodeInfo.isEmpty()) {
- throw new NoAvailableAINodeException();
- }
- TEndPoint targetAINodeEndPoint =
- new TEndPoint(aiNodeInfo.get(0).getInternalAddress(), aiNodeInfo.get(0).getInternalPort());
- try {
- return AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
-
- public List getModelDistributions(String modelName) {
- return modelInfo.getNodeIds(modelName);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
index 2e4227af3fc8b..d67e7721eef83 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java
@@ -61,8 +61,6 @@
import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler;
import org.apache.iotdb.confignode.procedure.env.RemoveDataNodeHandler;
import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure;
-import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure;
-import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure;
import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure;
import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure;
import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure;
@@ -1414,24 +1412,6 @@ public TSStatus createCQ(TCreateCQReq req, ScheduledExecutorService scheduledExe
return waitingProcedureFinished(procedure);
}
- public TSStatus createModel(String modelName, String uri) {
- long procedureId = executor.submitProcedure(new CreateModelProcedure(modelName, uri));
- LOGGER.info("CreateModelProcedure was submitted, procedureId: {}.", procedureId);
- return RpcUtils.SUCCESS_STATUS;
- }
-
- public TSStatus dropModel(String modelId) {
- DropModelProcedure procedure = new DropModelProcedure(modelId);
- executor.submitProcedure(procedure);
- TSStatus status = waitingProcedureFinished(procedure);
- if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- return status;
- } else {
- return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode())
- .setMessage(status.getMessage());
- }
- }
-
public TSStatus createPipePlugin(
PipePluginMeta pipePluginMeta, byte[] jarFile, boolean isSetIfNotExistsCondition) {
final CreatePipePluginProcedure createPipePluginProcedure =
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
deleted file mode 100644
index aeada03d15cc3..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java
+++ /dev/null
@@ -1,378 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.persistence;
-
-import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.commons.model.ModelInformation;
-import org.apache.iotdb.commons.model.ModelStatus;
-import org.apache.iotdb.commons.model.ModelTable;
-import org.apache.iotdb.commons.model.ModelType;
-import org.apache.iotdb.commons.snapshot.SnapshotProcessor;
-import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
-import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
-import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp;
-import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp;
-import org.apache.iotdb.rpc.TSStatusCode;
-
-import org.apache.thrift.TException;
-import org.apache.tsfile.utils.PublicBAOS;
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import javax.annotation.concurrent.ThreadSafe;
-
-import java.io.DataOutputStream;
-import java.io.File;
-import java.io.FileInputStream;
-import java.io.FileOutputStream;
-import java.io.IOException;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.concurrent.locks.ReadWriteLock;
-import java.util.concurrent.locks.ReentrantReadWriteLock;
-
-@ThreadSafe
-public class ModelInfo implements SnapshotProcessor {
-
- private static final Logger LOGGER = LoggerFactory.getLogger(ModelInfo.class);
-
- private static final String SNAPSHOT_FILENAME = "model_info.snapshot";
-
- private ModelTable modelTable;
-
- private final Map> modelNameToNodes;
-
- private final ReadWriteLock modelTableLock = new ReentrantReadWriteLock();
-
- private static final Set builtInForecastModel = new HashSet<>();
-
- private static final Set builtInAnomalyDetectionModel = new HashSet<>();
-
- static {
- builtInForecastModel.add("arima");
- builtInForecastModel.add("naive_forecaster");
- builtInForecastModel.add("stl_forecaster");
- builtInForecastModel.add("holtwinters");
- builtInForecastModel.add("exponential_smoothing");
- builtInForecastModel.add("timer_xl");
- builtInForecastModel.add("sundial");
- builtInAnomalyDetectionModel.add("gaussian_hmm");
- builtInAnomalyDetectionModel.add("gmm_hmm");
- builtInAnomalyDetectionModel.add("stray");
- }
-
- public ModelInfo() {
- this.modelTable = new ModelTable();
- this.modelNameToNodes = new HashMap<>();
- }
-
- public boolean contain(String modelName) {
- return modelTable.containsModel(modelName);
- }
-
- public void acquireModelTableReadLock() {
- LOGGER.info("acquire ModelTableReadLock");
- modelTableLock.readLock().lock();
- }
-
- public void releaseModelTableReadLock() {
- LOGGER.info("release ModelTableReadLock");
- modelTableLock.readLock().unlock();
- }
-
- public void acquireModelTableWriteLock() {
- LOGGER.info("acquire ModelTableWriteLock");
- modelTableLock.writeLock().lock();
- }
-
- public void releaseModelTableWriteLock() {
- LOGGER.info("release ModelTableWriteLock");
- modelTableLock.writeLock().unlock();
- }
-
- // init the model in modeInfo, it won't update the details information of the model
- public TSStatus createModel(CreateModelPlan plan) {
- try {
- acquireModelTableWriteLock();
- String modelName = plan.getModelName();
- modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING));
- return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
- } catch (Exception e) {
- final String errorMessage =
- String.format(
- "Failed to add model [%s] in ModelTable on Config Nodes, because of %s",
- plan.getModelName(), e);
- LOGGER.warn(errorMessage, e);
- return new TSStatus(TSStatusCode.CREATE_MODEL_ERROR.getStatusCode()).setMessage(errorMessage);
- } finally {
- releaseModelTableWriteLock();
- }
- }
-
- public TSStatus dropModelInNode(int aiNodeId) {
- acquireModelTableWriteLock();
- try {
- for (Map.Entry> entry : modelNameToNodes.entrySet()) {
- entry.getValue().remove(Integer.valueOf(aiNodeId));
- // if list is empty, remove this model totally
- if (entry.getValue().isEmpty()) {
- modelTable.removeModel(entry.getKey());
- modelNameToNodes.remove(entry.getKey());
- }
- }
- // currently, we only have one AINode at a time, so we can just clear failed model.
- modelTable.clearFailedModel();
- return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
- } finally {
- releaseModelTableWriteLock();
- }
- }
-
- public TSStatus dropModel(String modelName) {
- acquireModelTableWriteLock();
- TSStatus status;
- if (modelTable.containsModel(modelName)) {
- modelTable.removeModel(modelName);
- modelNameToNodes.remove(modelName);
- status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
- } else {
- status =
- new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode())
- .setMessage(String.format("model [%s] has not been created.", modelName));
- }
- releaseModelTableWriteLock();
- return status;
- }
-
- public List getNodeIds(String modelName) {
- return modelNameToNodes.getOrDefault(modelName, Collections.emptyList());
- }
-
- private ModelInformation getModelByName(String modelName) {
- ModelType modelType = checkModelType(modelName);
- if (modelType != ModelType.USER_DEFINED) {
- if (modelType == ModelType.BUILT_IN_FORECAST && builtInForecastModel.contains(modelName)) {
- return new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName);
- } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION
- && builtInAnomalyDetectionModel.contains(modelName)) {
- return new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName);
- }
- } else {
- return modelTable.getModelInformationById(modelName);
- }
- return null;
- }
-
- public ModelTableResp showModel(ShowModelPlan plan) {
- acquireModelTableReadLock();
- try {
- ModelTableResp modelTableResp =
- new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
- if (plan.isSetModelName()) {
- ModelInformation modelInformation = getModelByName(plan.getModelName());
- if (modelInformation != null) {
- modelTableResp.addModelInformation(modelInformation);
- }
- } else {
- modelTableResp.addModelInformation(modelTable.getAllModelInformation());
- for (String modelName : builtInForecastModel) {
- modelTableResp.addModelInformation(
- new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName));
- }
- for (String modelName : builtInAnomalyDetectionModel) {
- modelTableResp.addModelInformation(
- new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName));
- }
- }
- return modelTableResp;
- } catch (IOException e) {
- LOGGER.warn("Fail to get ModelTable", e);
- return new ModelTableResp(
- new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
- .setMessage(e.getMessage()));
- } finally {
- releaseModelTableReadLock();
- }
- }
-
- private boolean containsBuiltInModelName(Set builtInModelSet, String modelName) {
- // ignore the case
- for (String builtInModelName : builtInModelSet) {
- if (builtInModelName.equalsIgnoreCase(modelName)) {
- return true;
- }
- }
- return false;
- }
-
- public ModelType checkModelType(String modelName) {
- if (containsBuiltInModelName(builtInForecastModel, modelName)) {
- return ModelType.BUILT_IN_FORECAST;
- } else if (containsBuiltInModelName(builtInAnomalyDetectionModel, modelName)) {
- return ModelType.BUILT_IN_ANOMALY_DETECTION;
- } else {
- return ModelType.USER_DEFINED;
- }
- }
-
- private int getAvailableAINodeForModel(String modelName, ModelType modelType) {
- if (modelType == ModelType.USER_DEFINED) {
- List aiNodeIds = modelNameToNodes.get(modelName);
- if (aiNodeIds != null) {
- return aiNodeIds.get(0);
- }
- } else {
- // any AINode is fine for built-in model
- // 0 is always the nodeId for configNode, so it's fine to use 0 as special value
- return 0;
- }
- return -1;
- }
-
- // This method will be used by dataNode to get schema of the model for inference
- public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) {
- acquireModelTableReadLock();
- try {
- String modelName = plan.getModelId();
- GetModelInfoResp getModelInfoResp;
- ModelInformation modelInformation;
- ModelType modelType;
- // check if it's a built-in model
- if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) {
- modelInformation = new ModelInformation(modelType, modelName);
- } else {
- modelInformation = modelTable.getModelInformationById(modelName);
- }
-
- if (modelInformation != null) {
- getModelInfoResp =
- new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()));
- } else {
- TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
- errorStatus.setMessage(String.format("model [%s] has not been created.", modelName));
- getModelInfoResp = new GetModelInfoResp(errorStatus);
- return getModelInfoResp;
- }
- PublicBAOS buffer = new PublicBAOS();
- DataOutputStream stream = new DataOutputStream(buffer);
- modelInformation.serialize(stream);
- // select the nodeId to process the task, currently we default use the first one.
- int aiNodeId = getAvailableAINodeForModel(modelName, modelType);
- if (aiNodeId == -1) {
- TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode());
- errorStatus.setMessage(String.format("There is no AINode with %s available", modelName));
- getModelInfoResp = new GetModelInfoResp(errorStatus);
- return getModelInfoResp;
- } else {
- getModelInfoResp.setTargetAINodeId(aiNodeId);
- }
- return getModelInfoResp;
- } catch (IOException e) {
- LOGGER.warn("Fail to get model info", e);
- return new GetModelInfoResp(
- new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode())
- .setMessage(e.getMessage()));
- } finally {
- releaseModelTableReadLock();
- }
- }
-
- public TSStatus updateModelInfo(UpdateModelInfoPlan plan) {
- acquireModelTableWriteLock();
- try {
- String modelName = plan.getModelName();
- if (modelTable.containsModel(modelName)) {
- modelTable.updateModel(modelName, plan.getModelInformation());
- }
- if (!plan.getNodeIds().isEmpty()) {
- // only used in model registration, so we can just put the nodeIds in the map without
- // checking
- modelNameToNodes.put(modelName, plan.getNodeIds());
- }
- return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode());
- } finally {
- releaseModelTableWriteLock();
- }
- }
-
- @Override
- public boolean processTakeSnapshot(File snapshotDir) throws TException, IOException {
- File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME);
- if (snapshotFile.exists() && snapshotFile.isFile()) {
- LOGGER.error(
- "Failed to take snapshot of ModelInfo, because snapshot file [{}] is already exist.",
- snapshotFile.getAbsolutePath());
- return false;
- }
-
- acquireModelTableReadLock();
- try (FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)) {
- modelTable.serialize(fileOutputStream);
- ReadWriteIOUtils.write(modelNameToNodes.size(), fileOutputStream);
- for (Map.Entry> entry : modelNameToNodes.entrySet()) {
- ReadWriteIOUtils.write(entry.getKey(), fileOutputStream);
- ReadWriteIOUtils.write(entry.getValue().size(), fileOutputStream);
- for (Integer nodeId : entry.getValue()) {
- ReadWriteIOUtils.write(nodeId, fileOutputStream);
- }
- }
- fileOutputStream.getFD().sync();
- return true;
- } finally {
- releaseModelTableReadLock();
- }
- }
-
- @Override
- public void processLoadSnapshot(File snapshotDir) throws TException, IOException {
- File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME);
- if (!snapshotFile.exists() || !snapshotFile.isFile()) {
- LOGGER.error(
- "Failed to load snapshot of ModelInfo, snapshot file [{}] does not exist.",
- snapshotFile.getAbsolutePath());
- return;
- }
- acquireModelTableWriteLock();
- try (FileInputStream fileInputStream = new FileInputStream(snapshotFile)) {
- modelTable.clear();
- modelTable = ModelTable.deserialize(fileInputStream);
- int size = ReadWriteIOUtils.readInt(fileInputStream);
- for (int i = 0; i < size; i++) {
- String modelName = ReadWriteIOUtils.readString(fileInputStream);
- int nodeSize = ReadWriteIOUtils.readInt(fileInputStream);
- List nodes = new LinkedList<>();
- for (int j = 0; j < nodeSize; j++) {
- nodes.add(ReadWriteIOUtils.readInt(fileInputStream));
- }
- modelNameToNodes.put(modelName, nodes);
- }
- } finally {
- releaseModelTableWriteLock();
- }
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
index fe8b28c4da2e5..d6bad518f6f4b 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java
@@ -35,8 +35,6 @@
import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan;
import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan;
import org.apache.iotdb.confignode.consensus.request.read.function.GetUDFJarPlan;
-import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan;
-import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan;
import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan;
import org.apache.iotdb.confignode.consensus.request.read.partition.GetDataPartitionPlan;
import org.apache.iotdb.confignode.consensus.request.read.partition.GetNodePathsPartitionPlan;
@@ -84,10 +82,6 @@
import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan;
import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan;
import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan;
import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan;
import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan;
@@ -150,7 +144,6 @@
import org.apache.iotdb.confignode.exception.physical.UnknownPhysicalPlanTypeException;
import org.apache.iotdb.confignode.manager.pipe.agent.PipeConfigNodeAgent;
import org.apache.iotdb.confignode.persistence.ClusterInfo;
-import org.apache.iotdb.confignode.persistence.ModelInfo;
import org.apache.iotdb.confignode.persistence.ProcedureInfo;
import org.apache.iotdb.confignode.persistence.TTLInfo;
import org.apache.iotdb.confignode.persistence.TriggerInfo;
@@ -210,8 +203,6 @@ public class ConfigPlanExecutor {
private final CQInfo cqInfo;
- private final ModelInfo modelInfo;
-
private final PipeInfo pipeInfo;
private final SubscriptionInfo subscriptionInfo;
@@ -230,7 +221,6 @@ public ConfigPlanExecutor(
UDFInfo udfInfo,
TriggerInfo triggerInfo,
CQInfo cqInfo,
- ModelInfo modelInfo,
PipeInfo pipeInfo,
SubscriptionInfo subscriptionInfo,
QuotaInfo quotaInfo,
@@ -262,9 +252,6 @@ public ConfigPlanExecutor(
this.cqInfo = cqInfo;
this.snapshotProcessorList.add(cqInfo);
- this.modelInfo = modelInfo;
- this.snapshotProcessorList.add(modelInfo);
-
this.pipeInfo = pipeInfo;
this.snapshotProcessorList.add(pipeInfo);
@@ -362,10 +349,6 @@ public DataSet executeQueryPlan(final ConfigPhysicalReadPlan req)
return udfInfo.getUDFJar((GetUDFJarPlan) req);
case GetAllFunctionTable:
return udfInfo.getAllUDFTable();
- case ShowModel:
- return modelInfo.showModel((ShowModelPlan) req);
- case GetModelInfo:
- return modelInfo.getModelInfo((GetModelInfoPlan) req);
case GetPipePluginTable:
return pipeInfo.getPipePluginInfo().showPipePlugins();
case GetPipePluginJar:
@@ -648,14 +631,6 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan)
return cqInfo.activeCQ((ActiveCQPlan) physicalPlan);
case UPDATE_CQ_LAST_EXEC_TIME:
return cqInfo.updateCQLastExecutionTime((UpdateCQLastExecTimePlan) physicalPlan);
- case CreateModel:
- return modelInfo.createModel((CreateModelPlan) physicalPlan);
- case UpdateModelInfo:
- return modelInfo.updateModelInfo((UpdateModelInfoPlan) physicalPlan);
- case DropModel:
- return modelInfo.dropModel(((DropModelPlan) physicalPlan).getModelName());
- case DropModelInNode:
- return modelInfo.dropModelInNode(((DropModelInNodePlan) physicalPlan).getNodeId());
case CreatePipePlugin:
return pipeInfo.getPipePluginInfo().createPipePlugin((CreatePipePluginPlan) physicalPlan);
case DropPipePlugin:
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java
index 280a891b598d7..27cc3cc4cbf5a 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java
@@ -161,19 +161,19 @@ public boolean isJarNeededToBeSavedWhenCreatingPipePlugin(final String jarName)
}
public void checkPipePluginExistence(
- final Map extractorAttributes,
+ final Map sourceAttributes,
final Map processorAttributes,
- final Map connectorAttributes) {
- final PipeParameters extractorParameters = new PipeParameters(extractorAttributes);
- final String extractorPluginName =
- extractorParameters.getStringOrDefault(
+ final Map sinkAttributes) {
+ final PipeParameters sourceParameters = new PipeParameters(sourceAttributes);
+ final String sourcePluginName =
+ sourceParameters.getStringOrDefault(
Arrays.asList(PipeSourceConstant.EXTRACTOR_KEY, PipeSourceConstant.SOURCE_KEY),
IOTDB_EXTRACTOR.getPipePluginName());
- if (!pipePluginMetaKeeper.containsPipePlugin(extractorPluginName)) {
+ if (!pipePluginMetaKeeper.containsPipePlugin(sourcePluginName)) {
final String exceptionMessage =
String.format(
"Failed to create or alter pipe, the pipe extractor plugin %s does not exist",
- extractorPluginName);
+ sourcePluginName);
LOGGER.warn(exceptionMessage);
throw new PipeException(exceptionMessage);
}
@@ -191,16 +191,16 @@ public void checkPipePluginExistence(
throw new PipeException(exceptionMessage);
}
- final PipeParameters connectorParameters = new PipeParameters(connectorAttributes);
- final String connectorPluginName =
- connectorParameters.getStringOrDefault(
+ final PipeParameters sinkParameters = new PipeParameters(sinkAttributes);
+ final String sinkPluginName =
+ sinkParameters.getStringOrDefault(
Arrays.asList(PipeSinkConstant.CONNECTOR_KEY, PipeSinkConstant.SINK_KEY),
IOTDB_THRIFT_CONNECTOR.getPipePluginName());
- if (!pipePluginMetaKeeper.containsPipePlugin(connectorPluginName)) {
+ if (!pipePluginMetaKeeper.containsPipePlugin(sinkPluginName)) {
final String exceptionMessage =
String.format(
"Failed to create or alter pipe, the pipe connector plugin %s does not exist",
- connectorPluginName);
+ sinkPluginName);
LOGGER.warn(exceptionMessage);
throw new PipeException(exceptionMessage);
}
@@ -212,34 +212,29 @@ public TSStatus createPipePlugin(final CreatePipePluginPlan createPipePluginPlan
try {
final PipePluginMeta pipePluginMeta = createPipePluginPlan.getPipePluginMeta();
final String pluginName = pipePluginMeta.getPluginName();
+ final String className = pipePluginMeta.getClassName();
+ final String jarName = pipePluginMeta.getJarName();
// try to drop the old pipe plugin if exists to reduce the effect of the inconsistency
dropPipePlugin(new DropPipePluginPlan(pluginName));
pipePluginMetaKeeper.addPipePluginMeta(pluginName, pipePluginMeta);
- pipePluginMetaKeeper.addJarNameAndMd5(
- pipePluginMeta.getJarName(), pipePluginMeta.getJarMD5());
+ pipePluginMetaKeeper.addJarNameAndMd5(jarName, pipePluginMeta.getJarMD5());
if (createPipePluginPlan.getJarFile() != null) {
pipePluginExecutableManager.savePluginToInstallDir(
- ByteBuffer.wrap(createPipePluginPlan.getJarFile().getValues()),
- pluginName,
- pipePluginMeta.getJarName());
- final String pluginDirPath = pipePluginExecutableManager.getPluginsDirPath(pluginName);
- final PipePluginClassLoader pipePluginClassLoader =
- classLoaderManager.createPipePluginClassLoader(pluginDirPath);
- try {
- final Class> pluginClass =
- Class.forName(pipePluginMeta.getClassName(), true, pipePluginClassLoader);
- pipePluginMetaKeeper.addPipePluginVisibility(
- pluginName, VisibilityUtils.calculateFromPluginClass(pluginClass));
- classLoaderManager.addPluginAndClassLoader(pluginName, pipePluginClassLoader);
- } catch (final Exception e) {
- try {
- pipePluginClassLoader.close();
- } catch (final Exception ignored) {
- }
- throw e;
+ ByteBuffer.wrap(createPipePluginPlan.getJarFile().getValues()), pluginName, jarName);
+ computeFromPluginClass(pluginName, className);
+ } else {
+ final String existed = pipePluginMetaKeeper.getPluginNameByJarName(jarName);
+ if (Objects.nonNull(existed)) {
+ pipePluginExecutableManager.linkExistedPlugin(existed, pluginName, jarName);
+ computeFromPluginClass(pluginName, className);
+ } else {
+ throw new PipeException(
+ String.format(
+ "The %s's creation has not passed in jarName, which does not exist in other pipePlugins. Please check",
+ pluginName));
}
}
@@ -255,6 +250,25 @@ public TSStatus createPipePlugin(final CreatePipePluginPlan createPipePluginPlan
}
}
+ private void computeFromPluginClass(final String pluginName, final String className)
+ throws Exception {
+ final String pluginDirPath = pipePluginExecutableManager.getPluginsDirPath(pluginName);
+ final PipePluginClassLoader pipePluginClassLoader =
+ classLoaderManager.createPipePluginClassLoader(pluginDirPath);
+ try {
+ final Class> pluginClass = Class.forName(className, true, pipePluginClassLoader);
+ pipePluginMetaKeeper.addPipePluginVisibility(
+ pluginName, VisibilityUtils.calculateFromPluginClass(pluginClass));
+ classLoaderManager.addPluginAndClassLoader(pluginName, pipePluginClassLoader);
+ } catch (final Exception e) {
+ try {
+ pipePluginClassLoader.close();
+ } catch (final Exception ignored) {
+ }
+ throw e;
+ }
+ }
+
public TSStatus dropPipePlugin(final DropPipePluginPlan dropPipePluginPlan) {
final String pluginName = dropPipePluginPlan.getPluginName();
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
deleted file mode 100644
index 989061610213d..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java
+++ /dev/null
@@ -1,250 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.procedure.impl.model;
-
-import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
-import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.commons.exception.ainode.LoadModelException;
-import org.apache.iotdb.commons.model.ModelInformation;
-import org.apache.iotdb.commons.model.ModelStatus;
-import org.apache.iotdb.commons.model.exception.ModelManagementException;
-import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan;
-import org.apache.iotdb.confignode.manager.ConfigManager;
-import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
-import org.apache.iotdb.confignode.procedure.exception.ProcedureException;
-import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure;
-import org.apache.iotdb.confignode.procedure.state.model.CreateModelState;
-import org.apache.iotdb.confignode.procedure.store.ProcedureType;
-import org.apache.iotdb.consensus.exception.ConsensusException;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
-import org.apache.iotdb.rpc.TSStatusCode;
-
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Objects;
-
-public class CreateModelProcedure extends AbstractNodeProcedure {
-
- private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class);
- private static final int RETRY_THRESHOLD = 0;
-
- private String modelName;
-
- private String uri;
-
- private ModelInformation modelInformation = null;
-
- private List aiNodeIds;
-
- private String loadErrorMsg = "";
-
- public CreateModelProcedure() {
- super();
- }
-
- public CreateModelProcedure(String modelName, String uri) {
- super();
- this.modelName = modelName;
- this.uri = uri;
- this.aiNodeIds = new ArrayList<>();
- }
-
- @Override
- protected Flow executeFromState(ConfigNodeProcedureEnv env, CreateModelState state) {
- if (modelName == null || uri == null) {
- return Flow.NO_MORE_STATE;
- }
- try {
- switch (state) {
- case LOADING:
- initModel(env);
- loadModel(env);
- setNextState(CreateModelState.ACTIVE);
- break;
- case ACTIVE:
- modelInformation.updateStatus(ModelStatus.ACTIVE);
- updateModel(env);
- return Flow.NO_MORE_STATE;
- default:
- throw new UnsupportedOperationException(
- String.format("Unknown state during executing createModelProcedure, %s", state));
- }
- } catch (Exception e) {
- if (isRollbackSupported(state)) {
- LOGGER.error("Fail in CreateModelProcedure", e);
- setFailure(new ProcedureException(e.getMessage()));
- } else {
- LOGGER.error(
- "Retrievable error trying to create model [{}], state [{}]", modelName, state, e);
- if (getCycles() > RETRY_THRESHOLD) {
- modelInformation = new ModelInformation(modelName, ModelStatus.UNAVAILABLE);
- modelInformation.setAttribute(loadErrorMsg);
- updateModel(env);
- setFailure(
- new ProcedureException(
- String.format("Fail to create model [%s] at STATE [%s]", modelName, state)));
- }
- }
- }
- return Flow.HAS_MORE_STATE;
- }
-
- private void initModel(ConfigNodeProcedureEnv env) throws ConsensusException {
- LOGGER.info("Start to add model [{}]", modelName);
-
- ConfigManager configManager = env.getConfigManager();
- TSStatus response = configManager.getConsensusManager().write(new CreateModelPlan(modelName));
- if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new ModelManagementException(
- String.format(
- "Failed to add model [%s] in ModelTable on Config Nodes: %s",
- modelName, response.getMessage()));
- }
- }
-
- private void checkModelInformationEquals(ModelInformation receiveModelInfo) {
- if (modelInformation == null) {
- modelInformation = receiveModelInfo;
- } else {
- if (!modelInformation.equals(receiveModelInfo)) {
- throw new ModelManagementException(
- String.format(
- "Failed to load model [%s] on AI Nodes, model information is not equal in different nodes",
- modelName));
- }
- }
- }
-
- private void loadModel(ConfigNodeProcedureEnv env) {
- for (TAINodeConfiguration curNodeConfig :
- env.getConfigManager().getNodeManager().getRegisteredAINodes()) {
- try (AINodeClient client =
- AINodeClientManager.getInstance()
- .borrowClient(curNodeConfig.getLocation().getInternalEndPoint())) {
- ModelInformation resp = client.registerModel(modelName, uri);
- checkModelInformationEquals(resp);
- aiNodeIds.add(curNodeConfig.getLocation().aiNodeId);
- } catch (LoadModelException e) {
- LOGGER.warn(e.getMessage());
- loadErrorMsg = e.getMessage();
- } catch (Exception e) {
- LOGGER.warn(
- "Failed to load model on AINode {} from ConfigNode",
- curNodeConfig.getLocation().getInternalEndPoint());
- loadErrorMsg = e.getMessage();
- }
- }
-
- if (aiNodeIds.isEmpty()) {
- throw new ModelManagementException(
- String.format("CREATE MODEL [%s] failed on all AINodes:[%s]", modelName, loadErrorMsg));
- }
- }
-
- private void updateModel(ConfigNodeProcedureEnv env) {
- LOGGER.info("Start to update model [{}]", modelName);
-
- ConfigManager configManager = env.getConfigManager();
- try {
- TSStatus response =
- configManager
- .getConsensusManager()
- .write(new UpdateModelInfoPlan(modelName, modelInformation, aiNodeIds));
- if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new ModelManagementException(
- String.format(
- "Failed to update model [%s] in ModelTable on Config Nodes: %s",
- modelName, response.getMessage()));
- }
- } catch (Exception e) {
- throw new ModelManagementException(
- String.format(
- "Failed to update model [%s] in ModelTable on Config Nodes: %s",
- modelName, e.getMessage()));
- }
- }
-
- @Override
- protected void rollbackState(ConfigNodeProcedureEnv env, CreateModelState state)
- throws IOException, InterruptedException, ProcedureException {
- // do nothing
- }
-
- @Override
- protected boolean isRollbackSupported(CreateModelState state) {
- return false;
- }
-
- @Override
- protected CreateModelState getState(int stateId) {
- return CreateModelState.values()[stateId];
- }
-
- @Override
- protected int getStateId(CreateModelState createModelState) {
- return createModelState.ordinal();
- }
-
- @Override
- protected CreateModelState getInitialState() {
- return CreateModelState.LOADING;
- }
-
- @Override
- public void serialize(DataOutputStream stream) throws IOException {
- stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode());
- super.serialize(stream);
- ReadWriteIOUtils.write(modelName, stream);
- ReadWriteIOUtils.write(uri, stream);
- }
-
- @Override
- public void deserialize(ByteBuffer byteBuffer) {
- super.deserialize(byteBuffer);
- modelName = ReadWriteIOUtils.readString(byteBuffer);
- uri = ReadWriteIOUtils.readString(byteBuffer);
- }
-
- @Override
- public boolean equals(Object that) {
- if (that instanceof CreateModelProcedure) {
- CreateModelProcedure thatProc = (CreateModelProcedure) that;
- return thatProc.getProcId() == this.getProcId()
- && thatProc.getState() == this.getState()
- && Objects.equals(thatProc.modelName, this.modelName)
- && Objects.equals(thatProc.uri, this.uri);
- }
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(getProcId(), getState(), modelName, uri);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
deleted file mode 100644
index daa029e04ddfd..0000000000000
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package org.apache.iotdb.confignode.procedure.impl.model;
-
-import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq;
-import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration;
-import org.apache.iotdb.common.rpc.thrift.TSStatus;
-import org.apache.iotdb.commons.model.exception.ModelManagementException;
-import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan;
-import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
-import org.apache.iotdb.confignode.procedure.exception.ProcedureException;
-import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure;
-import org.apache.iotdb.confignode.procedure.state.model.DropModelState;
-import org.apache.iotdb.confignode.procedure.store.ProcedureType;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
-import org.apache.iotdb.rpc.TSStatusCode;
-
-import org.apache.thrift.TException;
-import org.apache.tsfile.utils.ReadWriteIOUtils;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import java.io.DataOutputStream;
-import java.io.IOException;
-import java.nio.ByteBuffer;
-import java.util.List;
-import java.util.Objects;
-
-import static org.apache.iotdb.confignode.procedure.state.model.DropModelState.CONFIG_NODE_DROPPED;
-
-public class DropModelProcedure extends AbstractNodeProcedure {
-
- private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class);
- private static final int RETRY_THRESHOLD = 1;
-
- private String modelName;
-
- public DropModelProcedure() {
- super();
- }
-
- public DropModelProcedure(String modelName) {
- super();
- this.modelName = modelName;
- }
-
- @Override
- protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state) {
- if (modelName == null) {
- return Flow.NO_MORE_STATE;
- }
- try {
- switch (state) {
- case AI_NODE_DROPPED:
- LOGGER.info("Start to drop model [{}] on AI Nodes", modelName);
- dropModelOnAINode(env);
- setNextState(CONFIG_NODE_DROPPED);
- break;
- case CONFIG_NODE_DROPPED:
- dropModelOnConfigNode(env);
- return Flow.NO_MORE_STATE;
- default:
- throw new UnsupportedOperationException(
- String.format("Unknown state during executing dropModelProcedure, %s", state));
- }
- } catch (Exception e) {
- if (isRollbackSupported(state)) {
- LOGGER.error("Fail in DropModelProcedure", e);
- setFailure(new ProcedureException(e.getMessage()));
- } else {
- LOGGER.error(
- "Retrievable error trying to drop model [{}], state [{}]", modelName, state, e);
- if (getCycles() > RETRY_THRESHOLD) {
- setFailure(
- new ProcedureException(
- String.format(
- "Fail to drop model [%s] at STATE [%s], %s",
- modelName, state, e.getMessage())));
- }
- }
- }
- return Flow.HAS_MORE_STATE;
- }
-
- private void dropModelOnAINode(ConfigNodeProcedureEnv env) {
- LOGGER.info("Start to drop model file [{}] on AI Node", modelName);
-
- List aiNodes =
- env.getConfigManager().getNodeManager().getRegisteredAINodes();
- aiNodes.forEach(
- aiNode -> {
- int nodeId = aiNode.getLocation().getAiNodeId();
- try (AINodeClient client =
- AINodeClientManager.getInstance()
- .borrowClient(
- env.getConfigManager()
- .getNodeManager()
- .getRegisteredAINode(nodeId)
- .getLocation()
- .getInternalEndPoint())) {
- TSStatus status = client.deleteModel(new TDeleteModelReq(modelName));
- if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- LOGGER.warn(
- "Failed to drop model [{}] on AINode [{}], status: {}",
- modelName,
- nodeId,
- status.getMessage());
- }
- } catch (Exception e) {
- LOGGER.warn(
- "Failed to drop model [{}] on AINode [{}], status: {}",
- modelName,
- nodeId,
- e.getMessage());
- }
- });
- }
-
- private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) {
- try {
- TSStatus response =
- env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelName));
- if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) {
- throw new TException(response.getMessage());
- }
- } catch (Exception e) {
- throw new ModelManagementException(
- String.format(
- "Fail to start training model [%s] on AI Node: %s", modelName, e.getMessage()));
- }
- }
-
- @Override
- protected void rollbackState(ConfigNodeProcedureEnv env, DropModelState state)
- throws IOException, InterruptedException, ProcedureException {
- // no need to rollback
- }
-
- @Override
- protected DropModelState getState(int stateId) {
- return DropModelState.values()[stateId];
- }
-
- @Override
- protected int getStateId(DropModelState dropModelState) {
- return dropModelState.ordinal();
- }
-
- @Override
- protected DropModelState getInitialState() {
- return DropModelState.AI_NODE_DROPPED;
- }
-
- @Override
- public void serialize(DataOutputStream stream) throws IOException {
- stream.writeShort(ProcedureType.DROP_MODEL_PROCEDURE.getTypeCode());
- super.serialize(stream);
- ReadWriteIOUtils.write(modelName, stream);
- }
-
- @Override
- public void deserialize(ByteBuffer byteBuffer) {
- super.deserialize(byteBuffer);
- modelName = ReadWriteIOUtils.readString(byteBuffer);
- }
-
- @Override
- public boolean equals(Object that) {
- if (that instanceof DropModelProcedure) {
- DropModelProcedure thatProc = (DropModelProcedure) that;
- return thatProc.getProcId() == this.getProcId()
- && thatProc.getState() == this.getState()
- && (thatProc.modelName).equals(this.modelName);
- }
- return false;
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(getProcId(), getState(), modelName);
- }
-}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java
index 2cab08c28244e..2a1c6881b1413 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java
@@ -23,13 +23,12 @@
import org.apache.iotdb.common.rpc.thrift.TSStatus;
import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils;
import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan;
-import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan;
import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv;
import org.apache.iotdb.confignode.procedure.exception.ProcedureException;
import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState;
import org.apache.iotdb.confignode.procedure.store.ProcedureType;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClient;
-import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager;
+import org.apache.iotdb.db.protocol.client.an.AINodeClient;
+import org.apache.iotdb.db.protocol.client.an.AINodeClientManager;
import org.apache.iotdb.rpc.TSStatusCode;
import org.slf4j.Logger;
@@ -65,17 +64,11 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState st
try {
switch (state) {
- case MODEL_DELETE:
- env.getConfigManager()
- .getConsensusManager()
- .write(new DropModelInNodePlan(removedAINode.aiNodeId));
- // Cause the AINode is removed, so we don't need to remove the model file.
- setNextState(RemoveAINodeState.NODE_STOP);
- break;
case NODE_STOP:
TSStatus resp = null;
try (AINodeClient client =
- AINodeClientManager.getInstance().borrowClient(removedAINode.getInternalEndPoint())) {
+ AINodeClientManager.getInstance()
+ .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) {
resp = client.stopAINode();
} catch (Exception e) {
LOGGER.warn(
@@ -148,7 +141,7 @@ protected int getStateId(RemoveAINodeState removeAINodeState) {
@Override
protected RemoveAINodeState getInitialState() {
- return RemoveAINodeState.MODEL_DELETE;
+ return RemoveAINodeState.NODE_STOP;
}
@Override
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java
index f6b84cb1f6e3a..53f908bf4fc11 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java
@@ -187,6 +187,7 @@ public void executeFromCalculateInfoForTask(final ConfigNodeProcedureEnv env) {
&& !databaseName.startsWith(SchemaConstant.SYSTEM_DATABASE + ".")
&& !databaseName.equals(SchemaConstant.AUDIT_DATABASE)
&& !databaseName.startsWith(SchemaConstant.AUDIT_DATABASE + ".")
+ && !Objects.isNull(currentPipeTaskMeta)
&& !(PipeTaskAgent.isHistoryOnlyPipe(
currentPipeStaticMeta.getSourceParameters())
&& PipeTaskAgent.isHistoryOnlyPipe(
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java
index 8a1a6a1bb03b5..49820df663616 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java
@@ -20,7 +20,6 @@
package org.apache.iotdb.confignode.procedure.state;
public enum RemoveAINodeState {
- MODEL_DELETE,
NODE_STOP,
NODE_REMOVE
}
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
index e023171f4fa88..f20a6999d5936 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java
@@ -22,8 +22,6 @@
import org.apache.iotdb.commons.exception.runtime.ThriftSerDeException;
import org.apache.iotdb.confignode.procedure.Procedure;
import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure;
-import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure;
-import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure;
import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure;
import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure;
import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure;
@@ -263,12 +261,6 @@ public Procedure create(ByteBuffer buffer) throws IOException {
case DROP_PIPE_PLUGIN_PROCEDURE:
procedure = new DropPipePluginProcedure();
break;
- case CREATE_MODEL_PROCEDURE:
- procedure = new CreateModelProcedure();
- break;
- case DROP_MODEL_PROCEDURE:
- procedure = new DropModelProcedure();
- break;
case AUTH_OPERATE_PROCEDURE:
procedure = new AuthOperationProcedure(false);
break;
@@ -494,10 +486,6 @@ public static ProcedureType getProcedureType(final Procedure> procedure) {
return ProcedureType.CREATE_PIPE_PLUGIN_PROCEDURE;
} else if (procedure instanceof DropPipePluginProcedure) {
return ProcedureType.DROP_PIPE_PLUGIN_PROCEDURE;
- } else if (procedure instanceof CreateModelProcedure) {
- return ProcedureType.CREATE_MODEL_PROCEDURE;
- } else if (procedure instanceof DropModelProcedure) {
- return ProcedureType.DROP_MODEL_PROCEDURE;
} else if (procedure instanceof CreatePipeProcedureV2) {
return ProcedureType.CREATE_PIPE_PROCEDURE_V2;
} else if (procedure instanceof StartPipeProcedureV2) {
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java
index 65ac1fb24ad5a..d076a7d9d926e 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java
@@ -85,7 +85,9 @@ public enum ProcedureType {
RENAME_VIEW_PROCEDURE((short) 764),
/** AI Model */
+ @Deprecated // Since 2.0.6, all models are managed by AINode
CREATE_MODEL_PROCEDURE((short) 800),
+ @Deprecated // Since 2.0.6, all models are managed by AINode
DROP_MODEL_PROCEDURE((short) 801),
REMOVE_AI_NODE_PROCEDURE((short) 802),
diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
index 59ce7352312f0..6582a5bfff8e1 100644
--- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
+++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java
@@ -115,7 +115,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq;
-import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq;
@@ -144,7 +143,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp;
import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq;
-import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq;
import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq;
@@ -163,8 +161,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq;
-import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq;
import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp;
import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp;
@@ -226,7 +222,6 @@
import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp;
import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq;
import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq;
-import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq;
import org.apache.iotdb.confignode.service.ConfigNode;
import org.apache.iotdb.consensus.exception.ConsensusException;
import org.apache.iotdb.db.queryengine.plan.relational.type.AuthorRType;
@@ -1362,26 +1357,6 @@ public TShowCQResp showCQ() {
return configManager.showCQ();
}
- @Override
- public TSStatus createModel(TCreateModelReq req) {
- return configManager.createModel(req);
- }
-
- @Override
- public TSStatus dropModel(TDropModelReq req) {
- return configManager.dropModel(req);
- }
-
- @Override
- public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) {
- return configManager.getModelInfo(req);
- }
-
- @Override
- public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException {
- return configManager.updateModelInfo(req);
- }
-
@Override
public TSStatus setSpaceQuota(final TSetSpaceQuotaReq req) throws TException {
return configManager.setSpaceQuota(req);
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
index 375285dfa2fb9..3f19a2101a805 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java
@@ -814,6 +814,9 @@ public class IoTDBConfig {
/** time cost(ms) threshold for slow query. Unit: millisecond */
private long slowQueryThreshold = 10000;
+ /** time window threshold for record of history queries. Unit: minute */
+ private int queryCostStatWindow = 0;
+
private int patternMatchingThreshold = 1000000;
/**
@@ -1121,6 +1124,21 @@ public class IoTDBConfig {
private int loadTsFileSpiltPartitionMaxSize = 10;
+ /**
+ * The threshold for splitting statement when loading multiple TsFiles. When the number of TsFiles
+ * exceeds this threshold, the statement will be split into multiple sub-statements for batch
+ * execution to limit resource consumption during statement analysis. Default value is 10, which
+ * means splitting will occur when there are more than 10 files.
+ */
+ private int loadTsFileStatementSplitThreshold = 10;
+
+ /**
+ * The number of TsFiles that each sub-statement handles when splitting a statement. This
+ * parameter controls how many files are grouped together in each sub-statement during batch
+ * execution. Default value is 10, which means each sub-statement handles 10 files.
+ */
+ private int loadTsFileSubStatementBatchSize = 10;
+
private String[] loadActiveListeningDirs =
new String[] {
IoTDBConstant.EXT_FOLDER_NAME
@@ -2613,6 +2631,14 @@ public void setSlowQueryThreshold(long slowQueryThreshold) {
this.slowQueryThreshold = slowQueryThreshold;
}
+ public int getQueryCostStatWindow() {
+ return queryCostStatWindow;
+ }
+
+ public void setQueryCostStatWindow(int queryCostStatWindow) {
+ this.queryCostStatWindow = queryCostStatWindow;
+ }
+
public boolean isEnableIndex() {
return enableIndex;
}
@@ -4057,6 +4083,46 @@ public void setLoadTsFileSpiltPartitionMaxSize(int loadTsFileSpiltPartitionMaxSi
this.loadTsFileSpiltPartitionMaxSize = loadTsFileSpiltPartitionMaxSize;
}
+ public int getLoadTsFileStatementSplitThreshold() {
+ return loadTsFileStatementSplitThreshold;
+ }
+
+ public void setLoadTsFileStatementSplitThreshold(final int loadTsFileStatementSplitThreshold) {
+ if (loadTsFileStatementSplitThreshold < 0) {
+ logger.warn(
+ "Invalid loadTsFileStatementSplitThreshold value: {}. Using default value: 10",
+ loadTsFileStatementSplitThreshold);
+ return;
+ }
+ if (this.loadTsFileStatementSplitThreshold != loadTsFileStatementSplitThreshold) {
+ logger.info(
+ "loadTsFileStatementSplitThreshold changed from {} to {}",
+ this.loadTsFileStatementSplitThreshold,
+ loadTsFileStatementSplitThreshold);
+ }
+ this.loadTsFileStatementSplitThreshold = loadTsFileStatementSplitThreshold;
+ }
+
+ public int getLoadTsFileSubStatementBatchSize() {
+ return loadTsFileSubStatementBatchSize;
+ }
+
+ public void setLoadTsFileSubStatementBatchSize(final int loadTsFileSubStatementBatchSize) {
+ if (loadTsFileSubStatementBatchSize <= 0) {
+ logger.warn(
+ "Invalid loadTsFileSubStatementBatchSize value: {}. Using default value: 10",
+ loadTsFileSubStatementBatchSize);
+ return;
+ }
+ if (this.loadTsFileSubStatementBatchSize != loadTsFileSubStatementBatchSize) {
+ logger.info(
+ "loadTsFileSubStatementBatchSize changed from {} to {}",
+ this.loadTsFileSubStatementBatchSize,
+ loadTsFileSubStatementBatchSize);
+ }
+ this.loadTsFileSubStatementBatchSize = loadTsFileSubStatementBatchSize;
+ }
+
public String[] getPipeReceiverFileDirs() {
return (Objects.isNull(this.pipeReceiverFileDirs) || this.pipeReceiverFileDirs.length == 0)
? new String[] {systemDir + File.separator + "pipe" + File.separator + "receiver"}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
index c194f304e7f6a..e51f445c56784 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java
@@ -813,6 +813,11 @@ public void loadProperties(TrimProperties properties) throws BadNodeUrlException
properties.getProperty(
"slow_query_threshold", String.valueOf(conf.getSlowQueryThreshold()))));
+ conf.setQueryCostStatWindow(
+ Integer.parseInt(
+ properties.getProperty(
+ "query_cost_stat_window", String.valueOf(conf.getQueryCostStatWindow()))));
+
conf.setDataRegionNum(
Integer.parseInt(
properties.getProperty("data_region_num", String.valueOf(conf.getDataRegionNum()))));
@@ -2406,6 +2411,18 @@ private void loadLoadTsFileProps(TrimProperties properties) {
properties.getProperty(
"skip_failed_table_schema_check",
String.valueOf(conf.isSkipFailedTableSchemaCheck()))));
+
+ conf.setLoadTsFileStatementSplitThreshold(
+ Integer.parseInt(
+ properties.getProperty(
+ "load_tsfile_statement_split_threshold",
+ Integer.toString(conf.getLoadTsFileStatementSplitThreshold()))));
+
+ conf.setLoadTsFileSubStatementBatchSize(
+ Integer.parseInt(
+ properties.getProperty(
+ "load_tsfile_sub_statement_batch_size",
+ Integer.toString(conf.getLoadTsFileSubStatementBatchSize()))));
}
private void loadLoadTsFileHotModifiedProp(TrimProperties properties) throws IOException {
@@ -2454,6 +2471,18 @@ private void loadLoadTsFileHotModifiedProp(TrimProperties properties) throws IOE
"load_tsfile_split_partition_max_size",
Integer.toString(conf.getLoadTsFileSpiltPartitionMaxSize()))));
+ conf.setLoadTsFileStatementSplitThreshold(
+ Integer.parseInt(
+ properties.getProperty(
+ "load_tsfile_statement_split_threshold",
+ Integer.toString(conf.getLoadTsFileStatementSplitThreshold()))));
+
+ conf.setLoadTsFileSubStatementBatchSize(
+ Integer.parseInt(
+ properties.getProperty(
+ "load_tsfile_sub_statement_batch_size",
+ Integer.toString(conf.getLoadTsFileSubStatementBatchSize()))));
+
conf.setSkipFailedTableSchemaCheck(
Boolean.parseBoolean(
properties.getProperty(
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java
index d9c3fabfae8ac..c136ffbe7d3e3 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java
@@ -33,7 +33,6 @@
import org.apache.tsfile.utils.PublicBAOS;
import org.apache.tsfile.utils.ReadWriteIOUtils;
-import org.apache.tsfile.write.record.Tablet;
import java.io.DataOutputStream;
import java.io.IOException;
@@ -247,11 +246,7 @@ public static PipeTransferTabletBatchReqV2 fromTPipeTransferReq(
size = ReadWriteIOUtils.readInt(transferReq.body);
for (int i = 0; i < size; ++i) {
- batchReq.tabletReqs.add(
- PipeTransferTabletRawReqV2.toTPipeTransferRawReq(
- Tablet.deserialize(transferReq.body),
- ReadWriteIOUtils.readBool(transferReq.body),
- ReadWriteIOUtils.readString(transferReq.body)));
+ batchReq.tabletReqs.add(PipeTransferTabletRawReqV2.toTPipeTransferRawReq(transferReq.body));
}
batchReq.version = transferReq.version;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java
index 7da44f297d140..af9b37edbf6ff 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java
@@ -22,6 +22,7 @@
import org.apache.iotdb.commons.exception.MetadataException;
import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion;
import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.PipeRequestType;
+import org.apache.iotdb.db.pipe.sink.util.TabletStatementConverter;
import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter;
import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement;
import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq;
@@ -43,10 +44,25 @@ public class PipeTransferTabletRawReq extends TPipeTransferReq {
private static final Logger LOGGER = LoggerFactory.getLogger(PipeTransferTabletRawReq.class);
- protected transient Tablet tablet;
+ protected transient InsertTabletStatement statement;
+
protected transient boolean isAligned;
+ protected transient Tablet tablet;
+ /**
+ * Get Tablet. If tablet is null, convert from statement.
+ *
+ * @return Tablet object
+ */
public Tablet getTablet() {
+ if (tablet == null && statement != null) {
+ try {
+ tablet = statement.convertToTablet();
+ } catch (final MetadataException e) {
+ LOGGER.warn("Failed to convert statement to tablet.", e);
+ return null;
+ }
+ }
return tablet;
}
@@ -54,16 +70,29 @@ public boolean getIsAligned() {
return isAligned;
}
+ /**
+ * Construct Statement. If statement already exists, return it. Otherwise, convert from tablet.
+ *
+ * @return InsertTabletStatement
+ */
public InsertTabletStatement constructStatement() {
+ if (statement != null) {
+ new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary();
+ return statement;
+ }
+
+ // Sort and deduplicate tablet before converting
new PipeTreeModelTabletEventSorter(tablet).deduplicateAndSortTimestampsIfNecessary();
try {
if (isTabletEmpty(tablet)) {
// Empty statement, will be filtered after construction
- return new InsertTabletStatement();
+ statement = new InsertTabletStatement();
+ return statement;
}
- return new InsertTabletStatement(tablet, isAligned, null);
+ statement = new InsertTabletStatement(tablet, isAligned, null);
+ return statement;
} catch (final MetadataException e) {
LOGGER.warn("Generate Statement from tablet {} error.", tablet, e);
return null;
@@ -107,8 +136,20 @@ public static PipeTransferTabletRawReq toTPipeTransferReq(
public static PipeTransferTabletRawReq fromTPipeTransferReq(final TPipeTransferReq transferReq) {
final PipeTransferTabletRawReq tabletReq = new PipeTransferTabletRawReq();
- tabletReq.tablet = Tablet.deserialize(transferReq.body);
- tabletReq.isAligned = ReadWriteIOUtils.readBool(transferReq.body);
+ final ByteBuffer buffer = transferReq.body;
+ final int startPosition = buffer.position();
+ try {
+ // V1: no databaseName, readDatabaseName = false
+ final InsertTabletStatement insertTabletStatement =
+ TabletStatementConverter.deserializeStatementFromTabletFormat(buffer, false);
+ tabletReq.isAligned = insertTabletStatement.isAligned();
+ // devicePath is already set in deserializeStatementFromTabletFormat for V1 format
+ tabletReq.statement = insertTabletStatement;
+ } catch (final Exception e) {
+ buffer.position(startPosition);
+ tabletReq.tablet = Tablet.deserialize(buffer);
+ tabletReq.isAligned = ReadWriteIOUtils.readBool(buffer);
+ }
tabletReq.version = transferReq.version;
tabletReq.type = transferReq.type;
@@ -118,18 +159,56 @@ public static PipeTransferTabletRawReq fromTPipeTransferReq(final TPipeTransferR
/////////////////////////////// Air Gap ///////////////////////////////
- public static byte[] toTPipeTransferBytes(final Tablet tablet, final boolean isAligned)
- throws IOException {
+ /**
+ * Serialize to bytes. If tablet is null, convert from statement first.
+ *
+ * @return serialized bytes
+ * @throws IOException if serialization fails
+ */
+ public byte[] toTPipeTransferBytes() throws IOException {
+ Tablet tabletToSerialize = tablet;
+ boolean isAlignedToSerialize = isAligned;
+
+ // If tablet is null, convert from statement
+ if (tabletToSerialize == null && statement != null) {
+ try {
+ tabletToSerialize = statement.convertToTablet();
+ isAlignedToSerialize = statement.isAligned();
+ } catch (final MetadataException e) {
+ throw new IOException("Failed to convert statement to tablet for serialization", e);
+ }
+ }
+
+ if (tabletToSerialize == null) {
+ throw new IOException("Cannot serialize: both tablet and statement are null");
+ }
+
try (final PublicBAOS byteArrayOutputStream = new PublicBAOS();
final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) {
ReadWriteIOUtils.write(IoTDBSinkRequestVersion.VERSION_1.getVersion(), outputStream);
ReadWriteIOUtils.write(PipeRequestType.TRANSFER_TABLET_RAW.getType(), outputStream);
- tablet.serialize(outputStream);
- ReadWriteIOUtils.write(isAligned, outputStream);
+ tabletToSerialize.serialize(outputStream);
+ ReadWriteIOUtils.write(isAlignedToSerialize, outputStream);
return byteArrayOutputStream.toByteArray();
}
}
+ /**
+ * Static method for backward compatibility. Creates a temporary instance and serializes.
+ *
+ * @param tablet Tablet to serialize
+ * @param isAligned whether aligned
+ * @return serialized bytes
+ * @throws IOException if serialization fails
+ */
+ public static byte[] toTPipeTransferBytes(final Tablet tablet, final boolean isAligned)
+ throws IOException {
+ final PipeTransferTabletRawReq req = new PipeTransferTabletRawReq();
+ req.tablet = tablet;
+ req.isAligned = isAligned;
+ return req.toTPipeTransferBytes();
+ }
+
/////////////////////////////// Object ///////////////////////////////
@Override
@@ -141,7 +220,16 @@ public boolean equals(final Object obj) {
return false;
}
final PipeTransferTabletRawReq that = (PipeTransferTabletRawReq) obj;
- return Objects.equals(tablet, that.tablet)
+ // Compare statement if both have it, otherwise compare tablet
+ if (statement != null && that.statement != null) {
+ return Objects.equals(statement, that.statement)
+ && isAligned == that.isAligned
+ && version == that.version
+ && type == that.type
+ && Objects.equals(body, that.body);
+ }
+ // Fallback to tablet comparison
+ return Objects.equals(getTablet(), that.getTablet())
&& isAligned == that.isAligned
&& version == that.version
&& type == that.type
@@ -150,6 +238,6 @@ public boolean equals(final Object obj) {
@Override
public int hashCode() {
- return Objects.hash(tablet, isAligned, version, type, body);
+ return Objects.hash(getTablet(), isAligned, version, type, body);
}
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java
index 43d8501252c5a..3c5f420a317fd 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java
@@ -22,6 +22,7 @@
import org.apache.iotdb.commons.exception.MetadataException;
import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion;
import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.PipeRequestType;
+import org.apache.iotdb.db.pipe.sink.util.TabletStatementConverter;
import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTableModelTabletEventSorter;
import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter;
import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement;
@@ -52,6 +53,16 @@ public String getDataBaseName() {
@Override
public InsertTabletStatement constructStatement() {
+ if (statement != null) {
+ if (Objects.isNull(dataBaseName)) {
+ new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary();
+ } else {
+ new PipeTableModelTabletEventSorter(statement).sortByTimestampIfNecessary();
+ }
+
+ return statement;
+ }
+
if (Objects.isNull(dataBaseName)) {
new PipeTreeModelTabletEventSorter(tablet).deduplicateAndSortTimestampsIfNecessary();
} else {
@@ -86,6 +97,16 @@ public static PipeTransferTabletRawReqV2 toTPipeTransferRawReq(
return tabletReq;
}
+ public static PipeTransferTabletRawReqV2 toTPipeTransferRawReq(final ByteBuffer buffer) {
+ final PipeTransferTabletRawReqV2 tabletReq = new PipeTransferTabletRawReqV2();
+
+ tabletReq.deserializeTPipeTransferRawReq(buffer);
+ tabletReq.version = IoTDBSinkRequestVersion.VERSION_1.getVersion();
+ tabletReq.type = PipeRequestType.TRANSFER_TABLET_RAW_V2.getType();
+
+ return tabletReq;
+ }
+
/////////////////////////////// Thrift ///////////////////////////////
public static PipeTransferTabletRawReqV2 toTPipeTransferReq(
@@ -114,13 +135,11 @@ public static PipeTransferTabletRawReqV2 fromTPipeTransferReq(
final TPipeTransferReq transferReq) {
final PipeTransferTabletRawReqV2 tabletReq = new PipeTransferTabletRawReqV2();
- tabletReq.tablet = Tablet.deserialize(transferReq.body);
- tabletReq.isAligned = ReadWriteIOUtils.readBool(transferReq.body);
- tabletReq.dataBaseName = ReadWriteIOUtils.readString(transferReq.body);
+ tabletReq.deserializeTPipeTransferRawReq(transferReq.body);
+ tabletReq.body = transferReq.body;
tabletReq.version = transferReq.version;
tabletReq.type = transferReq.type;
- tabletReq.body = transferReq.body;
return tabletReq;
}
@@ -161,4 +180,27 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hash(super.hashCode(), dataBaseName);
}
+
+ /////////////////////////////// Util ///////////////////////////////
+
+ public void deserializeTPipeTransferRawReq(final ByteBuffer buffer) {
+ final int startPosition = buffer.position();
+ try {
+ // V2: read databaseName, readDatabaseName = true
+ final InsertTabletStatement insertTabletStatement =
+ TabletStatementConverter.deserializeStatementFromTabletFormat(buffer, true);
+ this.isAligned = insertTabletStatement.isAligned();
+ // databaseName is already set in deserializeStatementFromTabletFormat when
+ // readDatabaseName=true
+ this.dataBaseName = insertTabletStatement.getDatabaseName().orElse(null);
+ this.statement = insertTabletStatement;
+ } catch (final Exception e) {
+ // If Statement deserialization fails, fallback to Tablet format
+ // Reset buffer position for Tablet deserialization
+ buffer.position(startPosition);
+ this.tablet = Tablet.deserialize(buffer);
+ this.isAligned = ReadWriteIOUtils.readBool(buffer);
+ this.dataBaseName = ReadWriteIOUtils.readString(buffer);
+ }
+ }
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java
new file mode 100644
index 0000000000000..c5b9ebed4d5f2
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.pipe.sink.util;
+
+import org.apache.iotdb.commons.exception.IllegalPathException;
+import org.apache.iotdb.commons.path.PartialPath;
+import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory;
+import org.apache.iotdb.db.pipe.resource.memory.InsertNodeMemoryEstimator;
+import org.apache.iotdb.db.queryengine.plan.analyze.cache.schema.DataNodeDevicePathCache;
+import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement;
+
+import org.apache.tsfile.enums.ColumnCategory;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.utils.Binary;
+import org.apache.tsfile.utils.BitMap;
+import org.apache.tsfile.utils.BytesUtils;
+import org.apache.tsfile.utils.Pair;
+import org.apache.tsfile.utils.ReadWriteIOUtils;
+import org.apache.tsfile.write.UnSupportedDataTypeException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.nio.ByteBuffer;
+import java.util.Arrays;
+
+/**
+ * Utility class for converting between InsertTabletStatement and Tablet format ByteBuffer. This
+ * avoids creating intermediate Tablet objects and directly converts between formats with only the
+ * fields needed.
+ */
+public class TabletStatementConverter {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TabletStatementConverter.class);
+
+ // Memory calculation constants - extracted from RamUsageEstimator for better performance
+ private static final long NUM_BYTES_ARRAY_HEADER =
+ org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;
+ private static final long NUM_BYTES_OBJECT_REF =
+ org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_OBJECT_REF;
+ private static final long NUM_BYTES_OBJECT_HEADER =
+ org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
+ private static final long SIZE_OF_ARRAYLIST =
+ org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance(java.util.ArrayList.class);
+ private static final long SIZE_OF_BITMAP =
+ org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance(
+ org.apache.tsfile.utils.BitMap.class);
+
+ private TabletStatementConverter() {
+ // Utility class, no instantiation
+ }
+
+ /**
+ * Deserialize InsertTabletStatement from Tablet format ByteBuffer.
+ *
+ * @param byteBuffer ByteBuffer containing serialized data
+ * @param readDatabaseName whether to read databaseName from buffer (for V2 format)
+ * @return InsertTabletStatement with all fields set, including devicePath
+ */
+ public static InsertTabletStatement deserializeStatementFromTabletFormat(
+ final ByteBuffer byteBuffer, final boolean readDatabaseName) throws IllegalPathException {
+ final InsertTabletStatement statement = new InsertTabletStatement();
+
+ // Calculate memory size during deserialization, use INSTANCE_SIZE constant
+ long memorySize = InsertTabletStatement.getInstanceSize();
+
+ final String insertTargetName = ReadWriteIOUtils.readString(byteBuffer);
+
+ final int rowSize = ReadWriteIOUtils.readInt(byteBuffer);
+
+ // deserialize schemas
+ final int schemaSize =
+ BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer))
+ ? ReadWriteIOUtils.readInt(byteBuffer)
+ : 0;
+ final String[] measurement = new String[schemaSize];
+ final TsTableColumnCategory[] columnCategories = new TsTableColumnCategory[schemaSize];
+ final TSDataType[] dataTypes = new TSDataType[schemaSize];
+
+ // Calculate memory for arrays headers and references during deserialization
+ // measurements array: array header + object references
+ long measurementMemorySize =
+ org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize(
+ NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize);
+
+ // dataTypes array: shallow size (array header + references)
+ long dataTypesMemorySize =
+ org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize(
+ NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize);
+
+ // columnCategories array: shallow size (array header + references)
+ long columnCategoriesMemorySize =
+ org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize(
+ NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize);
+
+ // tagColumnIndices (TAG columns): ArrayList base + array header
+ long tagColumnIndicesSize = SIZE_OF_ARRAYLIST;
+ tagColumnIndicesSize += NUM_BYTES_ARRAY_HEADER;
+
+ // Deserialize and calculate memory in the same loop
+ for (int i = 0; i < schemaSize; i++) {
+ final boolean hasSchema = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer));
+ if (hasSchema) {
+ final Pair pair = readMeasurement(byteBuffer);
+ measurement[i] = pair.getLeft();
+ dataTypes[i] = pair.getRight();
+ columnCategories[i] =
+ TsTableColumnCategory.fromTsFileColumnCategory(
+ ColumnCategory.values()[byteBuffer.get()]);
+
+ // Calculate memory for each measurement string
+ if (measurement[i] != null) {
+ measurementMemorySize += org.apache.tsfile.utils.RamUsageEstimator.sizeOf(measurement[i]);
+ }
+
+ // Calculate memory for TAG column indices
+ if (columnCategories[i] != null && columnCategories[i].equals(TsTableColumnCategory.TAG)) {
+ tagColumnIndicesSize +=
+ org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize(
+ Integer.BYTES + NUM_BYTES_OBJECT_HEADER)
+ + NUM_BYTES_OBJECT_REF;
+ }
+ }
+ }
+
+ // Add all calculated memory to total
+ memorySize += measurementMemorySize;
+ memorySize += dataTypesMemorySize;
+
+ // deserialize times and calculate memory during deserialization
+ final long[] times = new long[rowSize];
+ // Calculate memory: array header + long size * rowSize
+ final long timesMemorySize =
+ org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize(
+ NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * rowSize);
+
+ final boolean isTimesNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer));
+ if (isTimesNotNull) {
+ for (int i = 0; i < rowSize; i++) {
+ times[i] = ReadWriteIOUtils.readLong(byteBuffer);
+ }
+ }
+
+ // Add times memory to total
+ memorySize += timesMemorySize;
+
+ // deserialize bitmaps and calculate memory during deserialization
+ final BitMap[] bitMaps;
+ final long bitMapsMemorySize;
+
+ final boolean isBitMapsNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer));
+ if (isBitMapsNotNull) {
+ // Use the method that returns both BitMap array and memory size
+ final Pair bitMapsAndMemory =
+ readBitMapsFromBufferWithMemory(byteBuffer, schemaSize);
+ bitMaps = bitMapsAndMemory.getLeft();
+ bitMapsMemorySize = bitMapsAndMemory.getRight();
+ } else {
+ // Calculate memory for empty BitMap array: array header + references
+ bitMaps = new BitMap[schemaSize];
+ bitMapsMemorySize =
+ org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize(
+ NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize);
+ }
+
+ // Add bitMaps memory to total
+ memorySize += bitMapsMemorySize;
+
+ // Deserialize values and calculate memory during deserialization
+ final Object[] values;
+ final long valuesMemorySize;
+
+ final boolean isValuesNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer));
+ if (isValuesNotNull) {
+ // Use the method that returns both values array and memory size
+ final Pair