diff --git a/.github/workflows/cluster-it-1c1d1a.yml b/.github/workflows/cluster-it-1c1d1a.yml index d4c40fa7ad889..67be8f1a5f6c8 100644 --- a/.github/workflows/cluster-it-1c1d1a.yml +++ b/.github/workflows/cluster-it-1c1d1a.yml @@ -41,9 +41,6 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Build AINode - shell: bash - run: mvn clean package -DskipTests -P with-ainode - name: IT Test shell: bash run: | diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java index 357c15c785697..5e418072a7d97 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java @@ -137,4 +137,10 @@ public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportio setProperty("datanode_memory_proportion", dataNodeMemoryProportion); return this; } + + @Override + public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) { + setProperty("query_cost_stat_window", String.valueOf(queryCostStatWindow)); + return this; + } } diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java index e118d6c3a98ff..34fd7e85240cf 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java @@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper { private static final String PROPERTIES_FILE = "iotdb-ainode.properties"; public static final String CONFIG_PATH = "conf"; public static final String SCRIPT_PATH = "sbin"; - public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights"; + public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin"; public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights"; private void replaceAttribute(String[] keys, String[] values, String filePath) { diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java index 1af7cb8f613a8..bba4c964f9578 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java @@ -93,4 +93,9 @@ public DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs) public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion) { return this; } + + @Override + public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) { + return this; + } } diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java index 0ae46ffc70f27..d57015b13964e 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java @@ -51,4 +51,6 @@ DataNodeConfig setLoadTsFileAnalyzeSchemaMemorySizeInBytes( DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs); DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion); + + DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow); } diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java b/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java index 3f96fdf1372fc..0523a3848289b 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java @@ -78,6 +78,13 @@ private void updateConfig(Statement statement, int timeout) throws SQLException statement.setQueryTimeout(timeout); } + /** + * Executes a SQL query on all read statements in parallel. + * + *

Note: For PreparedStatement EXECUTE queries, use the write connection directly instead, + * because PreparedStatements are session-scoped and this method may route queries to different + * nodes where the PreparedStatement doesn't exist. + */ @Override public ResultSet executeQuery(String sql) throws SQLException { return new ClusterTestResultSet(readStatements, readEndpoints, sql, queryTimeout); diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java new file mode 100644 index 0000000000000..44e280eca169b --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeCallInferenceIT { + + private static final String[] WRITE_SQL_IN_TREE = + new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; + + private static final String CALL_INFERENCE_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final int DEFAULT_INPUT_LENGTH = 256; + private static final int DEFAULT_OUTPUT_LENGTH = 48; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareData(WRITE_SQL_IN_TREE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void callInferenceTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + callInferenceTest(statement, modelInfo); + } + } + } + + public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format( + CALL_INFERENCE_SQL_TEMPLATE, + modelInfo.getModelId(), + i, + DEFAULT_INPUT_LENGTH, + DEFAULT_OUTPUT_LENGTH); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output"); + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count); + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java new file mode 100644 index 0000000000000..64029c1e34b8e --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeConcurrentForecastIT { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataForTableModel(); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + private static void prepareDataForTableModel() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE root"); + statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); + } + } + } + + @Test + public void concurrentGPUForecastTest() throws SQLException, InterruptedException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) { + concurrentGPUForecastTest(modelInfo); + } + } + + public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException, InterruptedException { + final int forecastLength = 512; + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + // Single forecast request can be processed successfully + final String forecastSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength); + final int threadCnt = 10; + final int loop = 100; + final String devices = "0,1"; + statement.execute( + String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); + long startTime = System.currentTimeMillis(); + concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength); + long endTime = System.currentTimeMillis(); + LOGGER.info( + String.format( + "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", + modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime)); + statement.execute( + String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java deleted file mode 100644 index a08990d472fe6..0000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import com.google.common.collect.ImmutableSet; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.HashSet; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeConcurrentInferenceIT { - - private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class); - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataForTreeModel(); - prepareDataForTableModel(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - private static void prepareDataForTreeModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root.AI"); - statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)", - i, Math.sin(i * Math.PI / 1440))); - } - } - } - - private static void prepareDataForTableModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root"); - statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); - } - } - } - - // @Test - public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException { - concurrentGPUCallInferenceTest("timer_xl"); - concurrentGPUCallInferenceTest("sundial"); - } - - private void concurrentGPUCallInferenceTest(String modelId) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - concurrentInference( - statement, - String.format( - "CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)", - modelId, predictLength), - threadCnt, - loop, - predictLength); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - String forecastTableFunctionSql = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d"; - String forecastUDTFSql = - "SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM root.AI"; - - @Test - public void concurrentGPUForecastTest() throws SQLException, InterruptedException { - concurrentGPUForecastTest("timer_xl", forecastUDTFSql); - concurrentGPUForecastTest("sundial", forecastUDTFSql); - concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql); - concurrentGPUForecastTest("sundial", forecastTableFunctionSql); - } - - public void concurrentGPUForecastTest(String modelId, String selectSql) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - long startTime = System.currentTimeMillis(); - concurrentInference( - statement, - String.format(selectSql, modelId, predictLength), - threadCnt, - loop, - predictLength); - long endTime = System.currentTimeMillis(); - LOGGER.info( - String.format( - "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", - modelId, threadCnt * loop, threadCnt, loop, endTime - startTime)); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) - throws SQLException, InterruptedException { - Set targetDevices = ImmutableSet.copyOf(device.split(",")); - LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); - for (int retry = 0; retry < 200; retry++) { - Set foundDevices = new HashSet<>(); - try (final ResultSet resultSet = - statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { - while (resultSet.next()) { - String deviceId = resultSet.getString("DeviceId"); - String loadedModelId = resultSet.getString("ModelId"); - int count = resultSet.getInt("Count(instances)"); - LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); - if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { - foundDevices.add(deviceId); - LOGGER.info("Model {} is loaded to device {}", modelId, device); - } - } - if (foundDevices.containsAll(targetDevices)) { - LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); - return; - } - } - TimeUnit.SECONDS.sleep(3); - } - Assert.fail("Model " + modelId + " is not loaded on device " + device); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java new file mode 100644 index 0000000000000..a06656d4adace --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeForecastIT { + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE db"); + statement.execute( + "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void forecastTableFunctionTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + forecastTableFunctionTest(statement, modelInfo); + } + } + } + + public void forecastTableFunctionTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String forecastTableFunctionSQL = + String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertTrue(count > 0); + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java deleted file mode 100644 index 70f7a1d9f9eb7..0000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData; -import static org.junit.Assert.assertEquals; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeInferenceSQLIT { - - static String[] WRITE_SQL_IN_TREE = - new String[] { - "set configuration \"trusted_uri_pattern\"='.*'", - "create model identity using uri \"" + EXAMPLE_MODEL_PATH + "\"", - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; - - static String[] WRITE_SQL_IN_TABLE = - new String[] { - "CREATE DATABASE root", - "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", - }; - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TREE); - prepareTableData(WRITE_SQL_IN_TABLE); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - // @Test - public void callInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void callInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - public void callInferenceTest(Statement statement) throws SQLException { - // SQL0: Invoke timer-sundial and timer-xl to inference, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(sundial, \"select s1 from root.AI\", generateTime=true, predict_length=720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(timer_xl, \"select s2 from root.AI\", generateTime=true, predict_length=256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", generateTime=true)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "CALL INFERENCE(identity, \"select s2,s0,s3,s1 from root.AI\", generateTime=false)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "CALL INFERENCE(naive_forecaster, \"select s0 from root.AI\", predict_length=3, generateTime=true)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "CALL INFERENCE(holtwinters, \"select s0 from root.AI\", predict_length=6, generateTime=true)"; - // TODO: enable following tests after refactor the CALL INFERENCE - - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2,output3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - - // @Test - public void errorCallInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void errorCallInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - public void errorCallInferenceTest(Statement statement) { - String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; - errorTest(statement, sql, "1505: model [notFound404] has not been created."); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", window=head(2))"; - // TODO: enable following tests after refactor the CALL INFERENCE - // errorTest(statement, sql, "701: Window output 2 is not equal to input size of model 7"); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI limit 5\")"; - // errorTest( - // statement, - // sql, - // "301: The number of rows 5 in the input data does not match the model input 7. Try to - // use LIMIT in SQL or WINDOW in CALL INFERENCE"); - sql = "CREATE MODEL 中文 USING URI \"" + EXAMPLE_MODEL_PATH + "\""; - errorTest(statement, sql, "701: ModelId can only contain letters, numbers, and underscores"); - } - - @Test - public void selectForecastTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - // SQL0: Invoke timer-sundial and timer-xl to forecast, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s1 FROM root.AI) ORDER BY time, output_length=>720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s1"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'timer_xl', input=>(SELECT time,s2 FROM root.AI) ORDER BY time, output_length=>256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s2"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s0,s1,s2,s3 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s2,s0,s3,s1 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "SELECT * FROM FORECAST(model_id=>'naive_forecaster', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>3)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "SELECT * FROM FORECAST(model_id=>'holtwinters', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>6)"; - // TODO: enable following tests after refactor the FORECAST - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2,s3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s2,s0,s3,s1"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index 93351c0178526..2ae1b860cd230 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -35,14 +35,14 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.TimeUnit; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; public class AINodeInstanceManagementIT { - private static final int WAITING_TIME_SEC = 30; private static final Set TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1")); @BeforeClass @@ -85,52 +85,18 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter } // Load sundial to each device - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS 0")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - while (resultSet.next()) { - Assert.assertEquals("0", resultSet.getString("DeviceID")); - Assert.assertEquals("Timer-Sundial", resultSet.getString("ModelType")); - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - } - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); // Load timer_xl to each device - statement.execute("LOAD MODEL timer_xl TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - if (resultSet.getString("ModelType").equals("Timer-XL")) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); // Clean every device - statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - statement.execute("UNLOAD MODEL timer_xl FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); + statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES)); + checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } private static final int LOOP_CNT = 10; @@ -141,23 +107,9 @@ public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } } @@ -170,12 +122,7 @@ public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedExcept statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); } - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC * LOOP_CNT); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 2a1461e4a15b6..b92b80aecf321 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -19,6 +19,7 @@ package org.apache.iotdb.ainode.it; +import org.apache.iotdb.ainode.utils.AINodeTestUtils; import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo; import org.apache.iotdb.it.env.EnvFactory; import org.apache.iotdb.it.framework.IoTDBTestRunner; @@ -36,13 +37,8 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; -import java.util.AbstractMap; -import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.junit.Assert.assertEquals; @@ -54,36 +50,6 @@ @Category({AIClusterIT.class}) public class AINodeModelManageIT { - private static final Map BUILT_IN_MODEL_MAP = - Stream.of( - new AbstractMap.SimpleEntry<>( - "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "holtwinters", - new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "exponential_smoothing", - new FakeModelInfo( - "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "naive_forecaster", - new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stl_forecaster", - new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gaussian_hmm", - new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "sundial", new FakeModelInfo("sundial", "Timer-Sundial", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "timer_xl", new FakeModelInfo("timer_xl", "Timer-XL", "BUILT-IN", "ACTIVE"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - @BeforeClass public static void setUp() throws Exception { // Init 1C1D1A cluster environment @@ -95,7 +61,7 @@ public static void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } - @Test + // @Test public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -103,7 +69,7 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup } } - @Test + // @Test public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -114,8 +80,7 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru private void userDefinedModelManagementTest(Statement statement) throws SQLException, InterruptedException { final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = - "create model operationTest using uri \"" + EXAMPLE_MODEL_PATH + "\""; + final String registerSql = "create model operationTest using uri \"" + "\""; final String showSql = "SHOW MODELS operationTest"; final String dropSql = "DROP MODEL operationTest"; @@ -166,7 +131,7 @@ private void userDefinedModelManagementTest(Statement statement) public void dropBuiltInModelErrorTestInTree() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); } } @@ -174,7 +139,7 @@ public void dropBuiltInModelErrorTestInTree() throws SQLException { public void dropBuiltInModelErrorTestInTable() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); } } @@ -208,10 +173,10 @@ private void showBuiltInModelTest(Statement statement) throws SQLException { resultSet.getString(2), resultSet.getString(3), resultSet.getString(4)); - assertTrue(BUILT_IN_MODEL_MAP.containsKey(modelInfo.getModelId())); - assertEquals(BUILT_IN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); + assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId())); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); } } - assertEquals(BUILT_IN_MODEL_MAP.size(), built_in_model_count); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.size(), built_in_model_count); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index cbb0b03b22997..0de90c42925fe 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -19,30 +19,68 @@ package org.apache.iotdb.ainode.utils; -import java.io.File; +import com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.util.AbstractMap; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class AINodeTestUtils { - public static final String EXAMPLE_MODEL_PATH = - "file://" - + System.getProperty("user.dir") - + File.separator - + "src" - + File.separator - + "test" - + File.separator - + "resources" - + File.separator - + "ainode-example"; + public static final Map BUILTIN_LTSM_MAP = + Stream.of( + new AbstractMap.SimpleEntry<>( + "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static final Map BUILTIN_MODEL_MAP; + + static { + Map tmp = + Stream.of( + new AbstractMap.SimpleEntry<>( + "arima", new FakeModelInfo("arima", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "exponential_smoothing", + new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "naive_forecaster", + new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "stl_forecaster", + new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "gaussian_hmm", + new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "stray", new FakeModelInfo("stray", "sktime", "builtin", "active"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + tmp.putAll(BUILTIN_LTSM_MAP); + BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp); + } + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeTestUtils.class); public static void checkHeader(ResultSetMetaData resultSetMetaData, String title) throws SQLException { @@ -94,6 +132,63 @@ public static void concurrentInference( } } + public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 200; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.containsAll(targetDevices)) { + LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is not loaded on device " + device); + } + + public static void checkModelNotOnSpecifiedDevice( + Statement statement, String modelId, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 50; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.isEmpty()) { + LOGGER.info("Model {} is unloaded from devices {}.", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is still loaded on device " + device); + } + public static class FakeModelInfo { private final String modelId; diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java new file mode 100644 index 0000000000000..f06d46201aff2 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.relational.it.db.it; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.TableClusterIT; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; +import org.apache.iotdb.itbase.runtime.ClusterTestConnection; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({TableLocalStandaloneIT.class, TableClusterIT.class}) +public class IoTDBPreparedStatementIT { + private static final String DATABASE_NAME = "test"; + private static final String[] sqls = + new String[] { + "CREATE DATABASE " + DATABASE_NAME, + "USE " + DATABASE_NAME, + "CREATE TABLE test_table(id INT64 FIELD, name STRING FIELD, value DOUBLE FIELD)", + "INSERT INTO test_table VALUES (2025-01-01T00:00:00, 1, 'Alice', 100.5)", + "INSERT INTO test_table VALUES (2025-01-01T00:01:00, 2, 'Bob', 200.3)", + "INSERT INTO test_table VALUES (2025-01-01T00:02:00, 3, 'Charlie', 300.7)", + "INSERT INTO test_table VALUES (2025-01-01T00:03:00, 4, 'David', 400.2)", + "INSERT INTO test_table VALUES (2025-01-01T00:04:00, 5, 'Eve', 500.9)", + }; + + protected static void insertData() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + for (String sql : sqls) { + statement.execute(sql); + } + } catch (Exception e) { + fail("insertData failed: " + e.getMessage()); + } + } + + @BeforeClass + public static void setUp() { + EnvFactory.getEnv().initClusterEnvironment(); + insertData(); + } + + @AfterClass + public static void tearDown() { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + /** + * Execute a prepared statement query and verify the result. For PreparedStatement EXECUTE + * queries, use the write connection directly instead of tableResultSetEqualTest, because + * PreparedStatements are session-scoped and tableResultSetEqualTest may route queries to + * different nodes where the PreparedStatement doesn't exist. + */ + private static void executePreparedStatementAndVerify( + Connection connection, + Statement statement, + String executeSql, + String[] expectedHeader, + String[] expectedRetArray) + throws SQLException { + // Execute with parameters using write connection directly + // In cluster test, we need to use write connection to ensure same session + if (connection instanceof ClusterTestConnection) { + // Use write connection directly for PreparedStatement queries + try (Statement writeStatement = + ((ClusterTestConnection) connection) + .writeConnection + .getUnderlyingConnection() + .createStatement(); + ResultSet resultSet = writeStatement.executeQuery(executeSql)) { + ResultSetMetaData metaData = resultSet.getMetaData(); + + // Verify header + assertEquals(expectedHeader.length, metaData.getColumnCount()); + for (int i = 1; i <= metaData.getColumnCount(); i++) { + assertEquals(expectedHeader[i - 1], metaData.getColumnName(i)); + } + + // Verify data + int cnt = 0; + while (resultSet.next()) { + StringBuilder builder = new StringBuilder(); + for (int i = 1; i <= expectedHeader.length; i++) { + builder.append(resultSet.getString(i)).append(","); + } + assertEquals(expectedRetArray[cnt], builder.toString()); + cnt++; + } + assertEquals(expectedRetArray.length, cnt); + } + } else { + try (ResultSet resultSet = statement.executeQuery(executeSql)) { + ResultSetMetaData metaData = resultSet.getMetaData(); + + // Verify header + assertEquals(expectedHeader.length, metaData.getColumnCount()); + for (int i = 1; i <= metaData.getColumnCount(); i++) { + assertEquals(expectedHeader[i - 1], metaData.getColumnName(i)); + } + + // Verify data + int cnt = 0; + while (resultSet.next()) { + StringBuilder builder = new StringBuilder(); + for (int i = 1; i <= expectedHeader.length; i++) { + builder.append(resultSet.getString(i)).append(","); + } + assertEquals(expectedRetArray[cnt], builder.toString()); + cnt++; + } + assertEquals(expectedRetArray.length, cnt); + } + } + } + + @Test + public void testPrepareAndExecute() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:01:00.000Z,2,Bob,200.3,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement + statement.execute("PREPARE stmt1 FROM SELECT * FROM test_table WHERE id = ?"); + // Execute with parameter using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt1 USING 2", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt1"); + } catch (SQLException e) { + fail("testPrepareAndExecute failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareAndExecuteMultipleTimes() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray1 = new String[] {"2025-01-01T00:00:00.000Z,1,Alice,100.5,"}; + String[] retArray2 = new String[] {"2025-01-01T00:02:00.000Z,3,Charlie,300.7,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement + statement.execute("PREPARE stmt2 FROM SELECT * FROM test_table WHERE id = ?"); + // Execute multiple times with different parameters using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt2 USING 1", expectedHeader, retArray1); + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt2 USING 3", expectedHeader, retArray2); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt2"); + } catch (SQLException e) { + fail("testPrepareAndExecuteMultipleTimes failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareWithMultipleParameters() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:01:00.000Z,2,Bob,200.3,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement with multiple parameters + statement.execute("PREPARE stmt3 FROM SELECT * FROM test_table WHERE id = ? AND value > ?"); + // Execute with multiple parameters using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt3 USING 2, 150.0", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt3"); + } catch (SQLException e) { + fail("testPrepareWithMultipleParameters failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteImmediate() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:03:00.000Z,4,David,400.2,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Execute immediate with SQL string and parameters + tableResultSetEqualTest( + "EXECUTE IMMEDIATE 'SELECT * FROM test_table WHERE id = ?' USING 4", + expectedHeader, + retArray, + DATABASE_NAME); + } catch (SQLException e) { + fail("testExecuteImmediate failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteImmediateWithoutParameters() { + String[] expectedHeader = new String[] {"_col0"}; + String[] retArray = new String[] {"5,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Execute immediate without parameters + tableResultSetEqualTest( + "EXECUTE IMMEDIATE 'SELECT COUNT(*) FROM test_table'", + expectedHeader, + retArray, + DATABASE_NAME); + } catch (SQLException e) { + fail("testExecuteImmediateWithoutParameters failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteImmediateWithMultipleParameters() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:04:00.000Z,5,Eve,500.9,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Execute immediate with multiple parameters + tableResultSetEqualTest( + "EXECUTE IMMEDIATE 'SELECT * FROM test_table WHERE id = ? AND value > ?' USING 5, 450.0", + expectedHeader, + retArray, + DATABASE_NAME); + } catch (SQLException e) { + fail("testExecuteImmediateWithMultipleParameters failed: " + e.getMessage()); + } + } + + @Test + public void testDeallocateNonExistentStatement() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Try to deallocate a non-existent statement + SQLException exception = + assertThrows( + SQLException.class, () -> statement.execute("DEALLOCATE PREPARE non_existent_stmt")); + assertTrue( + exception.getMessage().contains("does not exist") + || exception.getMessage().contains("Prepared statement")); + } catch (SQLException e) { + fail("testDeallocateNonExistentStatement failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteNonExistentStatement() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Try to execute a non-existent statement + SQLException exception = + assertThrows( + SQLException.class, () -> statement.execute("EXECUTE non_existent_stmt USING 1")); + assertTrue( + exception.getMessage().contains("does not exist") + || exception.getMessage().contains("Prepared statement")); + } catch (SQLException e) { + fail("testExecuteNonExistentStatement failed: " + e.getMessage()); + } + } + + @Test + public void testMultiplePreparedStatements() { + String[] expectedHeader1 = new String[] {"time", "id", "name", "value"}; + String[] retArray1 = new String[] {"2025-01-01T00:00:00.000Z,1,Alice,100.5,"}; + String[] expectedHeader2 = new String[] {"_col0"}; + String[] retArray2 = new String[] {"4,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare multiple statements + statement.execute("PREPARE stmt4 FROM SELECT * FROM test_table WHERE id = ?"); + statement.execute("PREPARE stmt5 FROM SELECT COUNT(*) FROM test_table WHERE value > ?"); + // Execute both statements using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt4 USING 1", expectedHeader1, retArray1); + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt5 USING 200.0", expectedHeader2, retArray2); + // Deallocate both + statement.execute("DEALLOCATE PREPARE stmt4"); + statement.execute("DEALLOCATE PREPARE stmt5"); + } catch (SQLException e) { + fail("testMultiplePreparedStatements failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareDuplicateName() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement + statement.execute("PREPARE stmt6 FROM SELECT * FROM test_table WHERE id = ?"); + // Try to prepare another statement with the same name + SQLException exception = + assertThrows( + SQLException.class, + () -> statement.execute("PREPARE stmt6 FROM SELECT * FROM test_table WHERE id = ?")); + assertTrue( + exception.getMessage().contains("already exists") + || exception.getMessage().contains("Prepared statement")); + // Cleanup + statement.execute("DEALLOCATE PREPARE stmt6"); + } catch (SQLException e) { + fail("testPrepareDuplicateName failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareAndExecuteWithAggregation() { + String[] expectedHeader = new String[] {"_col0"}; + String[] retArray = new String[] {"300.40000000000003,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement with aggregation + statement.execute( + "PREPARE stmt7 FROM SELECT AVG(value) FROM test_table WHERE id >= ? AND id <= ?"); + // Execute with parameters using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt7 USING 2, 4", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt7"); + } catch (SQLException e) { + fail("testPrepareAndExecuteWithAggregation failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareAndExecuteWithStringParameter() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:02:00.000Z,3,Charlie,300.7,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement with string parameter + statement.execute("PREPARE stmt8 FROM SELECT * FROM test_table WHERE name = ?"); + // Execute with string parameter using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt8 USING 'Charlie'", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt8"); + } catch (SQLException e) { + fail("testPrepareAndExecuteWithStringParameter failed: " + e.getMessage()); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java new file mode 100644 index 0000000000000..66941ec784f81 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.relational.it.query.recent.informationschema; + +import org.apache.iotdb.commons.conf.CommonDescriptor; +import org.apache.iotdb.db.queryengine.execution.QueryState; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.END_TIME_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.NUMS; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATE_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.table.InformationSchema.getSchemaTables; +import static org.apache.iotdb.db.it.utils.TestUtils.createUser; +import static org.apache.iotdb.itbase.env.BaseEnv.TABLE_SQL_DIALECT; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({TableLocalStandaloneIT.class}) +// This IT will run at least 60s, so we only run it in 1C1D +public class IoTDBCurrentQueriesIT { + private static final int CURRENT_QUERIES_COLUMN_NUM = + getSchemaTables().get("current_queries").getColumnNum(); + private static final int QUERIES_COSTS_HISTOGRAM_COLUMN_NUM = + getSchemaTables().get("queries_costs_histogram").getColumnNum(); + private static final String ADMIN_NAME = + CommonDescriptor.getInstance().getConfig().getDefaultAdminName(); + private static final String ADMIN_PWD = + CommonDescriptor.getInstance().getConfig().getAdminPassword(); + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getDataNodeConfig().setQueryCostStatWindow(1); + EnvFactory.getEnv().initClusterEnvironment(); + createUser("test", "test123123456"); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testCurrentQueries() { + try { + Connection connection = + EnvFactory.getEnv().getConnection(ADMIN_NAME, ADMIN_PWD, BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement(); + statement.execute("USE information_schema"); + + // 1. query current_queries table + String sql = "SELECT * FROM current_queries"; + ResultSet resultSet = statement.executeQuery(sql); + ResultSetMetaData metaData = resultSet.getMetaData(); + Assert.assertEquals(CURRENT_QUERIES_COLUMN_NUM, metaData.getColumnCount()); + int rowNum = 0; + while (resultSet.next()) { + Assert.assertEquals(QueryState.RUNNING.name(), resultSet.getString(STATE_TABLE_MODEL)); + Assert.assertEquals(null, resultSet.getString(END_TIME_TABLE_MODEL)); + Assert.assertEquals(sql, resultSet.getString(STATEMENT_TABLE_MODEL)); + Assert.assertEquals(ADMIN_NAME, resultSet.getString(USER_TABLE_MODEL)); + rowNum++; + } + Assert.assertEquals(1, rowNum); + resultSet.close(); + + // 2. query queries_costs_histogram table + sql = "SELECT * FROM queries_costs_histogram"; + resultSet = statement.executeQuery(sql); + metaData = resultSet.getMetaData(); + Assert.assertEquals(QUERIES_COSTS_HISTOGRAM_COLUMN_NUM, metaData.getColumnCount()); + rowNum = 0; + int queriesCount = 0; + while (resultSet.next()) { + int nums = resultSet.getInt(NUMS); + if (nums > 0) { + queriesCount++; + } + rowNum++; + } + Assert.assertEquals(1, queriesCount); + Assert.assertEquals(61, rowNum); + + // 3. requery current_queries table + sql = "SELECT * FROM current_queries"; + resultSet = statement.executeQuery(sql); + metaData = resultSet.getMetaData(); + Assert.assertEquals(CURRENT_QUERIES_COLUMN_NUM, metaData.getColumnCount()); + rowNum = 0; + int finishedQueries = 0; + while (resultSet.next()) { + if (QueryState.FINISHED.name().equals(resultSet.getString(STATE_TABLE_MODEL))) { + finishedQueries++; + } + rowNum++; + } + // three rows in the result, 2 FINISHED and 1 RUNNING + Assert.assertEquals(3, rowNum); + Assert.assertEquals(2, finishedQueries); + resultSet.close(); + + // 4. test the expired QueryInfo was evicted + Thread.sleep(61_001); + resultSet = statement.executeQuery(sql); + rowNum = 0; + while (resultSet.next()) { + rowNum++; + } + // one row in the result, current query + Assert.assertEquals(1, rowNum); + resultSet.close(); + + sql = "SELECT * FROM queries_costs_histogram"; + resultSet = statement.executeQuery(sql); + queriesCount = 0; + while (resultSet.next()) { + int nums = resultSet.getInt(NUMS); + if (nums > 0) { + queriesCount++; + } + } + // the last current_queries table query was recorded, others are evicted + Assert.assertEquals(1, queriesCount); + } catch (Exception e) { + fail(e.getMessage()); + } + + // 5. test privilege + testPrivilege(); + } + + private void testPrivilege() { + // 1. test current_queries table + try (Connection connection = + EnvFactory.getEnv().getConnection("test", "test123123456", TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + String sql = "SELECT * FROM information_schema.current_queries"; + + // another user executes a query + try (Connection connection2 = + EnvFactory.getEnv().getConnection(ADMIN_NAME, ADMIN_PWD, BaseEnv.TABLE_SQL_DIALECT)) { + ResultSet resultSet = connection2.createStatement().executeQuery(sql); + resultSet.close(); + } catch (Exception e) { + fail(e.getMessage()); + } + + // current user query current_queries table + ResultSet resultSet = statement.executeQuery(sql); + int rowNum = 0; + while (resultSet.next()) { + rowNum++; + } + // only current query in the result + Assert.assertEquals(1, rowNum); + } catch (SQLException e) { + fail(e.getMessage()); + } + + // 2. test queries_costs_histogram table + try (Connection connection = + EnvFactory.getEnv().getConnection("test", "test123123456", TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.executeQuery("SELECT * FROM information_schema.queries_costs_histogram"); + } catch (SQLException e) { + Assert.assertEquals( + "803: Access Denied: No permissions for this operation, please add privilege SYSTEM", + e.getMessage()); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index e04ff838819e6..4736d9b0521d9 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -399,15 +399,16 @@ public void testInformationSchema() throws SQLException { "config_nodes,INF,", "configurations,INF,", "connections,INF,", + "current_queries,INF,", "data_nodes,INF,", "databases,INF,", "functions,INF,", "keywords,INF,", - "models,INF,", "nodes,INF,", "pipe_plugins,INF,", "pipes,INF,", "queries,INF,", + "queries_costs_histogram,INF,", "regions,INF,", "subscriptions,INF,", "tables,INF,", @@ -504,16 +505,6 @@ public void testInformationSchema() throws SQLException { "database,STRING,TAG,", "table_name,STRING,TAG,", "view_definition,STRING,ATTRIBUTE,"))); - TestUtils.assertResultSetEqual( - statement.executeQuery("desc models"), - "ColumnName,DataType,Category,", - new HashSet<>( - Arrays.asList( - "model_id,STRING,TAG,", - "model_type,STRING,ATTRIBUTE,", - "state,STRING,ATTRIBUTE,", - "configs,STRING,ATTRIBUTE,", - "notes,STRING,ATTRIBUTE,"))); TestUtils.assertResultSetEqual( statement.executeQuery("desc functions"), "ColumnName,DataType,Category,", @@ -638,7 +629,6 @@ public void testInformationSchema() throws SQLException { "information_schema,pipes,INF,USING,null,SYSTEM VIEW,", "information_schema,subscriptions,INF,USING,null,SYSTEM VIEW,", "information_schema,views,INF,USING,null,SYSTEM VIEW,", - "information_schema,models,INF,USING,null,SYSTEM VIEW,", "information_schema,functions,INF,USING,null,SYSTEM VIEW,", "information_schema,configurations,INF,USING,null,SYSTEM VIEW,", "information_schema,keywords,INF,USING,null,SYSTEM VIEW,", @@ -646,12 +636,14 @@ public void testInformationSchema() throws SQLException { "information_schema,config_nodes,INF,USING,null,SYSTEM VIEW,", "information_schema,data_nodes,INF,USING,null,SYSTEM VIEW,", "information_schema,connections,INF,USING,null,SYSTEM VIEW,", + "information_schema,current_queries,INF,USING,null,SYSTEM VIEW,", + "information_schema,queries_costs_histogram,INF,USING,null,SYSTEM VIEW,", "test,test,INF,USING,test,BASE TABLE,", "test,view_table,100,USING,null,VIEW FROM TREE,"))); TestUtils.assertResultSetEqual( statement.executeQuery("count devices from tables where status = 'USING'"), "count(devices),", - Collections.singleton("20,")); + Collections.singleton("21,")); TestUtils.assertResultSetEqual( statement.executeQuery( "select * from columns where table_name = 'queries' or database = 'test'"), diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java index 36dc53a4c41b3..748a6937bcf63 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBTableIT.java @@ -816,9 +816,7 @@ private void testObject4SingleIllegalPath(final String illegal) throws Exception + File.separator + "test-classes" + File.separator - + "ainode-example" - + File.separator - + "model.pt"; + + "object-example.pt"; List schemaList = new ArrayList<>(); schemaList.add(new MeasurementSchema("a", TSDataType.STRING)); diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java index 7f63d19030291..49884e3c70ba7 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/session/IoTDBSessionRelationalIT.java @@ -861,9 +861,7 @@ public void insertObjectTest() + File.separator + "test-classes" + File.separator - + "ainode-example" - + File.separator - + "model.pt"; + + "object-example.pt"; File object = new File(testObject); try (ITableSession session = EnvFactory.getEnv().getTableSessionConnection()) { @@ -929,9 +927,7 @@ public void insertObjectSegmentsTest() + File.separator + "test-classes" + File.separator - + "ainode-example" - + File.separator - + "model.pt"; + + "object-example.pt"; byte[] objectBytes = Files.readAllBytes(Paths.get(testObject)); List objectSegments = new ArrayList<>(); for (int i = 0; i < objectBytes.length; i += 512) { diff --git a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java index bc043c79e8a6f..80aa854e66cc1 100644 --- a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java @@ -176,7 +176,7 @@ public void testSameDataNodeGetConnections() { ResultSet resultSet = statement.executeQuery( - "SELECT * FROM connections WHERE data_node_id = '" + dataNodeId + "'"); + "SELECT * FROM connections WHERE datanode_id = '" + dataNodeId + "'"); if (!resultSet.next()) { fail(); } @@ -256,7 +256,7 @@ public void testClosedDataNodeGetConnections() throws Exception { ResultSet resultSet = statement.executeQuery( - "SELECT COUNT(*) FROM connections WHERE data_node_id = '" + closedDataNodeId + "'"); + "SELECT COUNT(*) FROM connections WHERE datanode_id = '" + closedDataNodeId + "'"); if (!resultSet.next()) { fail(); } @@ -306,7 +306,7 @@ public void testClosedDataNodeGetConnections() throws Exception { ResultSet resultSet = statement.executeQuery( - "SELECT COUNT(*) FROM connections WHERE data_node_id = '" + closedDataNodeId + "'"); + "SELECT COUNT(*) FROM connections WHERE datanode_id = '" + closedDataNodeId + "'"); if (!resultSet.next()) { fail(); } diff --git a/integration-test/src/test/resources/ainode-example/config.yaml b/integration-test/src/test/resources/ainode-example/config.yaml deleted file mode 100644 index 665acb8704e24..0000000000000 --- a/integration-test/src/test/resources/ainode-example/config.yaml +++ /dev/null @@ -1,5 +0,0 @@ -configs: - input_shape: [7, 4] - output_shape: [7, 4] - input_type: ["float32", "float32", "float32", "float32"] - output_type: ["float32", "float32", "float32", "float32"] diff --git a/integration-test/src/test/resources/ainode-example/model.pt b/integration-test/src/test/resources/object-example.pt similarity index 100% rename from integration-test/src/test/resources/ainode-example/model.pt rename to integration-test/src/test/resources/object-example.pt diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index a131b2bcff217..bde8b845fb8c5 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -41,14 +41,24 @@ all_hiddenimports = [] # Only collect essential data files and binaries for critical libraries # This reduces startup time by avoiding unnecessary module imports essential_libraries = { - 'torch': True, # Keep collect_all for torch as it has many dynamic imports - 'transformers': True, # Keep collect_all for transformers - 'safetensors': True, # Keep collect_all for safetensors + 'torch': True, + 'transformers': True, + 'tokenizers': True, + 'huggingface_hub': True, + 'safetensors': True, + 'hf_xet': True, + 'numpy': True, + 'scipy': True, + 'pandas': True, + 'sklearn': True, + 'statsmodels': True, + 'sktime': True, + 'pmdarima': True, + 'hmmlearn': True, + 'accelerate': True } -# For other libraries, use selective collection to speed up startup -other_libraries = ['sktime', 'scipy', 'pandas', 'sklearn', 'statsmodels', 'optuna'] - +# Collect all libraries using collect_all (includes data files and binaries) for lib in essential_libraries: try: lib_datas, lib_binaries, lib_hiddenimports = collect_all(lib) @@ -58,48 +68,105 @@ for lib in essential_libraries: except Exception: pass -# For other libraries, only collect submodules (lighter weight) -# This relies on PyInstaller's dependency analysis to include what's actually used -for lib in other_libraries: +# Additionally collect ALL submodules for libraries that commonly have dynamic imports +# This is a more aggressive approach but ensures we don't miss any modules +# Libraries that are known to have many dynamic imports and submodules +libraries_with_dynamic_imports = [ + 'scipy', # Has many subpackages: stats, interpolate, optimize, linalg, sparse, signal, etc. + 'sklearn', # Has many submodules that may be dynamically imported + 'transformers', # Has dynamic model loading + 'torch', # Has many submodules, especially _dynamo.polyfills +] + +# Collect all submodules for these libraries to ensure comprehensive coverage +for lib in libraries_with_dynamic_imports: try: submodules = collect_submodules(lib) all_hiddenimports.extend(submodules) - # Only collect essential data files and binaries, not all submodules - # This significantly reduces startup time - try: - lib_datas, lib_binaries, _ = collect_all(lib) - all_datas.extend(lib_datas) - all_binaries.extend(lib_binaries) - except Exception: - # If collect_all fails, try collect_data_files for essential data only - try: - lib_datas = collect_data_files(lib) - all_datas.extend(lib_datas) - except Exception: - pass - except Exception: - pass + print(f"Collected {len(submodules)} submodules from {lib}") + except Exception as e: + print(f"Warning: Failed to collect submodules from {lib}: {e}") + + +# Helper function to collect submodules with fallback +def collect_submodules_with_fallback(package, fallback_modules=None, package_name=None): + """ + Collect all submodules for a package, with fallback to manual module list if collection fails. + + Args: + package: Package name to collect submodules from + fallback_modules: List of module names to add if collection fails (optional) + package_name: Display name for logging (defaults to package) + """ + if package_name is None: + package_name = package + try: + submodules = collect_submodules(package) + all_hiddenimports.extend(submodules) + print(f"Collected {len(submodules)} submodules from {package_name}") + except Exception as e: + print(f"Warning: Failed to collect {package_name} submodules: {e}") + if fallback_modules: + all_hiddenimports.extend(fallback_modules) + print(f"Using fallback modules for {package_name}") + + +# Additional specific packages that need submodule collection +# Note: scipy, sklearn, transformers, torch are already collected above via libraries_with_dynamic_imports +# This section is for more specific sub-packages that need special handling +# Format: (package_name, fallback_modules_list, display_name) +submodule_collection_configs = [ + # torch._dynamo.polyfills - critical for torch dynamo functionality + # (torch is already collected above, but this ensures polyfills are included) + ( + 'torch._dynamo.polyfills', + [ + 'torch._dynamo.polyfills', + 'torch._dynamo.polyfills.functools', + 'torch._dynamo.polyfills.operator', + 'torch._dynamo.polyfills.collections', + ], + 'torch._dynamo.polyfills' + ), + # transformers sub-packages with dynamic imports + # (transformers is already collected above, but these specific sub-packages may need extra attention) + ('transformers.generation', None, 'transformers.generation'), + ('transformers.models.auto', None, 'transformers.models.auto'), +] + +# Collect submodules for all configured packages +for package, fallback_modules, display_name in submodule_collection_configs: + collect_submodules_with_fallback(package, fallback_modules, display_name) # Project-specific packages that need their submodules collected # Only list top-level packages - collect_submodules will recursively collect all submodules -TOP_LEVEL_PACKAGES = [ +project_packages = [ 'iotdb.ainode.core', # This will include all sub-packages: manager, model, inference, etc. 'iotdb.thrift', # This will include all thrift sub-packages ] # Collect all submodules for project packages automatically # Using top-level packages avoids duplicate collection -for package in TOP_LEVEL_PACKAGES: - try: - submodules = collect_submodules(package) - all_hiddenimports.extend(submodules) - except Exception: - # If package doesn't exist or collection fails, add the package itself - all_hiddenimports.append(package) +# If collection fails, add the package itself as fallback +for package in project_packages: + collect_submodules_with_fallback(package, fallback_modules=[package], package_name=package) # Add parent packages to ensure they are included all_hiddenimports.extend(['iotdb', 'iotdb.ainode']) +# Fix circular import issues in scipy.stats +# scipy.stats has circular imports that can cause issues in PyInstaller +# We need to ensure _stats is imported before scipy.stats tries to import it +# This helps resolve the "partially initialized module" error +scipy_stats_critical_modules = [ + 'scipy.stats._stats', # Core stats module, must be imported first + 'scipy.stats._stats_py', # Python implementation + 'scipy.stats._continuous_distns', # Continuous distributions + 'scipy.stats._discrete_distns', # Discrete distributions + 'scipy.stats.distributions', # Distribution base classes +] +all_hiddenimports.extend(scipy_stats_critical_modules) + # Multiprocessing support for PyInstaller # When using multiprocessing with PyInstaller, we need to ensure proper handling multiprocessing_modules = [ @@ -119,9 +186,6 @@ multiprocessing_modules = [ # Additional dependencies that may need explicit import # These are external libraries that might use dynamic imports external_dependencies = [ - 'huggingface_hub', - 'tokenizers', - 'hf_xet', 'einops', 'dynaconf', 'tzlocal', @@ -161,7 +225,9 @@ a = Analysis( win_no_prefer_redirects=False, win_private_assemblies=False, cipher=block_cipher, - noarchive=True, # Set to True to speed up startup - files are not archived into PYZ + noarchive=False, # Set to False to avoid circular import issues with scipy.stats + # When noarchive=True, modules are loaded as separate files which can cause + # circular import issues. Using PYZ archive helps PyInstaller handle module loading order better. ) # Package all PYZ files diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index afcf0683d7d04..e465df7e36d29 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -20,7 +20,6 @@ from iotdb.ainode.core.constant import ( AINODE_BUILD_INFO, - AINODE_BUILTIN_MODELS_DIR, AINODE_CLUSTER_INGRESS_ADDRESS, AINODE_CLUSTER_INGRESS_PASSWORD, AINODE_CLUSTER_INGRESS_PORT, @@ -33,10 +32,11 @@ AINODE_CONF_POM_FILE_NAME, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, - AINODE_INFERENCE_MAX_PREDICT_LENGTH, + AINODE_INFERENCE_MAX_OUTPUT_LENGTH, AINODE_INFERENCE_MEMORY_USAGE_RATIO, AINODE_INFERENCE_MODEL_MEM_USAGE_MAP, AINODE_LOG_DIR, + AINODE_MODELS_BUILTIN_DIR, AINODE_MODELS_DIR, AINODE_RPC_ADDRESS, AINODE_RPC_PORT, @@ -75,9 +75,7 @@ def __init__(self): self._ain_inference_batch_interval_in_ms: int = ( AINODE_INFERENCE_BATCH_INTERVAL_IN_MS ) - self._ain_inference_max_predict_length: int = ( - AINODE_INFERENCE_MAX_PREDICT_LENGTH - ) + self._ain_inference_max_output_length: int = AINODE_INFERENCE_MAX_OUTPUT_LENGTH self._ain_inference_model_mem_usage_map: dict[str, int] = ( AINODE_INFERENCE_MODEL_MEM_USAGE_MAP ) @@ -95,7 +93,7 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR - self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR + self._ain_models_builtin_dir = AINODE_MODELS_BUILTIN_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -160,13 +158,13 @@ def set_ain_inference_batch_interval_in_ms( ) -> None: self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms - def get_ain_inference_max_predict_length(self) -> int: - return self._ain_inference_max_predict_length + def get_ain_inference_max_output_length(self) -> int: + return self._ain_inference_max_output_length - def set_ain_inference_max_predict_length( - self, ain_inference_max_predict_length: int + def set_ain_inference_max_output_length( + self, ain_inference_max_output_length: int ) -> None: - self._ain_inference_max_predict_length = ain_inference_max_predict_length + self._ain_inference_max_output_length = ain_inference_max_output_length def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]: return self._ain_inference_model_mem_usage_map @@ -204,11 +202,11 @@ def get_ain_models_dir(self) -> str: def set_ain_models_dir(self, ain_models_dir: str) -> None: self._ain_models_dir = ain_models_dir - def get_ain_builtin_models_dir(self) -> str: - return self._ain_builtin_models_dir + def get_ain_models_builtin_dir(self) -> str: + return self._ain_models_builtin_dir - def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None: - self._ain_builtin_models_dir = ain_builtin_models_dir + def set_ain_models_builtin_dir(self, ain_models_builtin_dir: str) -> None: + self._ain_models_builtin_dir = ain_models_builtin_dir def get_ain_system_dir(self) -> str: return self._ain_system_dir @@ -374,6 +372,11 @@ def _load_config_from_file(self) -> None: if "ain_models_dir" in config_keys: self._config.set_ain_models_dir(file_configs["ain_models_dir"]) + if "ain_models_builtin_dir" in config_keys: + self._config.set_ain_models_builtin_dir( + file_configs["ain_models_builtin_dir"] + ) + if "ain_system_dir" in config_keys: self._config.set_ain_system_dir(file_configs["ain_system_dir"]) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index b9923d3e3ee7e..d8f730c829c89 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -18,9 +18,7 @@ import logging import os from enum import Enum -from typing import List -from iotdb.ainode.core.model.model_enums import BuiltInModelType from iotdb.thrift.common.ttypes import TEndPoint IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "") @@ -50,21 +48,21 @@ # AINode inference configuration AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 -AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 +AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880 + +# TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { - BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB - BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB + "sundial": 1036 * 1024**2, # 1036 MiB + "timer": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes + AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( 1.2 # the overhead ratio for inference, used to estimate the pool size ) -# AINode folder structure AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") -AINODE_BUILTIN_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/weights" -) # For built-in models, we only need to store their weights and config. +AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model" AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs") @@ -77,11 +75,6 @@ "log_inference_rank_{}_" # example: log_inference_rank_0_all.log ) -# AINode model management -MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" -MODEL_CONFIG_FILE_IN_JSON = "config.json" -MODEL_WEIGHTS_FILE_IN_PT = "model.pt" -MODEL_CONFIG_FILE_IN_YAML = "config.yaml" DEFAULT_CHUNK_SIZE = 8192 @@ -100,27 +93,6 @@ def get_status_code(self) -> int: return self.value -class TaskType(Enum): - FORECAST = "forecast" - - -class OptionsKey(Enum): - # common - TASK_TYPE = "task_type" - MODEL_TYPE = "model_type" - AUTO_TUNING = "auto_tuning" - INPUT_VARS = "input_vars" - - # forecast - INPUT_LENGTH = "input_length" - PREDICT_LENGTH = "predict_length" - PREDICT_INDEX_LIST = "predict_index_list" - INPUT_TYPE_LIST = "input_type_list" - - def name(self) -> str: - return self.value - - class HyperparameterName(Enum): # Training hyperparameter LEARNING_RATE = "learning_rate" @@ -139,134 +111,3 @@ class HyperparameterName(Enum): def name(self): return self.value - - -class ForecastModelType(Enum): - DLINEAR = "dlinear" - DLINEAR_INDIVIDUAL = "dlinear_individual" - NBEATS = "nbeats" - - @classmethod - def values(cls) -> List[str]: - values = [] - for item in list(cls): - values.append(item.value) - return values - - -class ModelInputName(Enum): - DATA_X = "data_x" - TIME_STAMP_X = "time_stamp_x" - TIME_STAMP_Y = "time_stamp_y" - DEC_INP = "dec_inp" - - -class AttributeName(Enum): - # forecast Attribute - PREDICT_LENGTH = "predict_length" - - # NaiveForecaster - STRATEGY = "strategy" - SP = "sp" - - # STLForecaster - # SP = 'sp' - SEASONAL = "seasonal" - SEASONAL_DEG = "seasonal_deg" - TREND_DEG = "trend_deg" - LOW_PASS_DEG = "low_pass_deg" - SEASONAL_JUMP = "seasonal_jump" - TREND_JUMP = "trend_jump" - LOSS_PASS_JUMP = "low_pass_jump" - - # ExponentialSmoothing - DAMPED_TREND = "damped_trend" - INITIALIZATION_METHOD = "initialization_method" - OPTIMIZED = "optimized" - REMOVE_BIAS = "remove_bias" - USE_BRUTE = "use_brute" - - # Arima - ORDER = "order" - SEASONAL_ORDER = "seasonal_order" - METHOD = "method" - MAXITER = "maxiter" - SUPPRESS_WARNINGS = "suppress_warnings" - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" - SCORING = "scoring" - WITH_INTERCEPT = "with_intercept" - TIME_VARYING_REGRESSION = "time_varying_regression" - ENFORCE_STATIONARITY = "enforce_stationarity" - ENFORCE_INVERTIBILITY = "enforce_invertibility" - SIMPLE_DIFFERENCING = "simple_differencing" - MEASUREMENT_ERROR = "measurement_error" - MLE_REGRESSION = "mle_regression" - HAMILTON_REPRESENTATION = "hamilton_representation" - CONCENTRATE_SCALE = "concentrate_scale" - - # GAUSSIAN_HMM - N_COMPONENTS = "n_components" - COVARIANCE_TYPE = "covariance_type" - MIN_COVAR = "min_covar" - STARTPROB_PRIOR = "startprob_prior" - TRANSMAT_PRIOR = "transmat_prior" - MEANS_PRIOR = "means_prior" - MEANS_WEIGHT = "means_weight" - COVARS_PRIOR = "covars_prior" - COVARS_WEIGHT = "covars_weight" - ALGORITHM = "algorithm" - N_ITER = "n_iter" - TOL = "tol" - PARAMS = "params" - INIT_PARAMS = "init_params" - IMPLEMENTATION = "implementation" - - # GMMHMM - # N_COMPONENTS = "n_components" - N_MIX = "n_mix" - # MIN_COVAR = "min_covar" - # STARTPROB_PRIOR = "startprob_prior" - # TRANSMAT_PRIOR = "transmat_prior" - WEIGHTS_PRIOR = "weights_prior" - - # MEANS_PRIOR = "means_prior" - # MEANS_WEIGHT = "means_weight" - # ALGORITHM = "algorithm" - # COVARIANCE_TYPE = "covariance_type" - # N_ITER = "n_iter" - # TOL = "tol" - # INIT_PARAMS = "init_params" - # PARAMS = "params" - # IMPLEMENTATION = "implementation" - - # STRAY - ALPHA = "alpha" - K = "k" - KNN_ALGORITHM = "knn_algorithm" - P = "p" - SIZE_THRESHOLD = "size_threshold" - OUTLIER_TAIL = "outlier_tail" - - # timerxl - INPUT_TOKEN_LEN = "input_token_len" - HIDDEN_SIZE = "hidden_size" - INTERMEDIATE_SIZE = "intermediate_size" - OUTPUT_TOKEN_LENS = "output_token_lens" - NUM_HIDDEN_LAYERS = "num_hidden_layers" - NUM_ATTENTION_HEADS = "num_attention_heads" - HIDDEN_ACT = "hidden_act" - USE_CACHE = "use_cache" - ROPE_THETA = "rope_theta" - ATTENTION_DROPOUT = "attention_dropout" - INITIALIZER_RANGE = "initializer_range" - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - CKPT_PATH = "ckpt_path" - - # sundial - DROPOUT_RATE = "dropout_rate" - FLOW_LOSS_DEPTH = "flow_loss_depth" - NUM_SAMPLING_STEPS = "num_sampling_steps" - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" - - def name(self) -> str: - return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index bc89cdc306625..30b9d54dcc7df 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -17,7 +17,7 @@ # import re -from iotdb.ainode.core.constant import ( +from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_YAML, MODEL_WEIGHTS_FILE_IN_PT, ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index 82c72cc37abf5..50634914c2737 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. # + import threading from typing import Any import torch -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.atmoic_int import AtomicInt @@ -41,8 +39,7 @@ def __init__( req_id: str, model_id: str, inputs: torch.Tensor, - inference_pipeline: AbstractInferencePipeline, - max_new_tokens: int = 96, + output_length: int = 96, **infer_kwargs, ): if inputs.ndim == 1: @@ -52,9 +49,8 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.inference_pipeline = inference_pipeline - self.max_new_tokens = ( - max_new_tokens # Number of time series data points to generate + self.output_length = ( + output_length # Number of time series data points to generate ) self.batch_size = inputs.size(0) @@ -65,7 +61,7 @@ def __init__( # Preallocate output buffer [batch_size, max_new_tokens] self.output_tensor = torch.zeros( - self.batch_size, max_new_tokens, device="cpu" + self.batch_size, output_length, device="cpu" ) # shape: [self.batch_size, max_new_steps] def mark_running(self): @@ -77,7 +73,7 @@ def mark_finished(self): def is_finished(self) -> bool: return ( self.state == InferenceRequestState.FINISHED - or self.cur_step_idx >= self.max_new_tokens + or self.cur_step_idx >= self.output_length ) def write_step_output(self, step_output: torch.Tensor): @@ -87,11 +83,11 @@ def write_step_output(self, step_output: torch.Tensor): batch_size, step_size = step_output.shape end_idx = self.cur_step_idx + step_size - if end_idx > self.max_new_tokens: + if end_idx > self.output_length: self.output_tensor[:, self.cur_step_idx :] = step_output[ - :, : self.max_new_tokens - self.cur_step_idx + :, : self.output_length - self.cur_step_idx ] - self.cur_step_idx = self.max_new_tokens + self.cur_step_idx = self.output_length else: self.output_tensor[:, self.cur_step_idx : end_idx] = step_output self.cur_step_idx = end_idx diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 6b054c91fe31c..a6c415a6c848b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -25,19 +25,22 @@ import numpy as np import torch import torch.multiprocessing as mp -from transformers import PretrainedConfig from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, +) +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device @@ -62,7 +65,6 @@ def __init__( pool_id: int, model_info: ModelInfo, device: str, - config: PretrainedConfig, request_queue: mp.Queue, result_queue: mp.Queue, ready_event, @@ -71,7 +73,6 @@ def __init__( super().__init__() self.pool_id = pool_id self.model_info = model_info - self.config = config self.pool_kwargs = pool_kwargs self.ready_event = ready_event self.device = convert_device_id_to_torch_device(device) @@ -86,8 +87,8 @@ def __init__( self._batcher = BasicBatcher() self._stop_event = mp.Event() - self._model = None - self._model_manager = None + self._inference_pipeline = None + self._logger = None # Fix inference seed @@ -98,9 +99,6 @@ def __init__( def _activate_requests(self): requests = self._request_scheduler.schedule_activate() for request in requests: - request.inputs = request.inference_pipeline.preprocess_inputs( - request.inputs - ) request.mark_running() self._running_queue.put(request) self._logger.debug( @@ -117,72 +115,51 @@ def _step(self): grouped_requests = defaultdict(list) for req in all_requests: - key = (req.inputs.shape[1], req.max_new_tokens) + key = (req.inputs.shape[1], req.output_length) grouped_requests[key].append(req) grouped_requests = list(grouped_requests.values()) for requests in grouped_requests: batch_inputs = self._batcher.batch_request(requests).to(self.device) - if self.model_info.model_type == BuiltInModelType.SUNDIAL.value: - batch_output = self._model.generate( + if isinstance(self._inference_pipeline, ForecastPipeline): + batch_output = self._inference_pipeline.forecast( batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - num_samples=10, + predict_length=requests[0].output_length, revin=True, ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output.mean(dim=1)) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) - - elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value: - batch_output = self._model.generate( + elif isinstance(self._inference_pipeline, ClassificationPipeline): + batch_output = self._inference_pipeline.classify( batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - revin=True, + # more infer kwargs can be added here ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) + elif isinstance(self._inference_pipeline, ChatPipeline): + batch_output = self._inference_pipeline.chat( + batch_inputs, + # more infer kwargs can be added here + ) + else: + self._logger.error("[Inference] Unsupported pipeline type.") + offset = 0 + for request in requests: + request.output_tensor = request.output_tensor.to(self.device) + cur_batch_size = request.batch_size + cur_output = batch_output[offset : offset + cur_batch_size] + offset += cur_batch_size + request.write_step_output(cur_output) + + if request.is_finished(): + # ensure the output tensor is on CPU before sending to result queue + request.output_tensor = request.output_tensor.cpu() + self._finished_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" + ) + else: + self._waiting_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" + ) + return def _requests_execute_loop(self): while not self._stop_event.is_set(): @@ -193,11 +170,8 @@ def run(self): self._logger = Logger( INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) - self._model_manager = ModelManager() self._request_scheduler.device = self.device - self._model = self._model_manager.load_model(self.model_info.model_id, {}).to( - self.device - ) + self._inference_pipeline = load_pipeline(self.model_info, str(self.device)) self.ready_event.set() activate_daemon = threading.Thread( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py new file mode 100644 index 0000000000000..82601e398059c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import ABC, abstractmethod + +import torch + +from iotdb.ainode.core.model.model_loader import load_model + + +class BasicPipeline(ABC): + def __init__(self, model_info, **model_kwargs): + self.model_info = model_info + self.device = model_kwargs.get("device", "cpu") + self.model = load_model(model_info, device_map=self.device, **model_kwargs) + + def _preprocess(self, inputs): + """ + Preprocess the input before inference, including shape validation and value transformation. + """ + return inputs + + def _postprocess(self, output: torch.Tensor): + """ + Post-process the outputs after the entire inference task. + """ + return output + + +class ForecastPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def forecast(self, inputs, **infer_kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output + + +class ClassificationPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def classify(self, inputs, **kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output + + +class ChatPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def chat(self, inputs, **kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py new file mode 100644 index 0000000000000..a30038dd5feff --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from pathlib import Path + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_storage import ModelInfo +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path + +logger = Logger() + + +def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs): + if model_info.model_type == "sktime": + from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline + + pipeline_cls = SktimePipeline + elif model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls) + else: + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + pipeline_cls = import_class_from_path( + model_info.model_id, model_info.pipeline_cls + ) + + return pipeline_cls(model_info, device=device, **model_kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 54580402ec293..8ffa89ffd6752 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -22,7 +22,6 @@ from concurrent.futures import wait from typing import Dict, Optional -import torch import torch.multiprocessing as mp from iotdb.ainode.core.exception import InferenceModelInternalError @@ -41,9 +40,6 @@ ) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig from iotdb.ainode.core.util.atmoic_int import AtomicInt from iotdb.ainode.core.util.batch_executor import BatchExecutor from iotdb.ainode.core.util.decorator import synchronized @@ -76,9 +72,9 @@ def __init__(self, result_queue: mp.Queue): thread_name_prefix=ThreadName.INFERENCE_POOL_CONTROLLER.value ) - # =============== Pool Management =============== + # =============== Automatic Pool Management (Developing) =============== @synchronized(threading.Lock()) - def first_req_init(self, model_id: str): + def first_req_init(self, model_id: str, device): """ Initialize the pools when the first request for the given model_id arrives. """ @@ -107,38 +103,35 @@ def _first_pool_init(self, model_id: str, device_str: str): Initialize the first pool for the given model_id. Ensure the pool is ready before returning. """ - device = torch.device(device_str) - device_id = device.index - - if model_id == "sundial": - config = SundialConfig() - elif model_id == "timer_xl": - config = TimerConfig() - first_queue = mp.Queue() - ready_event = mp.Event() - first_pool = InferenceRequestPool( - pool_id=0, - model_id=model_id, - device=device_str, - config=config, - request_queue=first_queue, - result_queue=self._result_queue, - ready_event=ready_event, - ) - first_pool.start() - self._register_pool(model_id, device_str, 0, first_pool, first_queue) - - if not ready_event.wait(timeout=30): - self._erase_pool(model_id, device_id, 0) - logger.error( - f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" - ) - else: - self.set_state(model_id, device_id, 0, PoolState.RUNNING) - logger.info( - f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" - ) + pass + # device = torch.device(device_str) + # device_id = device.index + # + # first_queue = mp.Queue() + # ready_event = mp.Event() + # first_pool = InferenceRequestPool( + # pool_id=0, + # model_id=model_id, + # device=device_str, + # request_queue=first_queue, + # result_queue=self._result_queue, + # ready_event=ready_event, + # ) + # first_pool.start() + # self._register_pool(model_id, device_str, 0, first_pool, first_queue) + # + # if not ready_event.wait(timeout=30): + # self._erase_pool(model_id, device_id, 0) + # logger.error( + # f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" + # ) + # else: + # self.set_state(model_id, device_id, 0, PoolState.RUNNING) + # logger.info( + # f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" + # ) + # =============== Pool Management =============== def load_model(self, model_id: str, device_id_list: list[str]): """ Load the model to the specified devices asynchronously. @@ -255,29 +248,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int): """ def _expand_pool_on_device(*_): - result_queue = mp.Queue() + request_queue = mp.Queue() pool_id = self._new_pool_id.get_and_increment() model_info = self._model_manager.get_model_info(model_id) - model_type = model_info.model_type - if model_type == BuiltInModelType.SUNDIAL.value: - config = SundialConfig() - elif model_type == BuiltInModelType.TIMER_XL.value: - config = TimerConfig() - else: - raise InferenceModelInternalError( - f"Unsupported model type {model_type} for loading model {model_id}" - ) pool = InferenceRequestPool( pool_id=pool_id, model_info=model_info, device=device_id, - config=config, - request_queue=result_queue, + request_queue=request_queue, result_queue=self._result_queue, ready_event=mp.Event(), ) pool.start() - self._register_pool(model_id, device_id, pool_id, pool, result_queue) + self._register_pool(model_id, device_id, pool_id, pool, request_queue) if not pool.ready_event.wait(timeout=300): logger.error( f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool failed to be ready in time" diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 6a2bd2b619aa7..d2e7292ecd8ff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -36,7 +36,7 @@ estimate_pool_size, evaluate_system_resources, ) -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo +from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device logger = Logger() diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py deleted file mode 100644 index 2300169a6ee93..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from abc import ABC, abstractmethod - -import torch - - -class AbstractInferencePipeline(ABC): - """ - Abstract assistance strategy class for model inference. - This class shall define the interface process for specific model. - """ - - def __init__(self, model_config, **infer_kwargs): - self.model_config = model_config - self.infer_kwargs = infer_kwargs - - @abstractmethod - def preprocess_inputs(self, inputs: torch.Tensor): - """ - Preprocess the inputs before inference, including shape validation and value transformation. - - Args: - inputs (torch.Tensor): The input tensor to be preprocessed. - - Returns: - torch.Tensor: The preprocessed input tensor. - """ - # TODO: Integrate with the data processing pipeline operators - pass - - @abstractmethod - def post_decode(self): - """ - Post-process the outputs after each decode step. - """ - pass - - @abstractmethod - def post_inference(self): - """ - Post-process the outputs after the entire inference task. - """ - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py index cf10b5b2cd4dc..d17f9fbcec536 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py @@ -19,7 +19,6 @@ import string import torch -from transformers.modeling_outputs import MoeCausalLMOutputWithPast def generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str: @@ -56,25 +55,25 @@ def _slice_pkv(pkv, s, e): return out -def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): - """ - split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) - split_sizes[i] = ith request's batch_size。 - """ - outs = [] - start = 0 - for bsz in split_sizes: - end = start + bsz - outs.append( - MoeCausalLMOutputWithPast( - loss=_slice_tensor(batch_out.loss, start, end), - logits=batch_out.logits[start:end], - past_key_values=_slice_pkv(batch_out.past_key_values, start, end), - hidden_states=_slice_tuple_of_tensors( - batch_out.hidden_states, start, end - ), - attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), - ) - ) - start = end - return outs +# def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): +# """ +# split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) +# split_sizes[i] = ith request's batch_size。 +# """ +# outs = [] +# start = 0 +# for bsz in split_sizes: +# end = start + bsz +# outs.append( +# MoeCausalLMOutputWithPast( +# loss=_slice_tensor(batch_out.loss, start, end), +# logits=batch_out.logits[start:end], +# past_key_values=_slice_pkv(batch_out.past_key_values, start, end), +# hidden_states=_slice_tuple_of_tensors( +# batch_out.hidden_states, start, end +# ), +# attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), +# ) +# ) +# start = end +# return outs diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index a67d576b0ec8c..1ce2e84e05929 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -18,7 +18,6 @@ import threading import time -from abc import ABC, abstractmethod from typing import Dict import pandas as pd @@ -29,29 +28,22 @@ from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import ( InferenceModelInternalError, - InvalidWindowArgumentError, NumericalRangeException, - runtime_error_extractor, ) from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, InferenceRequestProxy, ) -from iotdb.ainode.core.inference.pool_controller import PoolController -from iotdb.ainode.core.inference.strategy.timer_sundial_inference_pipeline import ( - TimerSundialInferencePipeline, -) -from iotdb.ainode.core.inference.strategy.timerxl_inference_pipeline import ( - TimerXLInferencePipeline, +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, ) +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline +from iotdb.ainode.core.inference.pool_controller import PoolController from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.modeling_timer import TimerForPrediction from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.ainode.core.util.serde import convert_to_binary @@ -71,83 +63,6 @@ logger = Logger() -class InferenceStrategy(ABC): - def __init__(self, model): - self.model = model - - @abstractmethod - def infer(self, full_data, **kwargs): - pass - - -# [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length], -# we only get valueList currently. -class TimerXLStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) - df = pd.DataFrame(output[0]) - return convert_to_binary(df) - - -class SundialStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate( - seqs, max_new_tokens=predict_length, num_samples=10, revin=True - ) - df = pd.DataFrame(output[0].mean(dim=0)) - return convert_to_binary(df) - - -class BuiltInStrategy(InferenceStrategy): - def infer(self, full_data, **_): - data = pd.DataFrame(full_data[1]).T - output = self.model.inference(data) - df = pd.DataFrame(output) - return convert_to_binary(df) - - -class RegisteredStrategy(InferenceStrategy): - def infer(self, full_data, window_interval=None, window_step=None, **_): - _, dataset, _, length = full_data - if window_interval is None or window_step is None: - window_interval = length - window_step = float("inf") - - if window_interval <= 0 or window_step <= 0 or window_interval > length: - raise InvalidWindowArgumentError(window_interval, window_step, length) - - data = torch.tensor(dataset, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1) - - times = int((length - window_interval) // window_step + 1) - results = [] - try: - for i in range(times): - start = 0 if window_step == float("inf") else i * window_step - end = start + window_interval - window = data[:, start:end, :] - out = self.model(window) - df = pd.DataFrame(out.squeeze(0).detach().numpy()) - results.append(df) - except Exception as e: - msg = runtime_error_extractor(str(e)) or str(e) - raise InferenceModelInternalError(msg) - - # concatenate or return first window for forecast - return [convert_to_binary(df) for df in results] - - class InferenceManager: WAITING_INTERVAL_IN_MS = ( AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms() @@ -251,15 +166,6 @@ def _process_request(self, req): with self._result_wrapper_lock: del self._result_wrapper_map[req_id] - def _get_strategy(self, model_id, model): - if isinstance(model, TimerForPrediction): - return TimerXLStrategy(model) - if isinstance(model, SundialForPrediction): - return SundialStrategy(model) - if self._model_manager.model_storage.is_built_in_or_fine_tuned(model_id): - return BuiltInStrategy(model) - return RegisteredStrategy(model) - def _run( self, req, @@ -272,59 +178,54 @@ def _run( model_id = req.modelId try: raw = data_getter(req) + # full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently. full_data = deserializer(raw) - inference_attrs = extract_attrs(req) + # TODO: TSBlock -> Tensor codes should be unified + data = full_data[1][0] # get valueList in ndarray + if data.dtype.byteorder not in ("=", "|"): + np_data = data.byteswap() + data = np_data.view(np_data.dtype.newbyteorder()) + # the inputs should be on CPU before passing to the inference request + inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - predict_length = int(inference_attrs.pop("predict_length", 96)) + inference_attrs = extract_attrs(req) + output_length = int(inference_attrs.pop("output_length", 96)) if ( - predict_length - > AINodeDescriptor().get_config().get_ain_inference_max_predict_length() + output_length + > AINodeDescriptor().get_config().get_ain_inference_max_output_length() ): raise NumericalRangeException( "output_length", 1, AINodeDescriptor() .get_config() - .get_ain_inference_max_predict_length(), - predict_length, + .get_ain_inference_max_output_length(), + output_length, ) if self._pool_controller.has_request_pools(model_id): - # use request pool to accelerate inference when the model instance is already loaded. - # TODO: TSBlock -> Tensor codes should be unified - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - # the inputs should be on CPU before passing to the inference request - inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - model_type = self._model_manager.get_model_info(model_id).model_type - if model_type == BuiltInModelType.SUNDIAL.value: - inference_pipeline = TimerSundialInferencePipeline(SundialConfig()) - elif model_type == BuiltInModelType.TIMER_XL.value: - inference_pipeline = TimerXLInferencePipeline(TimerConfig()) - else: - raise InferenceModelInternalError( - f"Unsupported model_id: {model_id}" - ) infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id, inputs=inputs, - inference_pipeline=inference_pipeline, - max_new_tokens=predict_length, + output_length=output_length, ) outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) else: - # load model - accel = str(inference_attrs.get("acceleration", "")).lower() == "true" - model = self._model_manager.load_model(model_id, inference_attrs, accel) - # inference by strategy - strategy = self._get_strategy(model_id, model) - outputs = strategy.infer( - full_data, predict_length=predict_length, **inference_attrs - ) + model_info = self._model_manager.get_model_info(model_id) + inference_pipeline = load_pipeline(model_info, device="cpu") + if isinstance(inference_pipeline, ForecastPipeline): + outputs = inference_pipeline.forecast( + inputs, predict_length=output_length, **inference_attrs + ) + elif isinstance(inference_pipeline, ClassificationPipeline): + outputs = inference_pipeline.classify(inputs) + elif isinstance(inference_pipeline, ChatPipeline): + outputs = inference_pipeline.chat(inputs) + else: + logger.error("[Inference] Unsupported pipeline type.") + outputs = convert_to_binary(pd.DataFrame(outputs[0])) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) @@ -345,7 +246,7 @@ def forecast(self, req: TForecastReq): data_getter=lambda r: r.inputData, deserializer=deserialize, extract_attrs=lambda r: { - "predict_length": r.outputLength, + "output_length": r.outputLength, **(r.options or {}), }, resp_cls=TForecastResp, @@ -358,8 +259,7 @@ def inference(self, req: TInferenceReq): data_getter=lambda r: r.dataset, deserializer=deserialize, extract_attrs=lambda r: { - "window_interval": getattr(r.windowParams, "windowInterval", None), - "window_step": getattr(r.windowParams, "windowStep", None), + "output_length": int(r.inferenceAttributes.pop("outputLength", 96)), **(r.inferenceAttributes or {}), }, resp_cls=TInferenceResp, diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index d84bca77c8430..8ffb33d91e2d2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -15,17 +15,14 @@ # specific language governing permissions and limitations # under the License. # -from typing import Callable, Dict -from torch import nn -from yaml import YAMLError +from typing import Any, List, Optional from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType, ModelStates -from iotdb.ainode.core.model.model_info import ModelInfo -from iotdb.ainode.core.model.model_storage import ModelStorage +from iotdb.ainode.core.model.model_loader import load_model +from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.ainode.ttypes import ( @@ -43,127 +40,60 @@ @singleton class ModelManager: def __init__(self): - self.model_storage = ModelStorage() + self._model_storage = ModelStorage() - def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: - logger.info(f"register model {req.modelId} from {req.uri}") + def register_model( + self, + req: TRegisterModelReq, + ) -> TRegisterModelResp: try: - configs, attributes = self.model_storage.register_model( - req.modelId, req.uri - ) - return TRegisterModelResp( - get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes - ) - except InvalidUriError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, e.message) - ) - except BadConfigValueError as e: - logger.warning(e) + if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + except ValueError as e: return TRegisterModelResp( - get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) + get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) ) - except YAMLError as e: - logger.warning(e) - if hasattr(e, "problem_mark"): - mark = e.problem_mark - return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file, " - f"at line {mark.line + 1} column {mark.column + 1}.", - ) - ) + except Exception as e: return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file", - ) + get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) ) - except Exception as e: - logger.warning(e) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + self._refresh() + return self._model_storage.show_models(req) def delete_model(self, req: TDeleteModelReq) -> TSStatus: - logger.info(f"delete model {req.modelId}") try: - self.model_storage.delete_model(req.modelId) + self._model_storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) - except Exception as e: + except BuiltInModelDeletionError as e: logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False - ) -> Callable: - """ - Load the model with the given model_id. - """ - logger.info(f"Load model {model_id}") - try: - model = self.model_storage.load_model( - model_id, inference_attrs, acceleration - ) - logger.info(f"Model {model_id} loaded") - return model - except Exception as e: - logger.error(f"Failed to load model {model_id}: {e}") - raise - - def save_model(self, model_id: str, model: nn.Module) -> TSStatus: - """ - Save the model using save_pretrained - """ - logger.info(f"Saving model {model_id}") - try: - self.model_storage.save_model(model_id, model) - logger.info(f"Saving model {model_id} successfully") - return get_status( - TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved successfully" - ) except Exception as e: - logger.error(f"Save model failed: {e}") + logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. - - Args: - model_id (str): The ID of the model. - - Returns: - str: The path to the checkpoint file for the model. - """ - return self.model_storage.get_ckpt_path(model_id) - - def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - return self.model_storage.show_models(req) - - def register_built_in_model(self, model_info: ModelInfo): - self.model_storage.register_built_in_model(model_info) - - def get_model_info(self, model_id: str) -> ModelInfo: - return self.model_storage.get_model_info(model_id) - - def update_model_state(self, model_id: str, state: ModelStates): - self.model_storage.update_model_state(model_id, state) - - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: - """ - Get the type of the model with the given model_id. - """ - return self.model_storage.get_built_in_model_type(model_id) - - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. - - Args: - model_id (str): The ID of the model. - - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return self.model_storage.is_built_in_or_fine_tuned(model_id) + def get_model_info( + self, + model_id: str, + category: Optional[ModelCategory] = None, + ) -> Optional[ModelInfo]: + return self._model_storage.get_model_info(model_id, category) + + def get_model_infos( + self, + category: Optional[ModelCategory] = None, + model_type: Optional[str] = None, + ) -> List[ModelInfo]: + return self._model_storage.get_model_infos(category, model_type) + + def _refresh(self): + """Refresh the model list (re-scan the file system)""" + self._model_storage.discover_all_models() + + def get_registered_models(self) -> List[str]: + return self._model_storage.get_registered_models() + + def is_model_registered(self, model_id: str) -> bool: + return self._model_storage.is_model_registered(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 0264e27331a86..23a98f26bbffa 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -25,7 +25,7 @@ from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP +from iotdb.ainode.core.model.model_loader import load_model logger = Logger() @@ -47,7 +47,8 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id, {}).to(device) + model_info = ModelManager().get_model_info(model_id) + model = load_model(model_info).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start @@ -80,7 +81,7 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = BUILT_IN_LTSM_MAP.get(model_id, None) + model_info = ModelManager().get_model_info(model_id) if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py deleted file mode 100644 index 3b55142350bad..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py +++ /dev/null @@ -1,1238 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from abc import abstractmethod -from typing import Callable, Dict, List - -import numpy as np -from huggingface_hub import hf_hub_download -from sklearn.preprocessing import MinMaxScaler -from sktime.detection.hmm_learn import GMMHMM, GaussianHMM -from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA -from sktime.forecasting.exp_smoothing import ExponentialSmoothing -from sktime.forecasting.naive import NaiveForecaster -from sktime.forecasting.trend import STLForecaster - -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - AttributeName, -) -from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - InferenceModelInternalError, - ListRangeException, - NumericalRangeException, - StringRangeException, - WrongAttributeTypeError, -) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import TIMER_REPO_ID -from iotdb.ainode.core.model.sundial import modeling_sundial -from iotdb.ainode.core.model.timerxl import modeling_timer - -logger = Logger() - - -def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: - weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) - if not os.path.exists(weights_path): - logger.info( - f"Model weights file not found at {weights_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - local_dir=local_dir, - ) - logger.info(f"Got file to {weights_path}") - except Exception as e: - logger.error( - f"Failed to download model weights file to {local_dir} due to {e}" - ) - return False - if not os.path.exists(config_path): - logger.info( - f"Model config file not found at {config_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, - local_dir=local_dir, - ) - logger.info(f"Got file to {config_path}") - except Exception as e: - logger.error( - f"Failed to download model config file to {local_dir} due to {e}" - ) - return False - return True - - -def download_built_in_ltsm_from_hf_if_necessary( - model_type: BuiltInModelType, local_dir: str -) -> bool: - """ - Download the built-in ltsm from HuggingFace repository when necessary. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - repo_id = TIMER_REPO_ID[model_type] - if not _download_file_from_hf_if_necessary(local_dir, repo_id): - return False - return True - - -def get_model_attributes(model_type: BuiltInModelType): - if model_type == BuiltInModelType.ARIMA: - attribute_map = arima_attribute_map - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - attribute_map = naive_forecaster_attribute_map - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - attribute_map = exponential_smoothing_attribute_map - elif model_type == BuiltInModelType.STL_FORECASTER: - attribute_map = stl_forecaster_attribute_map - elif model_type == BuiltInModelType.GMM_HMM: - attribute_map = gmmhmm_attribute_map - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - attribute_map = gaussian_hmm_attribute_map - elif model_type == BuiltInModelType.STRAY: - attribute_map = stray_attribute_map - elif model_type == BuiltInModelType.TIMER_XL: - attribute_map = timerxl_attribute_map - elif model_type == BuiltInModelType.SUNDIAL: - attribute_map = sundial_attribute_map - else: - raise BuiltInModelNotSupportError(model_type.value) - return attribute_map - - -def fetch_built_in_model( - model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str] -) -> Callable: - """ - Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config. - Args: - model_type: the type of the built-in model - model_dir: for huggingface models only, the directory where the model is stored - Returns: - model: the built-in model - """ - default_attributes = get_model_attributes(model_type) - # parse the attributes from inference_attrs - attributes = parse_attribute(inference_attrs, default_attributes) - - # build the built-in model - if model_type == BuiltInModelType.ARIMA: - model = ArimaModel(attributes) - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - model = ExponentialSmoothingModel(attributes) - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - model = NaiveForecasterModel(attributes) - elif model_type == BuiltInModelType.STL_FORECASTER: - model = STLForecasterModel(attributes) - elif model_type == BuiltInModelType.GMM_HMM: - model = GMMHMMModel(attributes) - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - model = GaussianHmmModel(attributes) - elif model_type == BuiltInModelType.STRAY: - model = STRAYModel(attributes) - elif model_type == BuiltInModelType.TIMER_XL: - model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) - elif model_type == BuiltInModelType.SUNDIAL: - model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) - else: - raise BuiltInModelNotSupportError(model_type.value) - - return model - - -class Attribute(object): - def __init__(self, name: str): - """ - Args: - name: the name of the attribute - """ - self._name = name - - @abstractmethod - def get_default_value(self): - raise NotImplementedError - - @abstractmethod - def validate_value(self, value): - raise NotImplementedError - - @abstractmethod - def parse(self, string_value: str): - raise NotImplementedError - - -class IntAttribute(Attribute): - def __init__( - self, - name: str, - default_value: int, - default_low: int, - default_high: int, - ): - super(IntAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - int_value = int(string_value) - except: - raise WrongAttributeTypeError(self._name, "int") - return int_value - - -class FloatAttribute(Attribute): - def __init__( - self, - name: str, - default_value: float, - default_low: float, - default_high: float, - ): - super(FloatAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - float_value = float(string_value) - except: - raise WrongAttributeTypeError(self._name, "float") - return float_value - - -class StringAttribute(Attribute): - def __init__(self, name: str, default_value: str, value_choices: List[str]): - super(StringAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_choices = value_choices - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if value in self.__value_choices: - return True - raise StringRangeException(self._name, value, self.__value_choices) - - def parse(self, string_value: str): - return string_value - - -class BooleanAttribute(Attribute): - def __init__(self, name: str, default_value: bool): - super(BooleanAttribute, self).__init__(name) - self.__default_value = default_value - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if isinstance(value, bool): - return True - raise WrongAttributeTypeError(self._name, "bool") - - def parse(self, string_value: str): - if string_value.lower() == "true": - return True - elif string_value.lower() == "false": - return False - else: - raise WrongAttributeTypeError(self._name, "bool") - - -class ListAttribute(Attribute): - def __init__(self, name: str, default_value: List, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(ListAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, list): - raise WrongAttributeTypeError(self._name, "list") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - list_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "list") - if not isinstance(list_value, list): - raise WrongAttributeTypeError(self._name, "list") - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - return list_value - - -class TupleAttribute(Attribute): - def __init__(self, name: str, default_value: tuple, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(TupleAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - tuple_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "tuple") - if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - list_value = list(tuple_value) - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - tuple_value = tuple(list_value) - return tuple_value - - -def parse_attribute( - input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] -): - """ - Args: - input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of - the attribute - attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute - object - Returns: - a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute - """ - attributes = {} - for attribute_name in attribute_map: - # user specified the attribute - if attribute_name in input_attributes: - attribute = attribute_map[attribute_name] - value = attribute.parse(input_attributes[attribute_name]) - attribute.validate_value(value) - attributes[attribute_name] = value - # user did not specify the attribute, use the default value - else: - try: - attributes[attribute_name] = attribute_map[ - attribute_name - ].get_default_value() - except NotImplementedError as e: - logger.error(f"attribute {attribute_name} is not implemented.") - raise e - return attributes - - -sundial_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=16, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=768, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=3072, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=12, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=12, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.DROPOUT_RATE.value: FloatAttribute( - name=AttributeName.DROPOUT_RATE.value, - default_value=0.1, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute( - name=AttributeName.FLOW_LOSS_DEPTH.value, - default_value=3, - default_low=1, - default_high=50, - ), - AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute( - name=AttributeName.NUM_SAMPLING_STEPS.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute( - name=AttributeName.DIFFUSION_BATCH_MUL.value, - default_value=4, - default_low=1, - default_high=5000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "sundial", - ), - value_choices=[""], - ), -} - -timerxl_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=96, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=1024, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=2048, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=8, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=8, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.ATTENTION_DROPOUT.value: FloatAttribute( - name=AttributeName.ATTENTION_DROPOUT.value, - default_value=0.0, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "timerxl", - "model.safetensors", - ), - value_choices=[""], - ), -} - -# built-in sktime model attributes -# NaiveForecaster -naive_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.STRATEGY.value: StringAttribute( - name=AttributeName.STRATEGY.value, - default_value="last", - value_choices=["last", "mean"], - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 - ), -} -# ExponentialSmoothing -exponential_smoothing_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.DAMPED_TREND.value: BooleanAttribute( - name=AttributeName.DAMPED_TREND.value, - default_value=False, - ), - AttributeName.INITIALIZATION_METHOD.value: StringAttribute( - name=AttributeName.INITIALIZATION_METHOD.value, - default_value="estimated", - value_choices=["estimated", "heuristic", "legacy-heuristic", "known"], - ), - AttributeName.OPTIMIZED.value: BooleanAttribute( - name=AttributeName.OPTIMIZED.value, - default_value=True, - ), - AttributeName.REMOVE_BIAS.value: BooleanAttribute( - name=AttributeName.REMOVE_BIAS.value, - default_value=False, - ), - AttributeName.USE_BRUTE.value: BooleanAttribute( - name=AttributeName.USE_BRUTE.value, - default_value=False, - ), -} -# Arima -arima_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.ORDER.value: TupleAttribute( - name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int - ), - AttributeName.SEASONAL_ORDER.value: TupleAttribute( - name=AttributeName.SEASONAL_ORDER.value, - default_value=(0, 0, 0, 0), - value_type=int, - ), - AttributeName.METHOD.value: StringAttribute( - name=AttributeName.METHOD.value, - default_value="lbfgs", - value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], - ), - AttributeName.MAXITER.value: IntAttribute( - name=AttributeName.MAXITER.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( - name=AttributeName.SUPPRESS_WARNINGS.value, - default_value=True, - ), - AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute( - name=AttributeName.OUT_OF_SAMPLE_SIZE.value, - default_value=0, - default_low=0, - default_high=5000, - ), - AttributeName.SCORING.value: StringAttribute( - name=AttributeName.SCORING.value, - default_value="mse", - value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], - ), - AttributeName.WITH_INTERCEPT.value: BooleanAttribute( - name=AttributeName.WITH_INTERCEPT.value, - default_value=True, - ), - AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute( - name=AttributeName.TIME_VARYING_REGRESSION.value, - default_value=False, - ), - AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_STATIONARITY.value, - default_value=True, - ), - AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_INVERTIBILITY.value, - default_value=True, - ), - AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute( - name=AttributeName.SIMPLE_DIFFERENCING.value, - default_value=False, - ), - AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute( - name=AttributeName.MEASUREMENT_ERROR.value, - default_value=False, - ), - AttributeName.MLE_REGRESSION.value: BooleanAttribute( - name=AttributeName.MLE_REGRESSION.value, - default_value=True, - ), - AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute( - name=AttributeName.HAMILTON_REPRESENTATION.value, - default_value=False, - ), - AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( - name=AttributeName.CONCENTRATE_SCALE.value, - default_value=False, - ), -} -# STLForecaster -stl_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000 - ), - AttributeName.SEASONAL.value: IntAttribute( - name=AttributeName.SEASONAL.value, - default_value=7, - default_low=1, - default_high=5000, - ), - AttributeName.SEASONAL_DEG.value: IntAttribute( - name=AttributeName.SEASONAL_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_DEG.value: IntAttribute( - name=AttributeName.TREND_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOW_PASS_DEG.value: IntAttribute( - name=AttributeName.LOW_PASS_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.SEASONAL_JUMP.value: IntAttribute( - name=AttributeName.SEASONAL_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_JUMP.value: IntAttribute( - name=AttributeName.TREND_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOSS_PASS_JUMP.value: IntAttribute( - name=AttributeName.LOSS_PASS_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), -} - -# GAUSSIAN_HMM -gaussian_hmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["spherical", "diag", "full", "tied"], - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_PRIOR.value: FloatAttribute( - name=AttributeName.COVARS_PRIOR.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_WEIGHT.value: FloatAttribute( - name=AttributeName.COVARS_WEIGHT.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# GMMHMM -gmmhmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.N_MIX.value: IntAttribute( - name=AttributeName.N_MIX.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.WEIGHTS_PRIOR.value: FloatAttribute( - name=AttributeName.WEIGHTS_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["sperical", "diag", "full", "tied"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# STRAY -stray_attribute_map = { - AttributeName.ALPHA.value: FloatAttribute( - name=AttributeName.ALPHA.value, - default_value=0.01, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.K.value: IntAttribute( - name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000 - ), - AttributeName.KNN_ALGORITHM.value: StringAttribute( - name=AttributeName.KNN_ALGORITHM.value, - default_value="brute", - value_choices=["brute", "kd_tree", "ball_tree", "auto"], - ), - AttributeName.P.value: FloatAttribute( - name=AttributeName.P.value, - default_value=0.5, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.SIZE_THRESHOLD.value: IntAttribute( - name=AttributeName.SIZE_THRESHOLD.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.OUTLIER_TAIL.value: StringAttribute( - name=AttributeName.OUTLIER_TAIL.value, - default_value="max", - value_choices=["min", "max"], - ), -} - - -class BuiltInModel(object): - def __init__(self, attributes): - self._attributes = attributes - self._model = None - - @abstractmethod - def inference(self, data): - raise NotImplementedError - - -class ArimaModel(BuiltInModel): - def __init__(self, attributes): - super(ArimaModel, self).__init__(attributes) - self._model = ARIMA( - order=attributes["order"], - seasonal_order=attributes["seasonal_order"], - method=attributes["method"], - suppress_warnings=attributes["suppress_warnings"], - maxiter=attributes["maxiter"], - out_of_sample_size=attributes["out_of_sample_size"], - scoring=attributes["scoring"], - with_intercept=attributes["with_intercept"], - time_varying_regression=attributes["time_varying_regression"], - enforce_stationarity=attributes["enforce_stationarity"], - enforce_invertibility=attributes["enforce_invertibility"], - simple_differencing=attributes["simple_differencing"], - measurement_error=attributes["measurement_error"], - mle_regression=attributes["mle_regression"], - hamilton_representation=attributes["hamilton_representation"], - concentrate_scale=attributes["concentrate_scale"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class ExponentialSmoothingModel(BuiltInModel): - def __init__(self, attributes): - super(ExponentialSmoothingModel, self).__init__(attributes) - self._model = ExponentialSmoothing( - damped_trend=attributes["damped_trend"], - initialization_method=attributes["initialization_method"], - optimized=attributes["optimized"], - remove_bias=attributes["remove_bias"], - use_brute=attributes["use_brute"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class NaiveForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(NaiveForecasterModel, self).__init__(attributes) - self._model = NaiveForecaster( - strategy=attributes["strategy"], sp=attributes["sp"] - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STLForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(STLForecasterModel, self).__init__(attributes) - self._model = STLForecaster( - sp=attributes["sp"], - seasonal=attributes["seasonal"], - seasonal_deg=attributes["seasonal_deg"], - trend_deg=attributes["trend_deg"], - low_pass_deg=attributes["low_pass_deg"], - seasonal_jump=attributes["seasonal_jump"], - trend_jump=attributes["trend_jump"], - low_pass_jump=attributes["low_pass_jump"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GMMHMMModel(BuiltInModel): - def __init__(self, attributes): - super(GMMHMMModel, self).__init__(attributes) - self._model = GMMHMM( - n_components=attributes["n_components"], - n_mix=attributes["n_mix"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - weights_prior=attributes["weights_prior"], - algorithm=attributes["algorithm"], - covariance_type=attributes["covariance_type"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GaussianHmmModel(BuiltInModel): - def __init__(self, attributes): - super(GaussianHmmModel, self).__init__(attributes) - self._model = GaussianHMM( - n_components=attributes["n_components"], - covariance_type=attributes["covariance_type"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - covars_prior=attributes["covars_prior"], - covars_weight=attributes["covars_weight"], - algorithm=attributes["algorithm"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STRAYModel(BuiltInModel): - def __init__(self, attributes): - super(STRAYModel, self).__init__(attributes) - self._model = STRAY( - alpha=attributes["alpha"], - k=attributes["k"], - knn_algorithm=attributes["knn_algorithm"], - p=attributes["p"], - size_threshold=attributes["size_threshold"], - outlier_tail=attributes["outlier_tail"], - ) - - def inference(self, data): - try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - # change the output to int - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py new file mode 100644 index 0000000000000..c42ec98551b83 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from enum import Enum + +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + +class ModelCategory(Enum): + BUILTIN = "builtin" + USER_DEFINED = "user_defined" + + +class ModelStates(Enum): + INACTIVE = "inactive" + ACTIVATING = "activating" + ACTIVE = "active" + DROPPING = "dropping" + + +class UriType(Enum): + REPO = "repo" + FILE = "file" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py deleted file mode 100644 index 348f9924316b6..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from enum import Enum -from typing import List - - -class BuiltInModelType(Enum): - # forecast models - ARIMA = "Arima" - HOLTWINTERS = "HoltWinters" - EXPONENTIAL_SMOOTHING = "ExponentialSmoothing" - NAIVE_FORECASTER = "NaiveForecaster" - STL_FORECASTER = "StlForecaster" - - # anomaly detection models - GAUSSIAN_HMM = "GaussianHmm" - GMM_HMM = "GmmHmm" - STRAY = "Stray" - - # large time series models (LTSM) - TIMER_XL = "Timer-XL" - # sundial - SUNDIAL = "Timer-Sundial" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def is_built_in_model(model_type: str) -> bool: - """ - Check if the given model type corresponds to a built-in model. - """ - return model_type in BuiltInModelType.values() - - -class ModelFileType(Enum): - SAFETENSORS = "safetensors" - PYTORCH = "pytorch" - UNKNOWN = "unknown" - - -class ModelCategory(Enum): - BUILT_IN = "BUILT-IN" - FINE_TUNED = "FINE-TUNED" - USER_DEFINED = "USER-DEFINED" - - -class ModelStates(Enum): - ACTIVE = "ACTIVE" - INACTIVE = "INACTIVE" - LOADING = "LOADING" - DROPPING = "DROPPING" - TRAINING = "TRAINING" - FAILED = "FAILED" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py deleted file mode 100644 index 26d863156f379..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py +++ /dev/null @@ -1,291 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import glob -import os -import shutil -from urllib.parse import urljoin - -import yaml - -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.uri_utils import ( - UriType, - download_file, - download_snapshot_from_hf, -) -from iotdb.ainode.core.util.serde import get_data_type_byte_from_str -from iotdb.thrift.ainode.ttypes import TConfigs - -logger = Logger() - - -def fetch_model_by_uri( - uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Fetch the model files from the specified URI. - - Args: - uri_type (UriType): type of the URI, either repo, file, http or https - uri (str): a network or a local path of the model to be registered - storage_path (str): path to save the whole model, including weights, config, codes, etc. - model_file_type (ModelFileType): The type of model file, either safetensors or pytorch - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if uri_type == UriType.REPO or uri_type in [UriType.HTTP, UriType.HTTPS]: - return _fetch_model_from_network(uri, storage_path, model_file_type) - elif uri_type == UriType.FILE: - return _fetch_model_from_local(uri, storage_path, model_file_type) - else: - raise InvalidUriError(f"Invalid URI type: {uri_type}") - - -def _fetch_model_from_network( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - download_snapshot_from_hf(uri, storage_path) - return _process_huggingface_files(storage_path) - - # TODO: The following codes might be refactored in future - # concat uri to get complete url - uri = uri if uri.endswith("/") else uri + "/" - target_model_path = urljoin(uri, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = urljoin(uri, MODEL_CONFIG_FILE_IN_YAML) - - # download config file - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - download_file(target_config_path, config_storage_path) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, download model file - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - download_file(target_model_path, model_storage_path) - return configs, attributes - - -def _fetch_model_from_local( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - # copy anything in the uri to local_dir - for file in os.listdir(uri): - shutil.copy(os.path.join(uri, file), storage_path) - return _process_huggingface_files(storage_path) - # concat uri to get complete path - target_model_path = os.path.join(uri, MODEL_WEIGHTS_FILE_IN_PT) - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = os.path.join(uri, MODEL_CONFIG_FILE_IN_YAML) - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - - # check if file exist - exist_model_file = os.path.exists(target_model_path) - exist_config_file = os.path.exists(target_config_path) - - configs = None - attributes = None - if exist_model_file and exist_config_file: - # copy config.yaml - shutil.copy(target_config_path, config_storage_path) - logger.info( - f"copy file from {target_config_path} to {config_storage_path} success" - ) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, copy model file - shutil.copy(target_model_path, model_storage_path) - logger.info( - f"copy file from {target_model_path} to {model_storage_path} success" - ) - - elif not exist_model_file or not exist_config_file: - raise InvalidUriError(uri) - - return configs, attributes - - -def _parse_inference_config(config_dict): - """ - Args: - config_dict: dict - - configs: dict - - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2] - - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2] - - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - attributes: dict - Returns: - configs: TConfigs - attributes: str - """ - configs = config_dict["configs"] - - # check if input_shape and output_shape are two-dimensional array - if not ( - isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2 - ): - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "input_shape should be a two-dimensional array.", - ) - if not ( - isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2 - ): - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "output_shape should be a two-dimensional array.", - ) - - # check if input_shape and output_shape are positive integer - input_shape_is_positive_number = ( - isinstance(configs["input_shape"][0], int) - and isinstance(configs["input_shape"][1], int) - and configs["input_shape"][0] > 0 - and configs["input_shape"][1] > 0 - ) - if not input_shape_is_positive_number: - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "element in input_shape should be positive integer.", - ) - - output_shape_is_positive_number = ( - isinstance(configs["output_shape"][0], int) - and isinstance(configs["output_shape"][1], int) - and configs["output_shape"][0] > 0 - and configs["output_shape"][1] > 0 - ) - if not output_shape_is_positive_number: - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "element in output_shape should be positive integer.", - ) - - # check if input_type and output_type are one-dimensional array with right length - if "input_type" in configs and not ( - isinstance(configs["input_type"], list) - and len(configs["input_type"]) == configs["input_shape"][1] - ): - raise BadConfigValueError( - "input_type", - configs["input_type"], - "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].", - ) - - if "output_type" in configs and not ( - isinstance(configs["output_type"], list) - and len(configs["output_type"]) == configs["output_shape"][1] - ): - raise BadConfigValueError( - "output_type", - configs["output_type"], - "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].", - ) - - # parse input_type and output_type to byte - if "input_type" in configs: - input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]] - else: - input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][ - 1 - ] - - if "output_type" in configs: - output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]] - else: - output_type = [get_data_type_byte_from_str("float32")] * configs[ - "output_shape" - ][1] - - # parse attributes - attributes = "" - if "attributes" in config_dict: - attributes = str(config_dict["attributes"]) - - return ( - TConfigs( - configs["input_shape"], configs["output_shape"], input_type, output_type - ), - attributes, - ) - - -def _process_huggingface_files(local_dir: str): - """ - TODO: Currently, we use this function to convert the model config from huggingface, we will refactor this in the future. - """ - config_file = None - for config_name in ["config.json", "model_config.json"]: - config_path = os.path.join(local_dir, config_name) - if os.path.exists(config_path): - config_file = config_path - break - - if not config_file: - raise InvalidUriError(f"No config.json found in {local_dir}") - - safetensors_files = glob.glob(os.path.join(local_dir, "*.safetensors")) - if not safetensors_files: - raise InvalidUriError(f"No .safetensors files found in {local_dir}") - - simple_config = { - "configs": { - "input_shape": [96, 1], - "output_shape": [96, 1], - "input_type": ["float32"], - "output_type": ["float32"], - }, - "attributes": { - "model_type": "huggingface_model", - "source_dir": local_dir, - "files": [os.path.basename(f) for f in safetensors_files], - }, - } - - configs, attributes = _parse_inference_config(simple_config) - return configs, attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 167bfd76640d1..718ead530dd2c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -15,140 +15,118 @@ # specific language governing permissions and limitations # under the License. # -import glob -import os -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) +from typing import Dict, Optional - -def get_model_file_type(model_path: str) -> ModelFileType: - """ - Determine the file type of the specified model directory. - """ - if _has_safetensors_format(model_path): - return ModelFileType.SAFETENSORS - elif _has_pytorch_format(model_path): - return ModelFileType.PYTORCH - else: - return ModelFileType.UNKNOWN - - -def _has_safetensors_format(path: str) -> bool: - """Check if directory contains safetensors files.""" - safetensors_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)) - json_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_JSON)) - return len(safetensors_files) > 0 and len(json_files) > 0 - - -def _has_pytorch_format(path: str) -> bool: - """Check if directory contains pytorch files.""" - pt_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_PT)) - yaml_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_YAML)) - return len(pt_files) > 0 and len(yaml_files) > 0 - - -def get_built_in_model_type(model_type: str) -> BuiltInModelType: - if not BuiltInModelType.is_built_in_model(model_type): - raise ValueError(f"Invalid built-in model type: {model_type}") - return BuiltInModelType(model_type) +from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates class ModelInfo: def __init__( self, model_id: str, - model_type: str, category: ModelCategory, state: ModelStates, + model_type: str = "", + config_cls: str = "", + model_cls: str = "", + pipeline_cls: str = "", + repo_id: str = "", + auto_map: Optional[Dict] = None, + _transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type self.category = category self.state = state + self.config_cls = config_cls + self.model_cls = model_cls + self.pipeline_cls = pipeline_cls + self.repo_id = repo_id + self.auto_map = auto_map # If exists, indicates it's a Transformers model + self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + def __repr__(self): + return ( + f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " + f"category={self.category.value}, state={self.state.value}, " + f"has_auto_map={self.auto_map is not None})" + ) -TIMER_REPO_ID = { - BuiltInModelType.TIMER_XL: "thuml/timer-base-84m", - BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m", -} -# Built-in machine learning models, they can be employed directly -BUILT_IN_MACHINE_LEARNING_MODEL_MAP = { +BUILTIN_SKTIME_MODEL_MAP = { # forecast models "arima": ModelInfo( model_id="arima", - model_type=BuiltInModelType.ARIMA.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "holtwinters": ModelInfo( model_id="holtwinters", - model_type=BuiltInModelType.HOLTWINTERS.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "exponential_smoothing": ModelInfo( model_id="exponential_smoothing", - model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "naive_forecaster": ModelInfo( model_id="naive_forecaster", - model_type=BuiltInModelType.NAIVE_FORECASTER.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stl_forecaster": ModelInfo( model_id="stl_forecaster", - model_type=BuiltInModelType.STL_FORECASTER.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), # anomaly detection models "gaussian_hmm": ModelInfo( model_id="gaussian_hmm", - model_type=BuiltInModelType.GAUSSIAN_HMM.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "gmm_hmm": ModelInfo( model_id="gmm_hmm", - model_type=BuiltInModelType.GMM_HMM.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stray": ModelInfo( model_id="stray", - model_type=BuiltInModelType.STRAY.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), } -# Built-in large time series models (LTSM), their weights are not included in AINode by default -BUILT_IN_LTSM_MAP = { +# Built-in huggingface transformers models, their weights are not included in AINode by default +BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { "timer_xl": ModelInfo( model_id="timer_xl", - model_type=BuiltInModelType.TIMER_XL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="timer", + config_cls="configuration_timer.TimerConfig", + model_cls="modeling_timer.TimerForPrediction", + pipeline_cls="pipeline_timer.TimerPipeline", + repo_id="thuml/timer-base-84m", ), "sundial": ModelInfo( model_id="sundial", - model_type=BuiltInModelType.SUNDIAL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="sundial", + config_cls="configuration_sundial.SundialConfig", + model_cls="modeling_sundial.SundialForPrediction", + pipeline_cls="pipeline_sundial.SundialPipeline", + repo_id="thuml/sundial-base-128m", ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py new file mode 100644 index 0000000000000..a6e3b1f7b5e38 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from pathlib import Path +from typing import Any + +import torch +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForNextSentencePrediction, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTimeSeriesPrediction, + AutoModelForTokenClassification, +) + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path + +logger = Logger() + + +def load_model(model_info: ModelInfo, **model_kwargs) -> Any: + if model_info.auto_map is not None: + model = load_model_from_transformers(model_info, **model_kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_info.model_id) + else: + model = load_model_from_pt(model_info, **model_kwargs) + + logger.info( + f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully." + ) + return model + + +def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): + device_map = model_kwargs.get("device_map", "cpu") + trust_remote_code = model_kwargs.get("trust_remote_code", True) + train_from_scratch = model_kwargs.get("train_from_scratch", False) + + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + + if model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + config_cls = import_class_from_path(module_name, model_info.config_cls) + model_cls = import_class_from_path(module_name, model_info.model_cls) + elif model_info.model_cls and model_info.config_cls: + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_cls = import_class_from_path( + model_info.model_id, model_info.config_cls + ) + model_cls = import_class_from_path( + model_info.model_id, model_info.model_cls + ) + else: + config_cls = AutoConfig.from_pretrained(model_path) + if type(config_cls) in AutoModelForTimeSeriesPrediction._model_mapping.keys(): + model_cls = AutoModelForTimeSeriesPrediction + elif ( + type(config_cls) in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + model_cls = AutoModelForNextSentencePrediction + elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys(): + model_cls = AutoModelForSeq2SeqLM + elif ( + type(config_cls) in AutoModelForSequenceClassification._model_mapping.keys() + ): + model_cls = AutoModelForSequenceClassification + elif type(config_cls) in AutoModelForTokenClassification._model_mapping.keys(): + model_cls = AutoModelForTokenClassification + else: + model_cls = AutoModelForCausalLM + + if train_from_scratch: + model = model_cls.from_config( + config_cls, trust_remote_code=trust_remote_code, device_map=device_map + ) + else: + model = model_cls.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + device_map=device_map, + ) + + return model + + +def load_model_from_pt(model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + acceleration = kwargs.get("acceleration", False) + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + model_file = os.path.join(model_path, "model.pt") + if not os.path.exists(model_file): + logger.error(f"Model file not found at {model_file}.") + raise ModelNotExistError(model_file) + model = torch.jit.load(model_file) + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: + return model + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + return model.to(device_map) + + +def load_model_for_efficient_inference(): + # TODO: An efficient model loading method for inference based on model_arguments + pass + + +def load_model_for_powerful_finetune(): + # TODO: An powerful model loading method for finetune based on model_arguments + pass + + +def unload_model(): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index e346f569102e3..5194ed4df1bd7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -20,43 +20,37 @@ import json import os import shutil -from collections.abc import Callable -from typing import Dict +from pathlib import Path +from typing import Dict, List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, ModelCategory, - ModelFileType, ModelStates, + UriType, ) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import ( + ensure_init_file, + get_parsed_uri, + import_class_from_path, + load_model_config_in_json, + parse_uri_type, + temporary_sys_path, + validate_model_files, +) from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus @@ -64,320 +58,368 @@ logger = Logger() -class ModelStorage(object): +class ModelStorage: + """Model storage class - unified management of model discovery and registration""" + def __init__(self): - self._model_dir = os.path.join( + self._models_dir = os.path.join( os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() ) - if not os.path.exists(self._model_dir): - try: - os.makedirs(self._model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._builtin_model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir() - ) - if not os.path.exists(self._builtin_model_dir): - try: - os.makedirs(self._builtin_model_dir) - except PermissionError as e: - logger.error(e) - raise e + # Unified storage: category -> {model_id -> ModelInfo} + self._models: Dict[str, Dict[str, ModelInfo]] = { + ModelCategory.BUILTIN.value: {}, + ModelCategory.USER_DEFINED.value: {}, + } + # Async download executor + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + # Thread lock pool for protecting concurrent access to model information self._lock_pool = ModelLockPool() - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1 - ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading. - self._model_info_map: Dict[str, ModelInfo] = {} - self._init_model_info_map() + self._initialize_directories() + self.discover_all_models() + + def _initialize_directories(self): + """Initialize directory structure and ensure __init__.py files exist""" + os.makedirs(self._models_dir, exist_ok=True) + ensure_init_file(self._models_dir) + for category in ModelCategory: + category_path = os.path.join(self._models_dir, category.value) + os.makedirs(category_path, exist_ok=True) + ensure_init_file(category_path) + + # ==================== Discovery Methods ==================== + + def discover_all_models(self): + """Scan file system to discover all models""" + self._discover_category(ModelCategory.BUILTIN) + self._discover_category(ModelCategory.USER_DEFINED) + + def _discover_category(self, category: ModelCategory): + """Discover all models in a category directory""" + category_path = os.path.join(self._models_dir, category.value) + if category == ModelCategory.BUILTIN: + self._discover_builtin_models(category_path) + elif category == ModelCategory.USER_DEFINED: + for model_id in os.listdir(category_path): + if os.path.isdir(os.path.join(category_path, model_id)): + self._process_user_defined_model_directory( + os.path.join(category_path, model_id), model_id + ) - def _init_model_info_map(self): - """ - Initialize the model info map. - """ - # 1. initialize built-in and ready-to-use models - for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP: - self._model_info_map[model_id] = BUILT_IN_MACHINE_LEARNING_MODEL_MAP[ - model_id - ] - # 2. retrieve fine-tuned models from the built-in model directory - fine_tuned_models = self._retrieve_fine_tuned_models() - for model_id in fine_tuned_models: - self._model_info_map[model_id] = fine_tuned_models[model_id] - # 3. automatically downloading the weights of built-in LSTM models when necessary - for model_id in BUILT_IN_LTSM_MAP: - if model_id not in self._model_info_map: - self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id] - future = self._executor.submit( - self._download_built_in_model_if_necessary, model_id - ) - future.add_done_callback( - lambda f, mid=model_id: self._callback_model_download_result(f, mid) - ) - # 4. retrieve user-defined models from the model directory - user_defined_models = self._retrieve_user_defined_models() - for model_id in user_defined_models: - self._model_info_map[model_id] = user_defined_models[model_id] + def _discover_builtin_models(self, category_path: str): + # Register SKTIME models directly from map + for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_SKTIME_MODEL_MAP[model_id] + ) - def _retrieve_fine_tuned_models(self): - """ - Retrieve fine-tuned models from the built-in model directory. + # Process HuggingFace Transformers models + for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): + model_dir = os.path.join(category_path, model_id) + os.makedirs(model_dir, exist_ok=True) + self._process_builtin_model_directory(model_dir, model_id) - Returns: - {"model_id": ModelInfo} - """ - result = {} - build_in_dirs = [ - d - for d in os.listdir(self._builtin_model_dir) - if os.path.isdir(os.path.join(self._builtin_model_dir, d)) - ] - for model_id in build_in_dirs: - config_file_path = os.path.join( - self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON + def _process_builtin_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a builtin model directory.""" + ensure_init_file(model_dir) + with self._lock_pool.get_lock(model_id).write_lock(): + # Check if model already exists and is in a valid state + existing_model = self._models[ModelCategory.BUILTIN.value].get(model_id) + if existing_model: + # If model is already ACTIVATING or ACTIVE, skip duplicate download + if existing_model.state in (ModelStates.ACTIVATING, ModelStates.ACTIVE): + return + + # If model not exists or is INACTIVE, we'll try to update its info and download its weights + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] ) - if os.path.isfile(config_file_path): - with open(config_file_path, "r") as f: - model_config = json.load(f) - if "model_type" in model_config: - model_type = model_config["model_type"] - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=ModelCategory.FINE_TUNED, - state=ModelStates.ACTIVE, + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.ACTIVATING + + def _download_model_if_necessary() -> bool: + """Returns: True if the model is existed or downloaded successfully, False otherwise.""" + repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + local_dir=model_dir, ) - # Refactor the built-in model category - if "timer_xl" == model_id: - model_info.category = ModelCategory.BUILT_IN - if "sundial" == model_id: - model_info.category = ModelCategory.BUILT_IN - # Compatible patch with the codes in HuggingFace - if "timer" == model_type: - model_info.model_type = BuiltInModelType.TIMER_XL.value - if "sundial" == model_type: - model_info.model_type = BuiltInModelType.SUNDIAL.value - result[model_id] = model_info - return result - - def _download_built_in_model_if_necessary(self, model_id: str) -> bool: - """ - Download the built-in model if it is not already downloaded. - - Args: - model_id (str): The ID of the model to download. + except Exception as e: + logger.error( + f"Failed to download model weights from HuggingFace: {e}" + ) + return False + if not os.path.exists(config_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE_IN_JSON, + local_dir=model_dir, + ) + except Exception as e: + logger.error( + f"Failed to download model config from HuggingFace: {e}" + ) + return False + return True - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - with self._lock_pool.get_lock(model_id).write_lock(): - local_dir = os.path.join(self._builtin_model_dir, model_id) - return download_built_in_ltsm_from_hf_if_necessary( - get_built_in_model_type(self._model_info_map[model_id].model_type), - local_dir, - ) + future = self._executor.submit(_download_model_if_necessary) + future.add_done_callback( + lambda f, mid=model_id: self._callback_model_download_result(f, mid) + ) def _callback_model_download_result(self, future, model_id: str): + """Callback function for handling model download results""" with self._lock_pool.get_lock(model_id).write_lock(): - if future.result(): - self._model_info_map[model_id].state = ModelStates.ACTIVE - logger.info( - f"The built-in model: {model_id} is active and ready to use." - ) - else: - self._model_info_map[model_id].state = ModelStates.INACTIVE - - def _retrieve_user_defined_models(self): - """ - Retrieve user_defined models from the model directory. + try: + if future.result(): + model_info = self._models[ModelCategory.BUILTIN.value][model_id] + model_info.state = ModelStates.ACTIVE + config_path = os.path.join( + self._models_dir, + ModelCategory.BUILTIN.value, + model_id, + MODEL_CONFIG_FILE_IN_JSON, + ) + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + if model_info.model_type == "": + model_info.model_type = config.get("model_type", "") + model_info.auto_map = config.get("auto_map", None) + logger.info( + f"Model {model_id} downloaded successfully and is ready to use." + ) + else: + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + logger.warning(f"Failed to download model {model_id}.") + except Exception as e: + logger.error(f"Error in download callback for model {model_id}: {e}") + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + + def _process_user_defined_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a user-defined model directory.""" + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + model_type = "" + auto_map = {} + pipeline_cls = "" + if os.path.exists(config_path): + config = load_model_config_in_json(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map", None) + pipeline_cls = config.get("pipeline_cls", "") - Returns: - {"model_id": ModelInfo} - """ - result = {} - user_dirs = [ - d - for d in os.listdir(self._model_dir) - if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights" - ] - for model_id in user_dirs: - result[model_id] = ModelInfo( + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( model_id=model_id, - model_type="", + model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + _transformers_registered=False, # Lazy registration ) - return result + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + # ==================== Registration Methods ==================== - def register_model(self, model_id: str, uri: str): + def register_model(self, model_id: str, uri: str) -> bool: """ - Args: - model_id: id of model to register - uri: network or local dir path of the model to register - Returns: - configs: TConfigs - attributes: str + Supported URI formats: + - repo:// (Maybe in the future) + - file:// """ + uri_type = parse_uri_type(uri) + parsed_uri = get_parsed_uri(uri) + + model_dir = os.path.join( + self._models_dir, ModelCategory.USER_DEFINED.value, model_id + ) + os.makedirs(model_dir, exist_ok=True) + ensure_init_file(model_dir) + + if uri_type == UriType.REPO: + self._fetch_model_from_hf_repo(parsed_uri, model_dir) + else: + self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) + + config_path, _ = validate_model_files(model_dir) + config = load_model_config_in_json(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + pipeline_cls = config.get("pipeline_cls", "") + with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - # create storage dir if not exist - if not os.path.exists(storage_path): - os.makedirs(storage_path) - uri_type, parsed_uri, model_file_type = get_model_register_strategy(uri) - self._model_info_map[model_id] = ModelInfo( + model_info = ModelInfo( model_id=model_id, - model_type="", + model_type=model_type, category=ModelCategory.USER_DEFINED, - state=ModelStates.LOADING, + state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + _transformers_registered=False, # Register later + ) + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + if auto_map: + # Transformers model: immediately register to Transformers auto-loading mechanism + success = self._register_transformers_model(model_info) + if success: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info._transformers_registered = True + else: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info.state = ModelStates.INACTIVE + logger.error(f"Failed to register Transformers model {model_id}") + return False + else: + # Other type models: only log + self._register_other_model(model_info) + + logger.info(f"Successfully registered model {model_id} from URI: {uri}") + return True + + def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" + ) + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise + + def _fetch_model_from_local(self, source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") + source_dir = Path(source_path) + if not source_dir.is_dir(): + raise ValueError( + f"Source path does not exist or is not a directory: {source_path}" ) - try: - # TODO: The uri should be fetched asynchronously - configs, attributes = fetch_model_by_uri( - uri_type, parsed_uri, storage_path, model_file_type - ) - self._model_info_map[model_id].state = ModelStates.ACTIVE - return configs, attributes - except Exception as e: - logger.error(f"Failed to register model {model_id}: {e}") - self._model_info_map[model_id].state = ModelStates.INACTIVE - raise e - def delete_model(self, model_id: str) -> None: - """ - Args: - model_id: id of model to delete - Returns: - None - """ - # check if the model is built-in - with self._lock_pool.get_lock(model_id).read_lock(): - if self._is_built_in(model_id): - raise BuiltInModelDeletionError(model_id) + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + return - # delete the user-defined or fine-tuned model - with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - storage_path = os.path.join(self._builtin_model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - if model_id in self._model_info_map: - del self._model_info_map[model_id] - logger.info(f"Model {model_id} deleted successfully.") - - def _is_built_in(self, model_id: str) -> bool: + def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ - Check if the model_id corresponds to a built-in model. + Register Transformers model to auto-loading mechanism (internal method) + """ + auto_map = model_info.auto_map + if not auto_map: + return False - Args: - model_id (str): The ID of the model. + auto_config_path = auto_map.get("AutoConfig") + auto_model_path = auto_map.get("AutoModelForCausalLM") - Returns: - bool: True if the model is built-in, False otherwise. - """ - return ( - model_id in self._model_info_map - and self._model_info_map[model_id].category == ModelCategory.BUILT_IN - ) + try: + model_path = os.path.join( + self._models_dir, model_info.category.value, model_info.model_id + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_class = import_class_from_path( + model_info.model_id, auto_config_path + ) + AutoConfig.register(model_info.model_type, config_class) + logger.info( + f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}" + ) - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. + model_class = import_class_from_path( + model_info.model_id, auto_model_path + ) + AutoModelForCausalLM.register(config_class, model_class) + logger.info( + f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" + ) - Args: - model_id (str): The ID of the model. + return True + except Exception as e: + logger.warning( + f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct." + ) + return False - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return model_id in self._model_info_map and ( - self._model_info_map[model_id].category == ModelCategory.BUILT_IN - or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED + def _register_other_model(self, model_info: ModelInfo): + """Register other type models (non-Transformers models)""" + logger.info( + f"Registered other type model: {model_info.model_id} ({model_info.model_type})" ) - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool - ) -> Callable: + def ensure_transformers_registered(self, model_id: str) -> ModelInfo: """ - Load a model with automatic detection of .safetensors or .pt format - + Ensure Transformers model is registered (called for lazy registration) + This method uses locks to ensure thread safety. All check logic is within lock protection. Returns: - model: The model instance corresponding to specific model_id - """ - with self._lock_pool.get_lock(model_id).read_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - return fetch_built_in_model( - get_built_in_model_type(self._model_info_map[model_id].model_type), - model_dir, - inference_attrs, - ) - else: - # load the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - model_file_type = get_model_file_type(model_dir) - if model_file_type == ModelFileType.SAFETENSORS: - # TODO: Support this function - raise UnsupportedError("SAFETENSORS format") - else: - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - - if not os.path.exists(model_path): - raise ModelNotExistError(model_path) - model = torch.jit.load(model_path) - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - - try: - model = torch.compile(model) - except Exception as e: - logger.warning( - f"acceleration failed, fallback to normal mode: {str(e)}" - ) - return model - - def save_model(self, model_id: str, model: nn.Module): - """ - Save the model using save_pretrained - - Returns: - Whether saving succeeded + str: If None, registration failed, otherwise returns model path """ + # Use lock to protect entire check-execute process with self._lock_pool.get_lock(model_id).write_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - model.save_pretrained(model_dir) - else: - # save the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - os.makedirs(model_dir, exist_ok=True) - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - try: - scripted_model = ( - model - if isinstance(model, torch.jit.ScriptModule) - else torch.jit.script(model) + # Directly access _models dictionary (avoid calling get_model_info which may cause deadlock) + model_info = None + for category_dict in self._models.values(): + if model_id in category_dict: + model_info = category_dict[model_id] + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot register") + return None + + # If already registered, return directly + if model_info._transformers_registered: + return model_info + + # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) + if ( + not model_info.auto_map + or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() + ): + model_info._transformers_registered = True + return model_info + + # Execute registration (under lock protection) + try: + success = self._register_transformers_model(model_info) + if success: + model_info._transformers_registered = True + logger.info( + f"Model {model_id} successfully registered to Transformers" ) - torch.jit.save(scripted_model, model_path) - except Exception as e: - logger.error(f"Failed to save scripted model: {e}") - - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. + return model_info + else: + model_info.state = ModelStates.INACTIVE + logger.error(f"Model {model_id} failed to register to Transformers") + return None - Args: - model_id (str): The ID of the model. + except Exception as e: + # Ensure state consistency in exception cases + model_info.state = ModelStates.INACTIVE + model_info._transformers_registered = False + logger.error( + f"Exception occurred while registering model {model_id} to Transformers: {e}" + ) + return None - Returns: - str: The path to the checkpoint file for the model. - """ - # Only support built-in models for now - return os.path.join(self._builtin_model_dir, f"{model_id}") + # ==================== Show and Delete Models ==================== def show_models(self, req: TShowModelsReq) -> TShowModelsResp: resp_status = TSStatus( @@ -385,8 +427,14 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: message="Show models successfully", ) if req.modelId: - if req.modelId in self._model_info_map: - model_info = self._model_info_map[req.modelId] + # Find specified model + model_info = None + for category_dict in self._models.values(): + if req.modelId in category_dict: + model_info = category_dict[req.modelId] + break + + if model_info: return TShowModelsResp( status=resp_status, modelIdList=[req.modelId], @@ -402,55 +450,133 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: categoryMap={}, stateMap={}, ) + # Return all models + model_id_list = [] + model_type_map = {} + category_map = {} + state_map = {} + + for category_dict in self._models.values(): + for model_id, model_info in category_dict.items(): + model_id_list.append(model_id) + model_type_map[model_id] = model_info.model_type + category_map[model_id] = model_info.category.value + state_map[model_id] = model_info.state.value + return TShowModelsResp( status=resp_status, - modelIdList=list(self._model_info_map.keys()), - modelTypeMap=dict( - (model_id, model_info.model_type) - for model_id, model_info in self._model_info_map.items() - ), - categoryMap=dict( - (model_id, model_info.category.value) - for model_id, model_info in self._model_info_map.items() - ), - stateMap=dict( - (model_id, model_info.state.value) - for model_id, model_info in self._model_info_map.items() - ), + modelIdList=model_id_list, + modelTypeMap=model_type_map, + categoryMap=category_map, + stateMap=state_map, ) - def register_built_in_model(self, model_info: ModelInfo): - with self._lock_pool.get_lock(model_info.model_id).write_lock(): - self._model_info_map[model_info.model_id] = model_info + def delete_model(self, model_id: str) -> None: + # Use write lock to protect entire deletion process + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = None + category_value = None + for cat_value, category_dict in self._models.items(): + if model_id in category_dict: + model_info = category_dict[model_id] + category_value = cat_value + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot delete") + return + + if model_info.category == ModelCategory.BUILTIN: + raise BuiltInModelDeletionError(model_id) + model_info.state = ModelStates.DROPPING + model_path = os.path.join( + self._models_dir, model_info.category.value, model_id + ) + if model_path.exists(): + try: + shutil.rmtree(model_path) + logger.info(f"Deleted model directory: {model_path}") + except Exception as e: + logger.error(f"Failed to delete model directory {model_path}: {e}") + raise - def get_model_info(self, model_id: str) -> ModelInfo: - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return self._model_info_map[model_id] - else: - raise ValueError(f"Model {model_id} does not exist.") + if category_value and model_id in self._models[category_value]: + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from storage") - def update_model_state(self, model_id: str, state: ModelStates): - with self._lock_pool.get_lock(model_id).write_lock(): - if model_id in self._model_info_map: - self._model_info_map[model_id].state = state - else: - raise ValueError(f"Model {model_id} does not exist.") + return - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: + # ==================== Query Methods ==================== + + def get_model_info( + self, model_id: str, category: Optional[ModelCategory] = None + ) -> Optional[ModelInfo]: """ - Get the type of the model with the given model_id. + Get single model information - Args: - model_id (str): The ID of the model. + If category is specified, use model_id's lock + If category is not specified, need to traverse all dictionaries, use global lock + """ + if category: + # Category specified, only need to access specific dictionary, use model_id's lock + with self._lock_pool.get_lock(model_id).read_lock(): + return self._models[category.value].get(model_id) + else: + # Category not specified, need to traverse all dictionaries, use global lock + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return category_dict[model_id] + return None + + def get_model_infos( + self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None + ) -> List[ModelInfo]: + """ + Get model information list - Returns: - str: The type of the model. + Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure + For single model access, using model_id-based lock would be more efficient """ - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return get_built_in_model_type( - self._model_info_map[model_id].model_type - ) + matching_models = [] + + # For traversal operations, we need to protect the entire dictionary structure + # Use a special lock (using empty string as key) to protect the entire dictionary + with self._lock_pool.get_lock("").read_lock(): + if category and model_type: + for model_info in self._models[category.value].values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models + elif category: + return list(self._models[category.value].values()) + elif model_type: + for category_dict in self._models.values(): + for model_info in category_dict.values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models else: - raise ValueError(f"Model {model_id} does not exist.") + for category_dict in self._models.values(): + matching_models.extend(category_dict.values()) + return matching_models + + def is_model_registered(self, model_id: str) -> bool: + """Check if model is registered (search in _models)""" + # Lazy registration: if it's a Transformers model and not registered, register it first + if self.ensure_transformers_registered(model_id) is None: + return False + + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return True + return False + + def get_registered_models(self) -> List[str]: + """Get list of all registered model IDs""" + with self._lock_pool.get_lock("").read_lock(): + model_ids = [] + for category_dict in self._models.values(): + model_ids.extend(category_dict.keys()) + return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json new file mode 100644 index 0000000000000..1561124badd12 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -0,0 +1,25 @@ +{ + "model_type": "sktime", + "model_id": "arima", + "predict_length": 1, + "order": [1, 0, 0], + "seasonal_order": [0, 0, 0, 0], + "start_params": null, + "method": "lbfgs", + "maxiter": 50, + "suppress_warnings": false, + "out_of_sample_size": 0, + "scoring": "mse", + "scoring_args": null, + "trend": null, + "with_intercept": true, + "time_varying_regression": false, + "enforce_stationarity": true, + "enforce_invertibility": true, + "simple_differencing": false, + "measurement_error": false, + "mle_regression": true, + "hamilton_representation": false, + "concentrate_scale": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py new file mode 100644 index 0000000000000..261de3c9abe7c --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -0,0 +1,409 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if value is None: + return True # Allow None for optional int parameters + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if value is None: + return True # Allow None for optional float parameters + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if value is None: + return True # Allow None for optional str parameters + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if value is None: + return True # Allow None for optional bool parameters + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + if string_value.lower() == "none" or string_value.strip() == "": + return None + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + if string_value.lower() == "none" or string_value.strip() == "": + return None + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + if string_value.lower() == "none" or string_value.strip() == "": + return None + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + elif string_value.lower() == "none" or string_value.strip() == "": + return None + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "strategy": AttributeConfig( + "strategy", "last", "str", choices=["last", "mean", "drift"] + ), + "window_length": AttributeConfig("window_length", None, "int"), + "sp": AttributeConfig("sp", 1, "int", 1, 5000), + }, + "EXPONENTIAL_SMOOTHING": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "damped_trend": AttributeConfig("damped_trend", False, "bool"), + "initialization_method": AttributeConfig( + "initialization_method", + "estimated", + "str", + choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + "optimized": AttributeConfig("optimized", True, "bool"), + "remove_bias": AttributeConfig("remove_bias", False, "bool"), + "use_brute": AttributeConfig("use_brute", False, "bool"), + }, + "ARIMA": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "seasonal_order": AttributeConfig( + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + ), + "start_params": AttributeConfig("start_params", None, "str"), + "method": AttributeConfig( + "method", + "lbfgs", + "str", + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + "maxiter": AttributeConfig("maxiter", 50, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", False, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "scoring_args": AttributeConfig("scoring_args", None, "str"), + "trend": AttributeConfig("trend", None, "str"), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), + }, + "STL_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "sp": AttributeConfig("sp", 2, "int", 1, 5000), + "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "trend": AttributeConfig("trend", None, "int"), + "low_pass": AttributeConfig("low_pass", None, "int"), + "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), + "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), + "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "robust": AttributeConfig("robust", False, "bool"), + "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), + "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), + "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + "inner_iter": AttributeConfig("inner_iter", None, "int"), + "outer_iter": AttributeConfig("outer_iter", None, "int"), + "forecaster_trend": AttributeConfig("forecaster_trend", None, "str"), + "forecaster_seasonal": AttributeConfig("forecaster_seasonal", None, "str"), + "forecaster_resid": AttributeConfig("forecaster_resid", None, "str"), + }, + "GAUSSIAN_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 0.01, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "random_state": AttributeConfig("random_state", None, "float"), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), + "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), + "init_params": AttributeConfig( + "init_params", "stmc", "str", choices=["stmc", "stm"] + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "GMM_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", None, "float"), + "covars_weight": AttributeConfig("covars_weight", None, "float"), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "random_state": AttributeConfig("random_state", None, "int"), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), + "init_params": AttributeConfig( + "init_params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "params": AttributeConfig( + "params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "STRAY": { + "alpha": AttributeConfig("alpha", 0.01, "float", -1e10, 1e10), + "k": AttributeConfig("k", 10, "int", 1, 5000), + "knn_algorithm": AttributeConfig( + "knn_algorithm", + "brute", + "str", + choices=["brute", "kd_tree", "ball_tree", "auto"], + ), + "p": AttributeConfig("p", 0.5, "float", -1e10, 1e10), + "size_threshold": AttributeConfig("size_threshold", 50, "int", 1, 5000), + "outlier_tail": AttributeConfig( + "outlier_tail", "max", "str", choices=["min", "max"] + ), + }, +} + + +def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: + """Get attribute configuration for Sktime model""" + model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id + if model_id not in MODEL_CONFIGS: + raise BuiltInModelNotSupportError(model_id) + return MODEL_CONFIGS[model_id] + + +def update_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, AttributeConfig] +) -> Dict[str, Any]: + """Update Sktime model attributes using input attributes""" + attributes = {} + for name, config in attribute_map.items(): + if name in input_attributes: + value = config.parse(input_attributes[name]) + config.validate_value(value) + attributes[name] = value + else: + attributes[name] = config.default + return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json new file mode 100644 index 0000000000000..4126d9de857a6 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "model_id": "exponential_smoothing", + "predict_length": 1, + "damped_trend": false, + "initialization_method": "estimated", + "optimized": true, + "remove_bias": false, + "use_brute": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json new file mode 100644 index 0000000000000..94f7d7ec659fc --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "model_id": "gaussian_hmm", + "n_components": 1, + "covariance_type": "diag", + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "means_prior": 0, + "means_weight": 0, + "covars_prior": 0.01, + "covars_weight": 1, + "algorithm": "viterbi", + "random_state": null, + "n_iter": 10, + "tol": 0.01, + "verbose": false, + "params": "stmc", + "init_params": "stmc", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json new file mode 100644 index 0000000000000..fb19d1aaf86d9 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json @@ -0,0 +1,24 @@ +{ + "model_type": "sktime", + "model_id": "gmm_hmm", + "n_components": 1, + "n_mix": 1, + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "weights_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "covars_prior": null, + "covars_weight": null, + "algorithm": "viterbi", + "covariance_type": "diag", + "random_state": null, + "n_iter": 10, + "tol": 0.01, + "verbose": false, + "init_params": "stmcw", + "params": "stmcw", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py new file mode 100644 index 0000000000000..eca812d35ec9a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import abstractmethod +from typing import Any, Dict + +import numpy as np +import pandas as pd +from sklearn.preprocessing import MinMaxScaler +from sktime.detection.hmm_learn import GMMHMM, GaussianHMM +from sktime.detection.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + InferenceModelInternalError, +) +from iotdb.ainode.core.log import Logger + +from .configuration_sktime import get_attributes, update_attribute + +logger = Logger() + + +class SktimeModel: + """Base class for Sktime models""" + + def __init__(self, attributes: Dict[str, Any]): + self._attributes = attributes + self._model = None + + @abstractmethod + def generate(self, data, **kwargs): + """Execute generation/inference""" + raise NotImplementedError + + +class ForecastingModel(SktimeModel): + """Base class for forecasting models""" + + def generate(self, data, **kwargs): + """Execute forecasting""" + try: + predict_length = kwargs.get( + "predict_length", self._attributes["predict_length"] + ) + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + return np.array(output, dtype=np.float64) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class DetectionModel(SktimeModel): + """Base class for detection models""" + + def generate(self, data, **kwargs): + """Execute detection""" + try: + predict_length = kwargs.get("predict_length", data.size) + output = self._model.fit_transform(data[:predict_length]) + if isinstance(output, pd.DataFrame): + return np.array(output["labels"], dtype=np.int32) + else: + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ArimaModel(ForecastingModel): + """ARIMA model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ARIMA( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class ExponentialSmoothingModel(ForecastingModel): + """Exponential smoothing model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ExponentialSmoothing( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class NaiveForecasterModel(ForecastingModel): + """Naive forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = NaiveForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class STLForecasterModel(ForecastingModel): + """STL forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STLForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class GMMHMMModel(DetectionModel): + """GMM HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GMMHMM(**attributes) + + +class GaussianHmmModel(DetectionModel): + """Gaussian HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GaussianHMM(**attributes) + + +class STRAYModel(DetectionModel): + """STRAY anomaly detection model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STRAY(**{k: v for k, v in attributes.items() if v is not None}) + + def generate(self, data, **kwargs): + """STRAY requires special handling: normalize first""" + try: + scaled_data = MinMaxScaler().fit_transform(data.values.reshape(-1, 1)) + scaled_data = pd.Series(scaled_data.flatten()) + return super().generate(scaled_data, **kwargs) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +# Model factory mapping +_MODEL_FACTORY = { + "ARIMA": ArimaModel, + "EXPONENTIAL_SMOOTHING": ExponentialSmoothingModel, + "HOLTWINTERS": ExponentialSmoothingModel, # Use the same model class + "NAIVE_FORECASTER": NaiveForecasterModel, + "STL_FORECASTER": STLForecasterModel, + "GMM_HMM": GMMHMMModel, + "GAUSSIAN_HMM": GaussianHmmModel, + "STRAY": STRAYModel, +} + + +def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: + """Create a Sktime model instance""" + attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) + model_class = _MODEL_FACTORY.get(model_id.upper()) + if model_class is None: + raise BuiltInModelNotSupportError(model_id) + return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json new file mode 100644 index 0000000000000..3dadd7c3b1e5d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json @@ -0,0 +1,9 @@ +{ + "model_type": "sktime", + "model_id": "naive_forecaster", + "predict_length": 1, + "strategy": "last", + "window_length": null, + "sp": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py new file mode 100644 index 0000000000000..ced21f29a2b82 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline + + +class SktimePipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + model_kwargs.pop("device", None) # sktime models run on CPU + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + input_ids = self._preprocess(inputs) + + # Convert to pandas Series for sktime (sktime expects Series or DataFrame) + # Handle batch dimension: if batch_size > 1, process each sample separately + if len(input_ids.shape) == 2 and input_ids.shape[0] > 1: + # Batch processing: convert each row to Series + outputs = [] + for i in range(input_ids.shape[0]): + series = pd.Series( + input_ids[i].cpu().numpy() + if isinstance(input_ids, torch.Tensor) + else input_ids[i] + ) + output = self.model.generate(series, predict_length=predict_length) + outputs.append(output) + output = np.array(outputs) + else: + # Single sample: convert to Series + if isinstance(input_ids, torch.Tensor): + series = pd.Series(input_ids.squeeze().cpu().numpy()) + else: + series = pd.Series(input_ids.squeeze()) + output = self.model.generate(series, predict_length=predict_length) + # Add batch dimension if needed + if len(output.shape) == 1: + output = output[np.newaxis, :] + + return self._postprocess(output) + + def _postprocess(self, output): + if isinstance(output, np.ndarray): + return torch.from_numpy(output).float() + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json new file mode 100644 index 0000000000000..bfe71dbc48614 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "model_id": "stl_forecaster", + "predict_length": 1, + "sp": 2, + "seasonal": 7, + "trend": null, + "low_pass": null, + "seasonal_deg": 1, + "trend_deg": 1, + "low_pass_deg": 1, + "robust": false, + "seasonal_jump": 1, + "trend_jump": 1, + "low_pass_jump": 1, + "inner_iter": null, + "outer_iter": null, + "forecaster_trend": null, + "forecaster_seasonal": null, + "forecaster_resid": null +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json new file mode 100644 index 0000000000000..e5bcc03cd0714 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "model_id": "stray", + "alpha": 0.01, + "k": 10, + "knn_algorithm": "brute", + "p": 0.5, + "size_threshold": 50, + "outlier_tail": "max" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index 3ebf516f705e0..dc1de32506e57 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -16,13 +16,10 @@ # under the License. # -import os -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors from torch import nn from transformers import Cache, DynamicCache, PreTrainedModel from transformers.activations import ACT2FN @@ -32,13 +29,10 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin -logger = Logger() - def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] @@ -616,11 +610,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): @@ -633,10 +623,9 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - token_num = ( - input_ids.shape[1] + self.config.input_token_len - 1 - ) // self.config.input_token_len - position_ids = position_ids[:, -token_num:] + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py similarity index 56% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 17c88e32fb5a0..85b6f7db2ffef 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -19,33 +19,33 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class TimerSundialInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-Sundial model inference. - """ +class SundialPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) - def __init__(self, model_config: SundialConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) + def _preprocess(self, inputs): if len(inputs.shape) != 2: raise InferenceModelInternalError( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py return inputs - def post_decode(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass - - def post_inference(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + num_samples = infer_kwargs.get("num_samples", 10) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, + max_new_tokens=predict_length, + num_samples=num_samples, + revin=revin, + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output.mean(dim=1) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py new file mode 100644 index 0000000000000..2a1e720805f29 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py similarity index 98% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py index 0a33c682742aa..fc9d7b41388bc 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py @@ -16,7 +16,7 @@ # under the License. # -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -29,11 +29,8 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin - -logger = Logger() +from iotdb.ainode.core.model.timer_xl.configuration_timer import TimerConfig +from iotdb.ainode.core.model.timer_xl.ts_generation_mixin import TSGenerationMixin def rotate_half(x): @@ -606,11 +603,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py similarity index 52% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index dc1dd304f68e8..c0f00b1f5caf3 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -19,33 +19,29 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class TimerXLInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-XL model inference. - """ +class TimerPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) - def __init__(self, model_config: TimerConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) + def _preprocess(self, inputs): if len(inputs.shape) != 2: raise InferenceModelInternalError( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) - # Considering that we are currently using the generate function interface, it seems that no pre-processing is required return inputs - def post_decode(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, max_new_tokens=predict_length, revin=revin + ) + return self._postprocess(output) - def post_inference(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py deleted file mode 100644 index b2e759e00ce07..0000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from enum import Enum -from typing import List - -from huggingface_hub import snapshot_download -from requests import Session -from requests.adapters import HTTPAdapter - -from iotdb.ainode.core.constant import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_RECONNECT_TIMEOUT, - DEFAULT_RECONNECT_TIMES, -) -from iotdb.ainode.core.exception import UnsupportedError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.model_info import get_model_file_type - -HTTP_PREFIX = "http://" -HTTPS_PREFIX = "https://" - -logger = Logger() - - -class UriType(Enum): - REPO = "repo" - FILE = "file" - HTTP = "http" - HTTPS = "https" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def parse_uri_type(uri: str): - """ - Parse the URI type from the given string. - """ - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): - return UriType.FILE - elif uri.startswith("http://"): - return UriType.HTTP - elif uri.startswith("https://"): - return UriType.HTTPS - else: - raise ValueError(f"Invalid URI type for {uri}") - - -def get_model_register_strategy(uri: str): - """ - Determine the loading strategy for a model based on its URI/path. - - Args: - uri (str): The URI of the model to be registered. - - Returns: - uri_type (UriType): The type of the URI, which can be one of: REPO, FILE, HTTP, or HTTPS. - parsed_uri (str): Parsed uri to get related file - model_file_type (ModelFileType): The type of the model file, which can be one of: SAFETENSORS, PYTORCH, or UNKNOWN. - """ - - uri_type = UriType.parse_uri_type(uri) - if uri_type in (UriType.HTTP, UriType.HTTPS): - # TODO: support HTTP(S) URI - raise UnsupportedError("CREATE MODEL FROM HTTP(S) URI") - else: - parsed_uri = uri[7:] - if uri_type == UriType.FILE: - # handle ~ in URI - parsed_uri = os.path.expanduser(parsed_uri) - model_file_type = get_model_file_type(uri) - elif uri_type == UriType.REPO: - # Currently, UriType.REPO only corresponds to huggingface repository with SAFETENSORS format - model_file_type = ModelFileType.SAFETENSORS - else: - raise ValueError(f"Invalid URI type for {uri}") - return uri_type, parsed_uri, model_file_type - - -def download_snapshot_from_hf(repo_id: str, local_dir: str): - """ - Download everything from a HuggingFace repository. - - Args: - repo_id (str): The HuggingFace repository ID. - local_dir (str): The local directory to save the downloaded files. - """ - try: - snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - ) - except Exception as e: - logger.error(f"Failed to download HuggingFace model {repo_id}: {e}") - raise e - - -def download_file(url: str, storage_path: str) -> None: - """ - Args: - url: url of file to download - storage_path: path to save the file - Returns: - None - """ - logger.info(f"Start Downloading file from {url} to {storage_path}") - session = Session() - adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES) - session.mount(HTTP_PREFIX, adapter) - session.mount(HTTPS_PREFIX, adapter) - response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) - response.raise_for_status() - with open(storage_path, "wb") as file: - for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): - if chunk: - file.write(chunk) - logger.info(f"Download file from {url} to {storage_path} success") diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py new file mode 100644 index 0000000000000..1cd0ee44912d5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import importlib +import json +import os.path +import sys +from contextlib import contextmanager +from typing import Dict, Tuple + +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + UriType, +) + + +def parse_uri_type(uri: str) -> UriType: + if uri.startswith("repo://"): + return UriType.REPO + elif uri.startswith("file://"): + return UriType.FILE + else: + raise ValueError( + f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + ) + + +def get_parsed_uri(uri: str) -> str: + return uri[7:] # Remove "repo://" or "file://" prefix + + +@contextmanager +def temporary_sys_path(path: str): + """Context manager for temporarily adding a path to sys.path""" + path_added = path not in sys.path + if path_added: + sys.path.insert(0, path) + try: + yield + finally: + if path_added and path in sys.path: + sys.path.remove(path) + + +def load_model_config_in_json(config_path: str) -> Dict: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def validate_model_files(model_dir: str) -> Tuple[str, str]: + """Validate model files exist, return config and weights file paths""" + + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + + if not os.path.exists(config_path): + raise ValueError(f"Model config file does not exist: {config_path}") + if not os.path.exists(weights_path): + raise ValueError(f"Model weights file does not exist: {weights_path}") + + # Create __init__.py file to ensure model directory can be imported as a module + init_file = os.path.join(model_dir, "__init__.py") + if not os.path.exists(init_file): + with open(init_file, "w"): + pass + + return config_path, weights_path + + +def import_class_from_path(module_name, class_path: str): + file_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name + "." + file_name) + return getattr(module, class_name) + + +def ensure_init_file(dir_path: str): + """Ensure __init__.py file exists in the given dir path""" + init_file = os.path.join(dir_path, "__init__.py") + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(init_file): + with open(init_file, "w"): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py index e2be6459508b7..ea6362ef080af 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py @@ -38,7 +38,6 @@ TAINodeRemoveReq, TAINodeRestartReq, TNodeVersionInfo, - TUpdateModelInfoReq, ) logger = Logger() @@ -155,13 +154,8 @@ def _wait_and_reconnect(self) -> None: self._try_to_connect() except TException: # can not connect to each config node - self._sync_latest_config_node_list() self._try_to_connect() - def _sync_latest_config_node_list(self) -> None: - # TODO - pass - def _update_config_node_leader(self, status: TSStatus) -> bool: if status.code == TSStatusCode.REDIRECTION_RECOMMEND.get_status_code(): if status.redirectNode is not None: @@ -271,36 +265,3 @@ def get_ainode_configuration(self, node_id: int) -> map: self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) - - def update_model_info( - self, - model_id: str, - model_status: int, - attribute: str = "", - ainode_id=None, - input_length=0, - output_length=0, - ) -> None: - if ainode_id is None: - ainode_id = [] - for _ in range(0, self._RETRY_NUM): - try: - req = TUpdateModelInfoReq(model_id, model_status, attribute) - if ainode_id is not None: - req.aiNodeIds = ainode_id - req.inputLength = input_length - req.outputLength = output_length - status = self._client.updateModelInfo(req) - if not self._update_config_node_leader(status): - verify_success( - status, "An error occurs when calling update model info" - ) - return status - except TTransport.TException: - logger.warning( - "Failed to connect to ConfigNode {} from AINode when executing update model info", - self._config_leader, - ) - self._config_leader = None - self._wait_and_reconnect() - raise TException(self._MSG_RECONNECTION_FAIL) diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index f01e1594f0698..6c4eedeb99f7f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -29,6 +29,7 @@ TAIHeartbeatResp, TDeleteModelReq, TForecastReq, + TForecastResp, TInferenceReq, TInferenceResp, TLoadModelReq, @@ -78,8 +79,14 @@ def stopAINode(self) -> TSStatus: def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) + def deleteModel(self, req: TDeleteModelReq) -> TSStatus: + return self._model_manager.delete_model(req) + + def showModels(self, req: TShowModelsReq) -> TShowModelsResp: + return self._model_manager.show_models(req) + def loadModel(self, req: TLoadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) + status = self._ensure_model_is_registered(req.existingModelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -88,7 +95,7 @@ def loadModel(self, req: TLoadModelReq) -> TSStatus: return self._inference_manager.load_model(req) def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + status = self._ensure_model_is_registered(req.modelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -96,21 +103,6 @@ def unloadModel(self, req: TUnloadModelReq) -> TSStatus: return status return self._inference_manager.unload_model(req) - def deleteModel(self, req: TDeleteModelReq) -> TSStatus: - return self._model_manager.delete_model(req) - - def inference(self, req: TInferenceReq) -> TInferenceResp: - return self._inference_manager.inference(req) - - def forecast(self, req: TForecastReq) -> TSStatus: - return self._inference_manager.forecast(req) - - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) - - def showModels(self, req: TShowModelsReq) -> TShowModelsResp: - return self._model_manager.show_models(req) - def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: status = _ensure_device_id_is_available(req.deviceIdList) if status.code != TSStatusCode.SUCCESS_STATUS.value: @@ -123,13 +115,28 @@ def showAIDevices(self) -> TShowAIDevicesResp: deviceIdList=get_available_devices(), ) + def inference(self, req: TInferenceReq) -> TInferenceResp: + status = self._ensure_model_is_registered(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TInferenceResp(status, []) + return self._inference_manager.inference(req) + + def forecast(self, req: TForecastReq) -> TForecastResp: + status = self._ensure_model_is_registered(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TForecastResp(status, []) + return self._inference_manager.forecast(req) + + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass - def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: - if not self._model_manager.is_built_in_or_fine_tuned(model_id): + def _ensure_model_is_registered(self, model_id: str) -> TSStatus: + if not self._model_manager.is_model_registered(model_id): return TSStatus( code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", + message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 331cb8ab32a34..e93bb3dfdaf10 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = ">=2.7.0" +torch = "^2.8.0,<2.9.0" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" @@ -93,7 +93,10 @@ scipy = "^1.12.0" pandas = "^2.3.2" scikit-learn = "^1.7.1" statsmodels = "^0.14.5" -sktime = "0.38.5" +sktime = "0.40.1" +pmdarima = "2.1.1" +hmmlearn = "0.3.2" +accelerate = "^1.10.1" # ---- Optimizers / utils ---- optuna = "^4.4.0" @@ -112,7 +115,7 @@ black = "25.1.0" isort = "6.0.1" setuptools = ">=75.3.0" joblib = ">=1.4.2" -urllib3 = ">=2.2.3" +urllib3 = "2.6.0" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main" diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java index 2721fedafb1e6..8d9081f435273 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java @@ -21,21 +21,28 @@ import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientPoolFactory; import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; +/** Asynchronously send RPC requests to AINodes. */ public class AsyncAINodeHeartbeatClientPool { - private final IClientManager clientManager; + private final IClientManager clientManager; private AsyncAINodeHeartbeatClientPool() { clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeHeartbeatClientPoolFactory()); + new IClientManager.Factory() + .createClientManager( + new ClientPoolFactory.AsyncAINodeHeartbeatServiceClientPoolFactory()); } + /** + * Only used in LoadManager. + * + * @param endPoint The specific DataNode + */ public void getAINodeHeartBeat( TEndPoint endPoint, TAIHeartbeatReq req, AINodeHeartbeatHandler handler) { try { @@ -56,6 +63,6 @@ private AsyncAINodeHeartbeatClientPoolHolder() { } public static AsyncAINodeHeartbeatClientPool getInstance() { - return AsyncAINodeHeartbeatClientPool.AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; + return AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; } } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java index ccc19f1a9f382..324e351302787 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java @@ -63,7 +63,6 @@ public void writeAuditLog( } } - // TODO: Is the AsyncDataNodeHeartbeatClientPool must be a singleton? private static class AsyncDataNodeHeartbeatClientPoolHolder { private static final AsyncDataNodeHeartbeatClientPool INSTANCE = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java index e0b2c144c0eac..65c1ee0a9fed5 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java @@ -21,8 +21,6 @@ import org.apache.iotdb.commons.exception.runtime.SerializationRunTimeException; import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.subscription.ShowTopicPlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; @@ -52,10 +50,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -574,24 +568,6 @@ public static ConfigPhysicalPlan create(final ByteBuffer buffer) throws IOExcept case UPDATE_CQ_LAST_EXEC_TIME: plan = new UpdateCQLastExecTimePlan(); break; - case CreateModel: - plan = new CreateModelPlan(); - break; - case UpdateModelInfo: - plan = new UpdateModelInfoPlan(); - break; - case DropModel: - plan = new DropModelPlan(); - break; - case ShowModel: - plan = new ShowModelPlan(); - break; - case DropModelInNode: - plan = new DropModelInNodePlan(); - break; - case GetModelInfo: - plan = new GetModelInfoPlan(); - break; case CreatePipePlugin: plan = new CreatePipePluginPlan(); break; diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java deleted file mode 100644 index dd79910e51fa8..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; - -import java.util.Objects; - -public class GetModelInfoPlan extends ConfigPhysicalReadPlan { - - private String modelId; - - public GetModelInfoPlan() { - super(ConfigPhysicalPlanType.GetModelInfo); - } - - public GetModelInfoPlan(final TGetModelInfoReq getModelInfoReq) { - super(ConfigPhysicalPlanType.GetModelInfo); - this.modelId = getModelInfoReq.getModelId(); - } - - public String getModelId() { - return modelId; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final GetModelInfoPlan that = (GetModelInfoPlan) o; - return Objects.equals(modelId, that.modelId); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java deleted file mode 100644 index eca00e8827d96..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; - -import java.util.Objects; - -public class ShowModelPlan extends ConfigPhysicalReadPlan { - - private String modelName; - - public ShowModelPlan() { - super(ConfigPhysicalPlanType.ShowModel); - } - - public ShowModelPlan(final TShowModelsReq showModelReq) { - super(ConfigPhysicalPlanType.ShowModel); - if (showModelReq.isSetModelId()) { - this.modelName = showModelReq.getModelId(); - } - } - - public boolean isSetModelName() { - return modelName != null; - } - - public String getModelName() { - return modelName; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final ShowModelPlan that = (ShowModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java deleted file mode 100644 index 61e37cdd21877..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class CreateModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public CreateModelPlan() { - super(ConfigPhysicalPlanType.CreateModel); - } - - public CreateModelPlan(String modelName) { - super(ConfigPhysicalPlanType.CreateModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - CreateModelPlan that = (CreateModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java deleted file mode 100644 index 885543f84e156..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelInNodePlan extends ConfigPhysicalPlan { - - private int nodeId; - - public DropModelInNodePlan() { - super(ConfigPhysicalPlanType.DropModelInNode); - } - - public DropModelInNodePlan(int nodeId) { - super(ConfigPhysicalPlanType.DropModelInNode); - this.nodeId = nodeId; - } - - public int getNodeId() { - return nodeId; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - stream.writeInt(nodeId); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - nodeId = buffer.getInt(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof DropModelInNodePlan)) return false; - DropModelInNodePlan that = (DropModelInNodePlan) o; - return nodeId == that.nodeId; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), nodeId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java deleted file mode 100644 index 813b116c645c5..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public DropModelPlan() { - super(ConfigPhysicalPlanType.DropModel); - } - - public DropModelPlan(String modelName) { - super(ConfigPhysicalPlanType.DropModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - DropModelPlan that = (DropModelPlan) o; - return modelName.equals(that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java deleted file mode 100644 index ce7219e428139..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - -public class UpdateModelInfoPlan extends ConfigPhysicalPlan { - - private String modelName; - private ModelInformation modelInformation; - - // The node which has the model which is only updated in model registration - private List nodeIds; - - public UpdateModelInfoPlan() { - super(ConfigPhysicalPlanType.UpdateModelInfo); - } - - public UpdateModelInfoPlan(String modelName, ModelInformation modelInformation) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = Collections.emptyList(); - } - - public UpdateModelInfoPlan( - String modelName, ModelInformation modelInformation, List nodeIds) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = nodeIds; - } - - public String getModelName() { - return modelName; - } - - public ModelInformation getModelInformation() { - return modelInformation; - } - - public List getNodeIds() { - return nodeIds; - } - - public void setNodeIds(List nodeIds) { - this.nodeIds = nodeIds; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - this.modelInformation.serialize(stream); - ReadWriteIOUtils.write(nodeIds.size(), stream); - for (Integer nodeId : nodeIds) { - ReadWriteIOUtils.write(nodeId, stream); - } - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - this.modelName = ReadWriteIOUtils.readString(buffer); - this.modelInformation = ModelInformation.deserialize(buffer); - int size = ReadWriteIOUtils.readInt(buffer); - this.nodeIds = new ArrayList<>(); - for (int i = 0; i < size; i++) { - this.nodeIds.add(ReadWriteIOUtils.readInt(buffer)); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - UpdateModelInfoPlan that = (UpdateModelInfoPlan) o; - return modelName.equals(that.modelName) - && modelInformation.equals(that.modelInformation) - && nodeIds.equals(that.nodeIds); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName, modelInformation, nodeIds); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java deleted file mode 100644 index cebc1301b8912..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.consensus.common.DataSet; - -public class GetModelInfoResp implements DataSet { - - private final TSStatus status; - - private int targetAINodeId; - private TEndPoint targetAINodeAddress; - - public TSStatus getStatus() { - return status; - } - - public GetModelInfoResp(TSStatus status) { - this.status = status; - } - - public int getTargetAINodeId() { - return targetAINodeId; - } - - public void setTargetAINodeId(int targetAINodeId) { - this.targetAINodeId = targetAINodeId; - } - - public void setTargetAINodeAddress(TAINodeConfiguration aiNodeConfiguration) { - if (aiNodeConfiguration.getLocation() == null) { - return; - } - this.targetAINodeAddress = aiNodeConfiguration.getLocation().getInternalEndPoint(); - } - - public TGetModelInfoResp convertToThriftResponse() { - TGetModelInfoResp resp = new TGetModelInfoResp(status); - resp.setAiNodeAddress(targetAINodeAddress); - return resp; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java deleted file mode 100644 index 7490a53a01c57..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.consensus.common.DataSet; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -// TODO: Will be removed in the future -public class ModelTableResp implements DataSet { - - private final TSStatus status; - private final List serializedAllModelInformation; - private Map modelTypeMap; - private Map algorithmMap; - - public ModelTableResp(TSStatus status) { - this.status = status; - this.serializedAllModelInformation = new ArrayList<>(); - } - - public void addModelInformation(List modelInformationList) throws IOException { - for (ModelInformation modelInformation : modelInformationList) { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - } - - public void addModelInformation(ModelInformation modelInformation) throws IOException { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - - public void setModelTypeMap(Map modelTypeMap) { - this.modelTypeMap = modelTypeMap; - } - - public void setAlgorithmMap(Map algorithmMap) { - this.algorithmMap = algorithmMap; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index 5d4b09adfc710..9d7151a8d20e3 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -19,22 +19,12 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.IDataSchema; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TFlushReq; import org.apache.iotdb.common.rpc.thrift.TPipeHeartbeatResp; import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet; @@ -58,7 +48,6 @@ import org.apache.iotdb.commons.conf.TrimProperties; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelStatus; import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.path.PathPatternTree; import org.apache.iotdb.commons.path.PathPatternUtil; @@ -97,7 +86,6 @@ import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; import org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan; import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp; @@ -129,7 +117,6 @@ import org.apache.iotdb.confignode.manager.schema.ClusterSchemaQuotaStatistics; import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -144,7 +131,6 @@ import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo; import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo; import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; @@ -163,13 +149,11 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq; @@ -186,7 +170,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -203,8 +186,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -257,11 +238,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.consensus.common.DataSet; import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.schemaengine.template.Template; import org.apache.iotdb.db.schemaengine.template.TemplateAlterOperationType; import org.apache.iotdb.db.schemaengine.template.alter.TemplateAlterOperationUtil; @@ -340,9 +318,6 @@ public class ConfigManager implements IManager { /** CQ. */ private final CQManager cqManager; - /** AI Model. */ - private final ModelManager modelManager; - /** Pipe */ private final PipeManager pipeManager; @@ -362,8 +337,6 @@ public class ConfigManager implements IManager { private static final String DATABASE = "\tDatabase="; - private static final String DOT = "."; - public ConfigManager() throws IOException { // Build the persistence module ClusterInfo clusterInfo = new ClusterInfo(); @@ -375,7 +348,6 @@ public ConfigManager() throws IOException { UDFInfo udfInfo = new UDFInfo(); TriggerInfo triggerInfo = new TriggerInfo(); CQInfo cqInfo = new CQInfo(); - ModelInfo modelInfo = new ModelInfo(); PipeInfo pipeInfo = new PipeInfo(); QuotaInfo quotaInfo = new QuotaInfo(); TTLInfo ttlInfo = new TTLInfo(); @@ -393,7 +365,6 @@ public ConfigManager() throws IOException { udfInfo, triggerInfo, cqInfo, - modelInfo, pipeInfo, subscriptionInfo, quotaInfo, @@ -415,7 +386,6 @@ public ConfigManager() throws IOException { this.udfManager = new UDFManager(this, udfInfo); this.triggerManager = new TriggerManager(this, triggerInfo); this.cqManager = new CQManager(this); - this.modelManager = new ModelManager(this, modelInfo); this.pipeManager = new PipeManager(this, pipeInfo); this.subscriptionManager = new SubscriptionManager(this, subscriptionInfo); this.auditLogger = new CNAuditLogger(this); @@ -1289,11 +1259,6 @@ public TriggerManager getTriggerManager() { return triggerManager; } - @Override - public ModelManager getModelManager() { - return modelManager; - } - @Override public PipeManager getPipeManager() { return pipeManager; @@ -2757,150 +2722,6 @@ public TSStatus transfer(List newUnknownDataList) { return transferResult; } - @Override - public TSStatus createModel(TCreateModelReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.createModel(req) - : status; - } - - private List fetchSchemaForTreeModel(TCreateTrainingReq req) { - List dataSchemaList = new ArrayList<>(); - for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) { - IDataSchema dataSchema = new IDataSchema(req.getDataSchemaForTree().getPath().get(i)); - dataSchema.setTimeRange(req.getTimeRanges().get(i)); - dataSchemaList.add(dataSchema); - } - return dataSchemaList; - } - - private List fetchSchemaForTableModel(TCreateTrainingReq req) { - return Collections.singletonList(new IDataSchema(req.getDataSchemaForTable().getTargetSql())); - } - - public TSStatus createTraining(TCreateTrainingReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - - TTrainingReq trainingReq = new TTrainingReq(); - trainingReq.setModelId(req.getModelId()); - if (req.isSetExistingModelId()) { - trainingReq.setExistingModelId(req.getExistingModelId()); - } - if (req.isSetParameters() && !req.getParameters().isEmpty()) { - trainingReq.setParameters(req.getParameters()); - } - - try { - status = getConsensusManager().write(new CreateModelPlan(req.getModelId())); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new MetadataException("Can't init model " + req.getModelId()); - } - - List dataSchema; - if (req.isTableModel) { - dataSchema = fetchSchemaForTableModel(req); - trainingReq.setDbType("iotdb.table"); - } else { - dataSchema = fetchSchemaForTreeModel(req); - trainingReq.setDbType("iotdb.tree"); - } - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal())); - trainingReq.setTargetDataSchema(dataSchema); - - TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0); - TEndPoint targetAINodeEndPoint = - new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort()); - try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) { - status = client.createTrainingTask(trainingReq); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new IllegalArgumentException(status.message); - } - } - } catch (final Exception e) { - status.setCode(TSStatusCode.CAN_NOT_CONNECT_CONFIGNODE.getStatusCode()); - status.setMessage(e.getMessage()); - try { - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.UNAVAILABLE.ordinal())); - } catch (Exception e2) { - LOGGER.error(e2.getMessage()); - } - } - return status; - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.dropModel(req) - : status; - } - - @Override - public TSStatus loadModel(TLoadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.loadModel(req) - : status; - } - - @Override - public TSStatus unloadModel(TUnloadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.unloadModel(req) - : status; - } - - @Override - public TShowModelsResp showModel(TShowModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showModel(req) - : new TShowModelsResp(status); - } - - @Override - public TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showLoadedModel(req) - : new TShowLoadedModelsResp(status, Collections.emptyMap()); - } - - @Override - public TShowAIDevicesResp showAIDevices() { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showAIDevices() - : new TShowAIDevicesResp(status, Collections.emptyList()); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.getModelInfo(req) - : new TGetModelInfoResp(status); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.updateModelInfo(req) - : status; - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) { TSStatus status = confirmLeader(); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java index 33e77db24907d..dff994d70e7e3 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java @@ -19,13 +19,6 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; @@ -82,7 +75,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -103,7 +95,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -120,8 +111,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -255,13 +244,6 @@ public interface IManager { */ CQManager getCQManager(); - /** - * Get {@link ModelManager}. - * - * @return {@link ModelManager} instance - */ - ModelManager getModelManager(); - /** * Get {@link PipeManager}. * @@ -880,30 +862,6 @@ TDataPartitionTableResp getOrCreateDataPartition( TSStatus transfer(List newUnknownDataList); - /** Create a model. */ - TSStatus createModel(TCreateModelReq req); - - /** Drop a model. */ - TSStatus dropModel(TDropModelReq req); - - /** Load the specific model to the specific devices. */ - TSStatus loadModel(TLoadModelReq req); - - /** Unload the specific model from the specific devices. */ - TSStatus unloadModel(TUnloadModelReq req); - - /** Return the model table. */ - TShowModelsResp showModel(TShowModelsReq req); - - /** Return the loaded model instances. */ - TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req); - - /** Return all available AI devices. */ - TShowAIDevicesResp showAIDevices(); - - /** Update the model state */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req); - /** Set space quota. */ TSStatus setSpaceQuota(TSetSpaceQuotaReq req); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java deleted file mode 100644 index 3efdbc222b6d2..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.manager; - -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.exception.NoAvailableAINodeException; -import org.apache.iotdb.confignode.persistence.ModelInfo; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class ModelManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class); - - private final ConfigManager configManager; - private final ModelInfo modelInfo; - - public ModelManager(ConfigManager configManager, ModelInfo modelInfo) { - this.configManager = configManager; - this.modelInfo = modelInfo; - } - - public TSStatus createModel(TCreateModelReq req) { - if (modelInfo.contain(req.modelName)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s already exists", req.modelName)); - } - try { - if (req.uri.isEmpty()) { - return configManager.getConsensusManager().write(new CreateModelPlan(req.modelName)); - } - return configManager.getProcedureManager().createModel(req.modelName, req.uri); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while getting model: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - public TSStatus dropModel(TDropModelReq req) { - if (modelInfo.checkModelType(req.getModelId()) != ModelType.USER_DEFINED) { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("Built-in model %s can't be removed", req.modelId)); - } - if (!modelInfo.contain(req.modelId)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s doesn't exists", req.modelId)); - } - return configManager.getProcedureManager().dropModel(req.getModelId()); - } - - public TSStatus loadModel(TLoadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq loadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq( - req.existingModelId, req.deviceIdList); - return client.loadModel(loadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to load model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TSStatus unloadModel(TUnloadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq unloadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq(req.modelId, req.deviceIdList); - return client.unloadModel(unloadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to unload model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TShowModelsResp showModel(final TShowModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowModelsReq showModelsReq = new TShowModelsReq(); - if (req.isSetModelId()) { - showModelsReq.setModelId(req.getModelId()); - } - TShowModelsResp resp = client.showModels(showModelsReq); - TShowModelsResp res = - new TShowModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setModelIdList(resp.getModelIdList()); - res.setModelTypeMap(resp.getModelTypeMap()); - res.setCategoryMap(resp.getCategoryMap()); - res.setStateMap(resp.getStateMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show models due to", e); - return new TShowModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowLoadedModelsResp showLoadedModel(final TShowLoadedModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowLoadedModelsReq showModelsReq = - new TShowLoadedModelsReq().setDeviceIdList(req.getDeviceIdList()); - TShowLoadedModelsResp resp = client.showLoadedModels(showModelsReq); - TShowLoadedModelsResp res = - new TShowLoadedModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceLoadedModelsMap(resp.getDeviceLoadedModelsMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show loaded models due to", e); - return new TShowLoadedModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowAIDevicesResp showAIDevices() { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp resp = client.showAIDevices(); - TShowAIDevicesResp res = - new TShowAIDevicesResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceIdList(resp.getDeviceIdList()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show AI devices due to", e); - return new TShowAIDevicesResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return new TGetModelInfoResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())) - .setAiNodeAddress( - configManager - .getNodeManager() - .getRegisteredAINodes() - .get(0) - .getLocation() - .getInternalEndPoint()); - } - - // Currently this method is only used by built-in timer_xl - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - if (!modelInfo.contain(req.getModelId())) { - return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode()) - .setMessage(String.format("Model %s doesn't exists", req.getModelId())); - } - try { - ModelInformation modelInformation = - new ModelInformation(ModelType.USER_DEFINED, req.getModelId()); - modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]); - modelInformation.setAttribute(req.getAttributes()); - modelInformation.setInputColumnSize(1); - if (req.isSetOutputLength()) { - modelInformation.setOutputLength(req.getOutputLength()); - } - if (req.isSetInputLength()) { - modelInformation.setInputLength(req.getInputLength()); - } - UpdateModelInfoPlan updateModelInfoPlan = - new UpdateModelInfoPlan(req.getModelId(), modelInformation); - if (req.isSetAiNodeIds()) { - updateModelInfoPlan.setNodeIds(req.getAiNodeIds()); - } - return configManager.getConsensusManager().write(updateModelInfoPlan); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while updating model info: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - private AINodeClient getAINodeClient() throws NoAvailableAINodeException, ClientManagerException { - List aiNodeInfo = configManager.getNodeManager().getRegisteredAINodeInfoList(); - if (aiNodeInfo.isEmpty()) { - throw new NoAvailableAINodeException(); - } - TEndPoint targetAINodeEndPoint = - new TEndPoint(aiNodeInfo.get(0).getInternalAddress(), aiNodeInfo.get(0).getInternalPort()); - try { - return AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public List getModelDistributions(String modelName) { - return modelInfo.getNodeIds(modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java index 2e4227af3fc8b..d67e7721eef83 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java @@ -61,8 +61,6 @@ import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler; import org.apache.iotdb.confignode.procedure.env.RemoveDataNodeHandler; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -1414,24 +1412,6 @@ public TSStatus createCQ(TCreateCQReq req, ScheduledExecutorService scheduledExe return waitingProcedureFinished(procedure); } - public TSStatus createModel(String modelName, String uri) { - long procedureId = executor.submitProcedure(new CreateModelProcedure(modelName, uri)); - LOGGER.info("CreateModelProcedure was submitted, procedureId: {}.", procedureId); - return RpcUtils.SUCCESS_STATUS; - } - - public TSStatus dropModel(String modelId) { - DropModelProcedure procedure = new DropModelProcedure(modelId); - executor.submitProcedure(procedure); - TSStatus status = waitingProcedureFinished(procedure); - if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return status; - } else { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(status.getMessage()); - } - } - public TSStatus createPipePlugin( PipePluginMeta pipePluginMeta, byte[] jarFile, boolean isSetIfNotExistsCondition) { final CreatePipePluginProcedure createPipePluginProcedure = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java deleted file mode 100644 index aeada03d15cc3..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.persistence; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelTable; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.commons.snapshot.SnapshotProcessor; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; -import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.PublicBAOS; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.annotation.concurrent.ThreadSafe; - -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; - -@ThreadSafe -public class ModelInfo implements SnapshotProcessor { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelInfo.class); - - private static final String SNAPSHOT_FILENAME = "model_info.snapshot"; - - private ModelTable modelTable; - - private final Map> modelNameToNodes; - - private final ReadWriteLock modelTableLock = new ReentrantReadWriteLock(); - - private static final Set builtInForecastModel = new HashSet<>(); - - private static final Set builtInAnomalyDetectionModel = new HashSet<>(); - - static { - builtInForecastModel.add("arima"); - builtInForecastModel.add("naive_forecaster"); - builtInForecastModel.add("stl_forecaster"); - builtInForecastModel.add("holtwinters"); - builtInForecastModel.add("exponential_smoothing"); - builtInForecastModel.add("timer_xl"); - builtInForecastModel.add("sundial"); - builtInAnomalyDetectionModel.add("gaussian_hmm"); - builtInAnomalyDetectionModel.add("gmm_hmm"); - builtInAnomalyDetectionModel.add("stray"); - } - - public ModelInfo() { - this.modelTable = new ModelTable(); - this.modelNameToNodes = new HashMap<>(); - } - - public boolean contain(String modelName) { - return modelTable.containsModel(modelName); - } - - public void acquireModelTableReadLock() { - LOGGER.info("acquire ModelTableReadLock"); - modelTableLock.readLock().lock(); - } - - public void releaseModelTableReadLock() { - LOGGER.info("release ModelTableReadLock"); - modelTableLock.readLock().unlock(); - } - - public void acquireModelTableWriteLock() { - LOGGER.info("acquire ModelTableWriteLock"); - modelTableLock.writeLock().lock(); - } - - public void releaseModelTableWriteLock() { - LOGGER.info("release ModelTableWriteLock"); - modelTableLock.writeLock().unlock(); - } - - // init the model in modeInfo, it won't update the details information of the model - public TSStatus createModel(CreateModelPlan plan) { - try { - acquireModelTableWriteLock(); - String modelName = plan.getModelName(); - modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING)); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } catch (Exception e) { - final String errorMessage = - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes, because of %s", - plan.getModelName(), e); - LOGGER.warn(errorMessage, e); - return new TSStatus(TSStatusCode.CREATE_MODEL_ERROR.getStatusCode()).setMessage(errorMessage); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModelInNode(int aiNodeId) { - acquireModelTableWriteLock(); - try { - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - entry.getValue().remove(Integer.valueOf(aiNodeId)); - // if list is empty, remove this model totally - if (entry.getValue().isEmpty()) { - modelTable.removeModel(entry.getKey()); - modelNameToNodes.remove(entry.getKey()); - } - } - // currently, we only have one AINode at a time, so we can just clear failed model. - modelTable.clearFailedModel(); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModel(String modelName) { - acquireModelTableWriteLock(); - TSStatus status; - if (modelTable.containsModel(modelName)) { - modelTable.removeModel(modelName); - modelNameToNodes.remove(modelName); - status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - status = - new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("model [%s] has not been created.", modelName)); - } - releaseModelTableWriteLock(); - return status; - } - - public List getNodeIds(String modelName) { - return modelNameToNodes.getOrDefault(modelName, Collections.emptyList()); - } - - private ModelInformation getModelByName(String modelName) { - ModelType modelType = checkModelType(modelName); - if (modelType != ModelType.USER_DEFINED) { - if (modelType == ModelType.BUILT_IN_FORECAST && builtInForecastModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName); - } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION - && builtInAnomalyDetectionModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName); - } - } else { - return modelTable.getModelInformationById(modelName); - } - return null; - } - - public ModelTableResp showModel(ShowModelPlan plan) { - acquireModelTableReadLock(); - try { - ModelTableResp modelTableResp = - new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - if (plan.isSetModelName()) { - ModelInformation modelInformation = getModelByName(plan.getModelName()); - if (modelInformation != null) { - modelTableResp.addModelInformation(modelInformation); - } - } else { - modelTableResp.addModelInformation(modelTable.getAllModelInformation()); - for (String modelName : builtInForecastModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName)); - } - for (String modelName : builtInAnomalyDetectionModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName)); - } - } - return modelTableResp; - } catch (IOException e) { - LOGGER.warn("Fail to get ModelTable", e); - return new ModelTableResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - private boolean containsBuiltInModelName(Set builtInModelSet, String modelName) { - // ignore the case - for (String builtInModelName : builtInModelSet) { - if (builtInModelName.equalsIgnoreCase(modelName)) { - return true; - } - } - return false; - } - - public ModelType checkModelType(String modelName) { - if (containsBuiltInModelName(builtInForecastModel, modelName)) { - return ModelType.BUILT_IN_FORECAST; - } else if (containsBuiltInModelName(builtInAnomalyDetectionModel, modelName)) { - return ModelType.BUILT_IN_ANOMALY_DETECTION; - } else { - return ModelType.USER_DEFINED; - } - } - - private int getAvailableAINodeForModel(String modelName, ModelType modelType) { - if (modelType == ModelType.USER_DEFINED) { - List aiNodeIds = modelNameToNodes.get(modelName); - if (aiNodeIds != null) { - return aiNodeIds.get(0); - } - } else { - // any AINode is fine for built-in model - // 0 is always the nodeId for configNode, so it's fine to use 0 as special value - return 0; - } - return -1; - } - - // This method will be used by dataNode to get schema of the model for inference - public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { - acquireModelTableReadLock(); - try { - String modelName = plan.getModelId(); - GetModelInfoResp getModelInfoResp; - ModelInformation modelInformation; - ModelType modelType; - // check if it's a built-in model - if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { - modelInformation = new ModelInformation(modelType, modelName); - } else { - modelInformation = modelTable.getModelInformationById(modelName); - } - - if (modelInformation != null) { - getModelInfoResp = - new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - } else { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("model [%s] has not been created.", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } - PublicBAOS buffer = new PublicBAOS(); - DataOutputStream stream = new DataOutputStream(buffer); - modelInformation.serialize(stream); - // select the nodeId to process the task, currently we default use the first one. - int aiNodeId = getAvailableAINodeForModel(modelName, modelType); - if (aiNodeId == -1) { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("There is no AINode with %s available", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } else { - getModelInfoResp.setTargetAINodeId(aiNodeId); - } - return getModelInfoResp; - } catch (IOException e) { - LOGGER.warn("Fail to get model info", e); - return new GetModelInfoResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - public TSStatus updateModelInfo(UpdateModelInfoPlan plan) { - acquireModelTableWriteLock(); - try { - String modelName = plan.getModelName(); - if (modelTable.containsModel(modelName)) { - modelTable.updateModel(modelName, plan.getModelInformation()); - } - if (!plan.getNodeIds().isEmpty()) { - // only used in model registration, so we can just put the nodeIds in the map without - // checking - modelNameToNodes.put(modelName, plan.getNodeIds()); - } - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - @Override - public boolean processTakeSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (snapshotFile.exists() && snapshotFile.isFile()) { - LOGGER.error( - "Failed to take snapshot of ModelInfo, because snapshot file [{}] is already exist.", - snapshotFile.getAbsolutePath()); - return false; - } - - acquireModelTableReadLock(); - try (FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)) { - modelTable.serialize(fileOutputStream); - ReadWriteIOUtils.write(modelNameToNodes.size(), fileOutputStream); - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - ReadWriteIOUtils.write(entry.getKey(), fileOutputStream); - ReadWriteIOUtils.write(entry.getValue().size(), fileOutputStream); - for (Integer nodeId : entry.getValue()) { - ReadWriteIOUtils.write(nodeId, fileOutputStream); - } - } - fileOutputStream.getFD().sync(); - return true; - } finally { - releaseModelTableReadLock(); - } - } - - @Override - public void processLoadSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (!snapshotFile.exists() || !snapshotFile.isFile()) { - LOGGER.error( - "Failed to load snapshot of ModelInfo, snapshot file [{}] does not exist.", - snapshotFile.getAbsolutePath()); - return; - } - acquireModelTableWriteLock(); - try (FileInputStream fileInputStream = new FileInputStream(snapshotFile)) { - modelTable.clear(); - modelTable = ModelTable.deserialize(fileInputStream); - int size = ReadWriteIOUtils.readInt(fileInputStream); - for (int i = 0; i < size; i++) { - String modelName = ReadWriteIOUtils.readString(fileInputStream); - int nodeSize = ReadWriteIOUtils.readInt(fileInputStream); - List nodes = new LinkedList<>(); - for (int j = 0; j < nodeSize; j++) { - nodes.add(ReadWriteIOUtils.readInt(fileInputStream)); - } - modelNameToNodes.put(modelName, nodes); - } - } finally { - releaseModelTableWriteLock(); - } - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java index fe8b28c4da2e5..d6bad518f6f4b 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java @@ -35,8 +35,6 @@ import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetUDFJarPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetNodePathsPartitionPlan; @@ -84,10 +82,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -150,7 +144,6 @@ import org.apache.iotdb.confignode.exception.physical.UnknownPhysicalPlanTypeException; import org.apache.iotdb.confignode.manager.pipe.agent.PipeConfigNodeAgent; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -210,8 +203,6 @@ public class ConfigPlanExecutor { private final CQInfo cqInfo; - private final ModelInfo modelInfo; - private final PipeInfo pipeInfo; private final SubscriptionInfo subscriptionInfo; @@ -230,7 +221,6 @@ public ConfigPlanExecutor( UDFInfo udfInfo, TriggerInfo triggerInfo, CQInfo cqInfo, - ModelInfo modelInfo, PipeInfo pipeInfo, SubscriptionInfo subscriptionInfo, QuotaInfo quotaInfo, @@ -262,9 +252,6 @@ public ConfigPlanExecutor( this.cqInfo = cqInfo; this.snapshotProcessorList.add(cqInfo); - this.modelInfo = modelInfo; - this.snapshotProcessorList.add(modelInfo); - this.pipeInfo = pipeInfo; this.snapshotProcessorList.add(pipeInfo); @@ -362,10 +349,6 @@ public DataSet executeQueryPlan(final ConfigPhysicalReadPlan req) return udfInfo.getUDFJar((GetUDFJarPlan) req); case GetAllFunctionTable: return udfInfo.getAllUDFTable(); - case ShowModel: - return modelInfo.showModel((ShowModelPlan) req); - case GetModelInfo: - return modelInfo.getModelInfo((GetModelInfoPlan) req); case GetPipePluginTable: return pipeInfo.getPipePluginInfo().showPipePlugins(); case GetPipePluginJar: @@ -648,14 +631,6 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) return cqInfo.activeCQ((ActiveCQPlan) physicalPlan); case UPDATE_CQ_LAST_EXEC_TIME: return cqInfo.updateCQLastExecutionTime((UpdateCQLastExecTimePlan) physicalPlan); - case CreateModel: - return modelInfo.createModel((CreateModelPlan) physicalPlan); - case UpdateModelInfo: - return modelInfo.updateModelInfo((UpdateModelInfoPlan) physicalPlan); - case DropModel: - return modelInfo.dropModel(((DropModelPlan) physicalPlan).getModelName()); - case DropModelInNode: - return modelInfo.dropModelInNode(((DropModelInNodePlan) physicalPlan).getNodeId()); case CreatePipePlugin: return pipeInfo.getPipePluginInfo().createPipePlugin((CreatePipePluginPlan) physicalPlan); case DropPipePlugin: diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java index 280a891b598d7..27cc3cc4cbf5a 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java @@ -161,19 +161,19 @@ public boolean isJarNeededToBeSavedWhenCreatingPipePlugin(final String jarName) } public void checkPipePluginExistence( - final Map extractorAttributes, + final Map sourceAttributes, final Map processorAttributes, - final Map connectorAttributes) { - final PipeParameters extractorParameters = new PipeParameters(extractorAttributes); - final String extractorPluginName = - extractorParameters.getStringOrDefault( + final Map sinkAttributes) { + final PipeParameters sourceParameters = new PipeParameters(sourceAttributes); + final String sourcePluginName = + sourceParameters.getStringOrDefault( Arrays.asList(PipeSourceConstant.EXTRACTOR_KEY, PipeSourceConstant.SOURCE_KEY), IOTDB_EXTRACTOR.getPipePluginName()); - if (!pipePluginMetaKeeper.containsPipePlugin(extractorPluginName)) { + if (!pipePluginMetaKeeper.containsPipePlugin(sourcePluginName)) { final String exceptionMessage = String.format( "Failed to create or alter pipe, the pipe extractor plugin %s does not exist", - extractorPluginName); + sourcePluginName); LOGGER.warn(exceptionMessage); throw new PipeException(exceptionMessage); } @@ -191,16 +191,16 @@ public void checkPipePluginExistence( throw new PipeException(exceptionMessage); } - final PipeParameters connectorParameters = new PipeParameters(connectorAttributes); - final String connectorPluginName = - connectorParameters.getStringOrDefault( + final PipeParameters sinkParameters = new PipeParameters(sinkAttributes); + final String sinkPluginName = + sinkParameters.getStringOrDefault( Arrays.asList(PipeSinkConstant.CONNECTOR_KEY, PipeSinkConstant.SINK_KEY), IOTDB_THRIFT_CONNECTOR.getPipePluginName()); - if (!pipePluginMetaKeeper.containsPipePlugin(connectorPluginName)) { + if (!pipePluginMetaKeeper.containsPipePlugin(sinkPluginName)) { final String exceptionMessage = String.format( "Failed to create or alter pipe, the pipe connector plugin %s does not exist", - connectorPluginName); + sinkPluginName); LOGGER.warn(exceptionMessage); throw new PipeException(exceptionMessage); } @@ -212,34 +212,29 @@ public TSStatus createPipePlugin(final CreatePipePluginPlan createPipePluginPlan try { final PipePluginMeta pipePluginMeta = createPipePluginPlan.getPipePluginMeta(); final String pluginName = pipePluginMeta.getPluginName(); + final String className = pipePluginMeta.getClassName(); + final String jarName = pipePluginMeta.getJarName(); // try to drop the old pipe plugin if exists to reduce the effect of the inconsistency dropPipePlugin(new DropPipePluginPlan(pluginName)); pipePluginMetaKeeper.addPipePluginMeta(pluginName, pipePluginMeta); - pipePluginMetaKeeper.addJarNameAndMd5( - pipePluginMeta.getJarName(), pipePluginMeta.getJarMD5()); + pipePluginMetaKeeper.addJarNameAndMd5(jarName, pipePluginMeta.getJarMD5()); if (createPipePluginPlan.getJarFile() != null) { pipePluginExecutableManager.savePluginToInstallDir( - ByteBuffer.wrap(createPipePluginPlan.getJarFile().getValues()), - pluginName, - pipePluginMeta.getJarName()); - final String pluginDirPath = pipePluginExecutableManager.getPluginsDirPath(pluginName); - final PipePluginClassLoader pipePluginClassLoader = - classLoaderManager.createPipePluginClassLoader(pluginDirPath); - try { - final Class pluginClass = - Class.forName(pipePluginMeta.getClassName(), true, pipePluginClassLoader); - pipePluginMetaKeeper.addPipePluginVisibility( - pluginName, VisibilityUtils.calculateFromPluginClass(pluginClass)); - classLoaderManager.addPluginAndClassLoader(pluginName, pipePluginClassLoader); - } catch (final Exception e) { - try { - pipePluginClassLoader.close(); - } catch (final Exception ignored) { - } - throw e; + ByteBuffer.wrap(createPipePluginPlan.getJarFile().getValues()), pluginName, jarName); + computeFromPluginClass(pluginName, className); + } else { + final String existed = pipePluginMetaKeeper.getPluginNameByJarName(jarName); + if (Objects.nonNull(existed)) { + pipePluginExecutableManager.linkExistedPlugin(existed, pluginName, jarName); + computeFromPluginClass(pluginName, className); + } else { + throw new PipeException( + String.format( + "The %s's creation has not passed in jarName, which does not exist in other pipePlugins. Please check", + pluginName)); } } @@ -255,6 +250,25 @@ public TSStatus createPipePlugin(final CreatePipePluginPlan createPipePluginPlan } } + private void computeFromPluginClass(final String pluginName, final String className) + throws Exception { + final String pluginDirPath = pipePluginExecutableManager.getPluginsDirPath(pluginName); + final PipePluginClassLoader pipePluginClassLoader = + classLoaderManager.createPipePluginClassLoader(pluginDirPath); + try { + final Class pluginClass = Class.forName(className, true, pipePluginClassLoader); + pipePluginMetaKeeper.addPipePluginVisibility( + pluginName, VisibilityUtils.calculateFromPluginClass(pluginClass)); + classLoaderManager.addPluginAndClassLoader(pluginName, pipePluginClassLoader); + } catch (final Exception e) { + try { + pipePluginClassLoader.close(); + } catch (final Exception ignored) { + } + throw e; + } + } + public TSStatus dropPipePlugin(final DropPipePluginPlan dropPipePluginPlan) { final String pluginName = dropPipePluginPlan.getPluginName(); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java deleted file mode 100644 index 989061610213d..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.manager.ConfigManager; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.CreateModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - -public class CreateModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class); - private static final int RETRY_THRESHOLD = 0; - - private String modelName; - - private String uri; - - private ModelInformation modelInformation = null; - - private List aiNodeIds; - - private String loadErrorMsg = ""; - - public CreateModelProcedure() { - super(); - } - - public CreateModelProcedure(String modelName, String uri) { - super(); - this.modelName = modelName; - this.uri = uri; - this.aiNodeIds = new ArrayList<>(); - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, CreateModelState state) { - if (modelName == null || uri == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case LOADING: - initModel(env); - loadModel(env); - setNextState(CreateModelState.ACTIVE); - break; - case ACTIVE: - modelInformation.updateStatus(ModelStatus.ACTIVE); - updateModel(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing createModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in CreateModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to create model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - modelInformation = new ModelInformation(modelName, ModelStatus.UNAVAILABLE); - modelInformation.setAttribute(loadErrorMsg); - updateModel(env); - setFailure( - new ProcedureException( - String.format("Fail to create model [%s] at STATE [%s]", modelName, state))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void initModel(ConfigNodeProcedureEnv env) throws ConsensusException { - LOGGER.info("Start to add model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - TSStatus response = configManager.getConsensusManager().write(new CreateModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } - - private void checkModelInformationEquals(ModelInformation receiveModelInfo) { - if (modelInformation == null) { - modelInformation = receiveModelInfo; - } else { - if (!modelInformation.equals(receiveModelInfo)) { - throw new ModelManagementException( - String.format( - "Failed to load model [%s] on AI Nodes, model information is not equal in different nodes", - modelName)); - } - } - } - - private void loadModel(ConfigNodeProcedureEnv env) { - for (TAINodeConfiguration curNodeConfig : - env.getConfigManager().getNodeManager().getRegisteredAINodes()) { - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient(curNodeConfig.getLocation().getInternalEndPoint())) { - ModelInformation resp = client.registerModel(modelName, uri); - checkModelInformationEquals(resp); - aiNodeIds.add(curNodeConfig.getLocation().aiNodeId); - } catch (LoadModelException e) { - LOGGER.warn(e.getMessage()); - loadErrorMsg = e.getMessage(); - } catch (Exception e) { - LOGGER.warn( - "Failed to load model on AINode {} from ConfigNode", - curNodeConfig.getLocation().getInternalEndPoint()); - loadErrorMsg = e.getMessage(); - } - } - - if (aiNodeIds.isEmpty()) { - throw new ModelManagementException( - String.format("CREATE MODEL [%s] failed on all AINodes:[%s]", modelName, loadErrorMsg)); - } - } - - private void updateModel(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to update model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - try { - TSStatus response = - configManager - .getConsensusManager() - .write(new UpdateModelInfoPlan(modelName, modelInformation, aiNodeIds)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, CreateModelState state) - throws IOException, InterruptedException, ProcedureException { - // do nothing - } - - @Override - protected boolean isRollbackSupported(CreateModelState state) { - return false; - } - - @Override - protected CreateModelState getState(int stateId) { - return CreateModelState.values()[stateId]; - } - - @Override - protected int getStateId(CreateModelState createModelState) { - return createModelState.ordinal(); - } - - @Override - protected CreateModelState getInitialState() { - return CreateModelState.LOADING; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - ReadWriteIOUtils.write(uri, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - uri = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof CreateModelProcedure) { - CreateModelProcedure thatProc = (CreateModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && Objects.equals(thatProc.modelName, this.modelName) - && Objects.equals(thatProc.uri, this.uri); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName, uri); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java deleted file mode 100644 index daa029e04ddfd..0000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.DropModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Objects; - -import static org.apache.iotdb.confignode.procedure.state.model.DropModelState.CONFIG_NODE_DROPPED; - -public class DropModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class); - private static final int RETRY_THRESHOLD = 1; - - private String modelName; - - public DropModelProcedure() { - super(); - } - - public DropModelProcedure(String modelName) { - super(); - this.modelName = modelName; - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state) { - if (modelName == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case AI_NODE_DROPPED: - LOGGER.info("Start to drop model [{}] on AI Nodes", modelName); - dropModelOnAINode(env); - setNextState(CONFIG_NODE_DROPPED); - break; - case CONFIG_NODE_DROPPED: - dropModelOnConfigNode(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing dropModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in DropModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to drop model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - setFailure( - new ProcedureException( - String.format( - "Fail to drop model [%s] at STATE [%s], %s", - modelName, state, e.getMessage()))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void dropModelOnAINode(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to drop model file [{}] on AI Node", modelName); - - List aiNodes = - env.getConfigManager().getNodeManager().getRegisteredAINodes(); - aiNodes.forEach( - aiNode -> { - int nodeId = aiNode.getLocation().getAiNodeId(); - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient( - env.getConfigManager() - .getNodeManager() - .getRegisteredAINode(nodeId) - .getLocation() - .getInternalEndPoint())) { - TSStatus status = client.deleteModel(new TDeleteModelReq(modelName)); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - status.getMessage()); - } - } catch (Exception e) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - e.getMessage()); - } - }); - } - - private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) { - try { - TSStatus response = - env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(response.getMessage()); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Fail to start training model [%s] on AI Node: %s", modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, DropModelState state) - throws IOException, InterruptedException, ProcedureException { - // no need to rollback - } - - @Override - protected DropModelState getState(int stateId) { - return DropModelState.values()[stateId]; - } - - @Override - protected int getStateId(DropModelState dropModelState) { - return dropModelState.ordinal(); - } - - @Override - protected DropModelState getInitialState() { - return DropModelState.AI_NODE_DROPPED; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.DROP_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof DropModelProcedure) { - DropModelProcedure thatProc = (DropModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && (thatProc.modelName).equals(this.modelName); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java index 2cab08c28244e..2a1c6881b1413 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java @@ -23,13 +23,12 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; import org.apache.iotdb.confignode.procedure.exception.ProcedureException; import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState; import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.slf4j.Logger; @@ -65,17 +64,11 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState st try { switch (state) { - case MODEL_DELETE: - env.getConfigManager() - .getConsensusManager() - .write(new DropModelInNodePlan(removedAINode.aiNodeId)); - // Cause the AINode is removed, so we don't need to remove the model file. - setNextState(RemoveAINodeState.NODE_STOP); - break; case NODE_STOP: TSStatus resp = null; try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(removedAINode.getInternalEndPoint())) { + AINodeClientManager.getInstance() + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { resp = client.stopAINode(); } catch (Exception e) { LOGGER.warn( @@ -148,7 +141,7 @@ protected int getStateId(RemoveAINodeState removeAINodeState) { @Override protected RemoveAINodeState getInitialState() { - return RemoveAINodeState.MODEL_DELETE; + return RemoveAINodeState.NODE_STOP; } @Override diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java index f6b84cb1f6e3a..53f908bf4fc11 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java @@ -187,6 +187,7 @@ public void executeFromCalculateInfoForTask(final ConfigNodeProcedureEnv env) { && !databaseName.startsWith(SchemaConstant.SYSTEM_DATABASE + ".") && !databaseName.equals(SchemaConstant.AUDIT_DATABASE) && !databaseName.startsWith(SchemaConstant.AUDIT_DATABASE + ".") + && !Objects.isNull(currentPipeTaskMeta) && !(PipeTaskAgent.isHistoryOnlyPipe( currentPipeStaticMeta.getSourceParameters()) && PipeTaskAgent.isHistoryOnlyPipe( diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java index 8a1a6a1bb03b5..49820df663616 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java @@ -20,7 +20,6 @@ package org.apache.iotdb.confignode.procedure.state; public enum RemoveAINodeState { - MODEL_DELETE, NODE_STOP, NODE_REMOVE } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java index e023171f4fa88..f20a6999d5936 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java @@ -22,8 +22,6 @@ import org.apache.iotdb.commons.exception.runtime.ThriftSerDeException; import org.apache.iotdb.confignode.procedure.Procedure; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -263,12 +261,6 @@ public Procedure create(ByteBuffer buffer) throws IOException { case DROP_PIPE_PLUGIN_PROCEDURE: procedure = new DropPipePluginProcedure(); break; - case CREATE_MODEL_PROCEDURE: - procedure = new CreateModelProcedure(); - break; - case DROP_MODEL_PROCEDURE: - procedure = new DropModelProcedure(); - break; case AUTH_OPERATE_PROCEDURE: procedure = new AuthOperationProcedure(false); break; @@ -494,10 +486,6 @@ public static ProcedureType getProcedureType(final Procedure procedure) { return ProcedureType.CREATE_PIPE_PLUGIN_PROCEDURE; } else if (procedure instanceof DropPipePluginProcedure) { return ProcedureType.DROP_PIPE_PLUGIN_PROCEDURE; - } else if (procedure instanceof CreateModelProcedure) { - return ProcedureType.CREATE_MODEL_PROCEDURE; - } else if (procedure instanceof DropModelProcedure) { - return ProcedureType.DROP_MODEL_PROCEDURE; } else if (procedure instanceof CreatePipeProcedureV2) { return ProcedureType.CREATE_PIPE_PROCEDURE_V2; } else if (procedure instanceof StartPipeProcedureV2) { diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java index 65ac1fb24ad5a..d076a7d9d926e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java @@ -85,7 +85,9 @@ public enum ProcedureType { RENAME_VIEW_PROCEDURE((short) 764), /** AI Model */ + @Deprecated // Since 2.0.6, all models are managed by AINode CREATE_MODEL_PROCEDURE((short) 800), + @Deprecated // Since 2.0.6, all models are managed by AINode DROP_MODEL_PROCEDURE((short) 801), REMOVE_AI_NODE_PROCEDURE((short) 802), diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java index 59ce7352312f0..6582a5bfff8e1 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java @@ -115,7 +115,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -144,7 +143,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -163,8 +161,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -226,7 +222,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.confignode.service.ConfigNode; import org.apache.iotdb.consensus.exception.ConsensusException; import org.apache.iotdb.db.queryengine.plan.relational.type.AuthorRType; @@ -1362,26 +1357,6 @@ public TShowCQResp showCQ() { return configManager.showCQ(); } - @Override - public TSStatus createModel(TCreateModelReq req) { - return configManager.createModel(req); - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - return configManager.dropModel(req); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return configManager.getModelInfo(req); - } - - @Override - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return configManager.updateModelInfo(req); - } - @Override public TSStatus setSpaceQuota(final TSetSpaceQuotaReq req) throws TException { return configManager.setSpaceQuota(req); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java index 375285dfa2fb9..3f19a2101a805 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java @@ -814,6 +814,9 @@ public class IoTDBConfig { /** time cost(ms) threshold for slow query. Unit: millisecond */ private long slowQueryThreshold = 10000; + /** time window threshold for record of history queries. Unit: minute */ + private int queryCostStatWindow = 0; + private int patternMatchingThreshold = 1000000; /** @@ -1121,6 +1124,21 @@ public class IoTDBConfig { private int loadTsFileSpiltPartitionMaxSize = 10; + /** + * The threshold for splitting statement when loading multiple TsFiles. When the number of TsFiles + * exceeds this threshold, the statement will be split into multiple sub-statements for batch + * execution to limit resource consumption during statement analysis. Default value is 10, which + * means splitting will occur when there are more than 10 files. + */ + private int loadTsFileStatementSplitThreshold = 10; + + /** + * The number of TsFiles that each sub-statement handles when splitting a statement. This + * parameter controls how many files are grouped together in each sub-statement during batch + * execution. Default value is 10, which means each sub-statement handles 10 files. + */ + private int loadTsFileSubStatementBatchSize = 10; + private String[] loadActiveListeningDirs = new String[] { IoTDBConstant.EXT_FOLDER_NAME @@ -2613,6 +2631,14 @@ public void setSlowQueryThreshold(long slowQueryThreshold) { this.slowQueryThreshold = slowQueryThreshold; } + public int getQueryCostStatWindow() { + return queryCostStatWindow; + } + + public void setQueryCostStatWindow(int queryCostStatWindow) { + this.queryCostStatWindow = queryCostStatWindow; + } + public boolean isEnableIndex() { return enableIndex; } @@ -4057,6 +4083,46 @@ public void setLoadTsFileSpiltPartitionMaxSize(int loadTsFileSpiltPartitionMaxSi this.loadTsFileSpiltPartitionMaxSize = loadTsFileSpiltPartitionMaxSize; } + public int getLoadTsFileStatementSplitThreshold() { + return loadTsFileStatementSplitThreshold; + } + + public void setLoadTsFileStatementSplitThreshold(final int loadTsFileStatementSplitThreshold) { + if (loadTsFileStatementSplitThreshold < 0) { + logger.warn( + "Invalid loadTsFileStatementSplitThreshold value: {}. Using default value: 10", + loadTsFileStatementSplitThreshold); + return; + } + if (this.loadTsFileStatementSplitThreshold != loadTsFileStatementSplitThreshold) { + logger.info( + "loadTsFileStatementSplitThreshold changed from {} to {}", + this.loadTsFileStatementSplitThreshold, + loadTsFileStatementSplitThreshold); + } + this.loadTsFileStatementSplitThreshold = loadTsFileStatementSplitThreshold; + } + + public int getLoadTsFileSubStatementBatchSize() { + return loadTsFileSubStatementBatchSize; + } + + public void setLoadTsFileSubStatementBatchSize(final int loadTsFileSubStatementBatchSize) { + if (loadTsFileSubStatementBatchSize <= 0) { + logger.warn( + "Invalid loadTsFileSubStatementBatchSize value: {}. Using default value: 10", + loadTsFileSubStatementBatchSize); + return; + } + if (this.loadTsFileSubStatementBatchSize != loadTsFileSubStatementBatchSize) { + logger.info( + "loadTsFileSubStatementBatchSize changed from {} to {}", + this.loadTsFileSubStatementBatchSize, + loadTsFileSubStatementBatchSize); + } + this.loadTsFileSubStatementBatchSize = loadTsFileSubStatementBatchSize; + } + public String[] getPipeReceiverFileDirs() { return (Objects.isNull(this.pipeReceiverFileDirs) || this.pipeReceiverFileDirs.length == 0) ? new String[] {systemDir + File.separator + "pipe" + File.separator + "receiver"} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java index c194f304e7f6a..e51f445c56784 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java @@ -813,6 +813,11 @@ public void loadProperties(TrimProperties properties) throws BadNodeUrlException properties.getProperty( "slow_query_threshold", String.valueOf(conf.getSlowQueryThreshold())))); + conf.setQueryCostStatWindow( + Integer.parseInt( + properties.getProperty( + "query_cost_stat_window", String.valueOf(conf.getQueryCostStatWindow())))); + conf.setDataRegionNum( Integer.parseInt( properties.getProperty("data_region_num", String.valueOf(conf.getDataRegionNum())))); @@ -2406,6 +2411,18 @@ private void loadLoadTsFileProps(TrimProperties properties) { properties.getProperty( "skip_failed_table_schema_check", String.valueOf(conf.isSkipFailedTableSchemaCheck())))); + + conf.setLoadTsFileStatementSplitThreshold( + Integer.parseInt( + properties.getProperty( + "load_tsfile_statement_split_threshold", + Integer.toString(conf.getLoadTsFileStatementSplitThreshold())))); + + conf.setLoadTsFileSubStatementBatchSize( + Integer.parseInt( + properties.getProperty( + "load_tsfile_sub_statement_batch_size", + Integer.toString(conf.getLoadTsFileSubStatementBatchSize())))); } private void loadLoadTsFileHotModifiedProp(TrimProperties properties) throws IOException { @@ -2454,6 +2471,18 @@ private void loadLoadTsFileHotModifiedProp(TrimProperties properties) throws IOE "load_tsfile_split_partition_max_size", Integer.toString(conf.getLoadTsFileSpiltPartitionMaxSize())))); + conf.setLoadTsFileStatementSplitThreshold( + Integer.parseInt( + properties.getProperty( + "load_tsfile_statement_split_threshold", + Integer.toString(conf.getLoadTsFileStatementSplitThreshold())))); + + conf.setLoadTsFileSubStatementBatchSize( + Integer.parseInt( + properties.getProperty( + "load_tsfile_sub_statement_batch_size", + Integer.toString(conf.getLoadTsFileSubStatementBatchSize())))); + conf.setSkipFailedTableSchemaCheck( Boolean.parseBoolean( properties.getProperty( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java index d9c3fabfae8ac..c136ffbe7d3e3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java @@ -33,7 +33,6 @@ import org.apache.tsfile.utils.PublicBAOS; import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.apache.tsfile.write.record.Tablet; import java.io.DataOutputStream; import java.io.IOException; @@ -247,11 +246,7 @@ public static PipeTransferTabletBatchReqV2 fromTPipeTransferReq( size = ReadWriteIOUtils.readInt(transferReq.body); for (int i = 0; i < size; ++i) { - batchReq.tabletReqs.add( - PipeTransferTabletRawReqV2.toTPipeTransferRawReq( - Tablet.deserialize(transferReq.body), - ReadWriteIOUtils.readBool(transferReq.body), - ReadWriteIOUtils.readString(transferReq.body))); + batchReq.tabletReqs.add(PipeTransferTabletRawReqV2.toTPipeTransferRawReq(transferReq.body)); } batchReq.version = transferReq.version; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java index 7da44f297d140..af9b37edbf6ff 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java @@ -22,6 +22,7 @@ import org.apache.iotdb.commons.exception.MetadataException; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.PipeRequestType; +import org.apache.iotdb.db.pipe.sink.util.TabletStatementConverter; import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter; import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq; @@ -43,10 +44,25 @@ public class PipeTransferTabletRawReq extends TPipeTransferReq { private static final Logger LOGGER = LoggerFactory.getLogger(PipeTransferTabletRawReq.class); - protected transient Tablet tablet; + protected transient InsertTabletStatement statement; + protected transient boolean isAligned; + protected transient Tablet tablet; + /** + * Get Tablet. If tablet is null, convert from statement. + * + * @return Tablet object + */ public Tablet getTablet() { + if (tablet == null && statement != null) { + try { + tablet = statement.convertToTablet(); + } catch (final MetadataException e) { + LOGGER.warn("Failed to convert statement to tablet.", e); + return null; + } + } return tablet; } @@ -54,16 +70,29 @@ public boolean getIsAligned() { return isAligned; } + /** + * Construct Statement. If statement already exists, return it. Otherwise, convert from tablet. + * + * @return InsertTabletStatement + */ public InsertTabletStatement constructStatement() { + if (statement != null) { + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + return statement; + } + + // Sort and deduplicate tablet before converting new PipeTreeModelTabletEventSorter(tablet).deduplicateAndSortTimestampsIfNecessary(); try { if (isTabletEmpty(tablet)) { // Empty statement, will be filtered after construction - return new InsertTabletStatement(); + statement = new InsertTabletStatement(); + return statement; } - return new InsertTabletStatement(tablet, isAligned, null); + statement = new InsertTabletStatement(tablet, isAligned, null); + return statement; } catch (final MetadataException e) { LOGGER.warn("Generate Statement from tablet {} error.", tablet, e); return null; @@ -107,8 +136,20 @@ public static PipeTransferTabletRawReq toTPipeTransferReq( public static PipeTransferTabletRawReq fromTPipeTransferReq(final TPipeTransferReq transferReq) { final PipeTransferTabletRawReq tabletReq = new PipeTransferTabletRawReq(); - tabletReq.tablet = Tablet.deserialize(transferReq.body); - tabletReq.isAligned = ReadWriteIOUtils.readBool(transferReq.body); + final ByteBuffer buffer = transferReq.body; + final int startPosition = buffer.position(); + try { + // V1: no databaseName, readDatabaseName = false + final InsertTabletStatement insertTabletStatement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer, false); + tabletReq.isAligned = insertTabletStatement.isAligned(); + // devicePath is already set in deserializeStatementFromTabletFormat for V1 format + tabletReq.statement = insertTabletStatement; + } catch (final Exception e) { + buffer.position(startPosition); + tabletReq.tablet = Tablet.deserialize(buffer); + tabletReq.isAligned = ReadWriteIOUtils.readBool(buffer); + } tabletReq.version = transferReq.version; tabletReq.type = transferReq.type; @@ -118,18 +159,56 @@ public static PipeTransferTabletRawReq fromTPipeTransferReq(final TPipeTransferR /////////////////////////////// Air Gap /////////////////////////////// - public static byte[] toTPipeTransferBytes(final Tablet tablet, final boolean isAligned) - throws IOException { + /** + * Serialize to bytes. If tablet is null, convert from statement first. + * + * @return serialized bytes + * @throws IOException if serialization fails + */ + public byte[] toTPipeTransferBytes() throws IOException { + Tablet tabletToSerialize = tablet; + boolean isAlignedToSerialize = isAligned; + + // If tablet is null, convert from statement + if (tabletToSerialize == null && statement != null) { + try { + tabletToSerialize = statement.convertToTablet(); + isAlignedToSerialize = statement.isAligned(); + } catch (final MetadataException e) { + throw new IOException("Failed to convert statement to tablet for serialization", e); + } + } + + if (tabletToSerialize == null) { + throw new IOException("Cannot serialize: both tablet and statement are null"); + } + try (final PublicBAOS byteArrayOutputStream = new PublicBAOS(); final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) { ReadWriteIOUtils.write(IoTDBSinkRequestVersion.VERSION_1.getVersion(), outputStream); ReadWriteIOUtils.write(PipeRequestType.TRANSFER_TABLET_RAW.getType(), outputStream); - tablet.serialize(outputStream); - ReadWriteIOUtils.write(isAligned, outputStream); + tabletToSerialize.serialize(outputStream); + ReadWriteIOUtils.write(isAlignedToSerialize, outputStream); return byteArrayOutputStream.toByteArray(); } } + /** + * Static method for backward compatibility. Creates a temporary instance and serializes. + * + * @param tablet Tablet to serialize + * @param isAligned whether aligned + * @return serialized bytes + * @throws IOException if serialization fails + */ + public static byte[] toTPipeTransferBytes(final Tablet tablet, final boolean isAligned) + throws IOException { + final PipeTransferTabletRawReq req = new PipeTransferTabletRawReq(); + req.tablet = tablet; + req.isAligned = isAligned; + return req.toTPipeTransferBytes(); + } + /////////////////////////////// Object /////////////////////////////// @Override @@ -141,7 +220,16 @@ public boolean equals(final Object obj) { return false; } final PipeTransferTabletRawReq that = (PipeTransferTabletRawReq) obj; - return Objects.equals(tablet, that.tablet) + // Compare statement if both have it, otherwise compare tablet + if (statement != null && that.statement != null) { + return Objects.equals(statement, that.statement) + && isAligned == that.isAligned + && version == that.version + && type == that.type + && Objects.equals(body, that.body); + } + // Fallback to tablet comparison + return Objects.equals(getTablet(), that.getTablet()) && isAligned == that.isAligned && version == that.version && type == that.type @@ -150,6 +238,6 @@ public boolean equals(final Object obj) { @Override public int hashCode() { - return Objects.hash(tablet, isAligned, version, type, body); + return Objects.hash(getTablet(), isAligned, version, type, body); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java index 43d8501252c5a..3c5f420a317fd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java @@ -22,6 +22,7 @@ import org.apache.iotdb.commons.exception.MetadataException; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.PipeRequestType; +import org.apache.iotdb.db.pipe.sink.util.TabletStatementConverter; import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTableModelTabletEventSorter; import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter; import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; @@ -52,6 +53,16 @@ public String getDataBaseName() { @Override public InsertTabletStatement constructStatement() { + if (statement != null) { + if (Objects.isNull(dataBaseName)) { + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + } else { + new PipeTableModelTabletEventSorter(statement).sortByTimestampIfNecessary(); + } + + return statement; + } + if (Objects.isNull(dataBaseName)) { new PipeTreeModelTabletEventSorter(tablet).deduplicateAndSortTimestampsIfNecessary(); } else { @@ -86,6 +97,16 @@ public static PipeTransferTabletRawReqV2 toTPipeTransferRawReq( return tabletReq; } + public static PipeTransferTabletRawReqV2 toTPipeTransferRawReq(final ByteBuffer buffer) { + final PipeTransferTabletRawReqV2 tabletReq = new PipeTransferTabletRawReqV2(); + + tabletReq.deserializeTPipeTransferRawReq(buffer); + tabletReq.version = IoTDBSinkRequestVersion.VERSION_1.getVersion(); + tabletReq.type = PipeRequestType.TRANSFER_TABLET_RAW_V2.getType(); + + return tabletReq; + } + /////////////////////////////// Thrift /////////////////////////////// public static PipeTransferTabletRawReqV2 toTPipeTransferReq( @@ -114,13 +135,11 @@ public static PipeTransferTabletRawReqV2 fromTPipeTransferReq( final TPipeTransferReq transferReq) { final PipeTransferTabletRawReqV2 tabletReq = new PipeTransferTabletRawReqV2(); - tabletReq.tablet = Tablet.deserialize(transferReq.body); - tabletReq.isAligned = ReadWriteIOUtils.readBool(transferReq.body); - tabletReq.dataBaseName = ReadWriteIOUtils.readString(transferReq.body); + tabletReq.deserializeTPipeTransferRawReq(transferReq.body); + tabletReq.body = transferReq.body; tabletReq.version = transferReq.version; tabletReq.type = transferReq.type; - tabletReq.body = transferReq.body; return tabletReq; } @@ -161,4 +180,27 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(super.hashCode(), dataBaseName); } + + /////////////////////////////// Util /////////////////////////////// + + public void deserializeTPipeTransferRawReq(final ByteBuffer buffer) { + final int startPosition = buffer.position(); + try { + // V2: read databaseName, readDatabaseName = true + final InsertTabletStatement insertTabletStatement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer, true); + this.isAligned = insertTabletStatement.isAligned(); + // databaseName is already set in deserializeStatementFromTabletFormat when + // readDatabaseName=true + this.dataBaseName = insertTabletStatement.getDatabaseName().orElse(null); + this.statement = insertTabletStatement; + } catch (final Exception e) { + // If Statement deserialization fails, fallback to Tablet format + // Reset buffer position for Tablet deserialization + buffer.position(startPosition); + this.tablet = Tablet.deserialize(buffer); + this.isAligned = ReadWriteIOUtils.readBool(buffer); + this.dataBaseName = ReadWriteIOUtils.readString(buffer); + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java new file mode 100644 index 0000000000000..c5b9ebed4d5f2 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util; + +import org.apache.iotdb.commons.exception.IllegalPathException; +import org.apache.iotdb.commons.path.PartialPath; +import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory; +import org.apache.iotdb.db.pipe.resource.memory.InsertNodeMemoryEstimator; +import org.apache.iotdb.db.queryengine.plan.analyze.cache.schema.DataNodeDevicePathCache; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.ColumnCategory; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; +import org.apache.tsfile.utils.BytesUtils; +import org.apache.tsfile.utils.Pair; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.apache.tsfile.write.UnSupportedDataTypeException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * Utility class for converting between InsertTabletStatement and Tablet format ByteBuffer. This + * avoids creating intermediate Tablet objects and directly converts between formats with only the + * fields needed. + */ +public class TabletStatementConverter { + + private static final Logger LOGGER = LoggerFactory.getLogger(TabletStatementConverter.class); + + // Memory calculation constants - extracted from RamUsageEstimator for better performance + private static final long NUM_BYTES_ARRAY_HEADER = + org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + private static final long NUM_BYTES_OBJECT_REF = + org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_OBJECT_REF; + private static final long NUM_BYTES_OBJECT_HEADER = + org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + private static final long SIZE_OF_ARRAYLIST = + org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance(java.util.ArrayList.class); + private static final long SIZE_OF_BITMAP = + org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance( + org.apache.tsfile.utils.BitMap.class); + + private TabletStatementConverter() { + // Utility class, no instantiation + } + + /** + * Deserialize InsertTabletStatement from Tablet format ByteBuffer. + * + * @param byteBuffer ByteBuffer containing serialized data + * @param readDatabaseName whether to read databaseName from buffer (for V2 format) + * @return InsertTabletStatement with all fields set, including devicePath + */ + public static InsertTabletStatement deserializeStatementFromTabletFormat( + final ByteBuffer byteBuffer, final boolean readDatabaseName) throws IllegalPathException { + final InsertTabletStatement statement = new InsertTabletStatement(); + + // Calculate memory size during deserialization, use INSTANCE_SIZE constant + long memorySize = InsertTabletStatement.getInstanceSize(); + + final String insertTargetName = ReadWriteIOUtils.readString(byteBuffer); + + final int rowSize = ReadWriteIOUtils.readInt(byteBuffer); + + // deserialize schemas + final int schemaSize = + BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)) + ? ReadWriteIOUtils.readInt(byteBuffer) + : 0; + final String[] measurement = new String[schemaSize]; + final TsTableColumnCategory[] columnCategories = new TsTableColumnCategory[schemaSize]; + final TSDataType[] dataTypes = new TSDataType[schemaSize]; + + // Calculate memory for arrays headers and references during deserialization + // measurements array: array header + object references + long measurementMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + + // dataTypes array: shallow size (array header + references) + long dataTypesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + + // columnCategories array: shallow size (array header + references) + long columnCategoriesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + + // tagColumnIndices (TAG columns): ArrayList base + array header + long tagColumnIndicesSize = SIZE_OF_ARRAYLIST; + tagColumnIndicesSize += NUM_BYTES_ARRAY_HEADER; + + // Deserialize and calculate memory in the same loop + for (int i = 0; i < schemaSize; i++) { + final boolean hasSchema = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (hasSchema) { + final Pair pair = readMeasurement(byteBuffer); + measurement[i] = pair.getLeft(); + dataTypes[i] = pair.getRight(); + columnCategories[i] = + TsTableColumnCategory.fromTsFileColumnCategory( + ColumnCategory.values()[byteBuffer.get()]); + + // Calculate memory for each measurement string + if (measurement[i] != null) { + measurementMemorySize += org.apache.tsfile.utils.RamUsageEstimator.sizeOf(measurement[i]); + } + + // Calculate memory for TAG column indices + if (columnCategories[i] != null && columnCategories[i].equals(TsTableColumnCategory.TAG)) { + tagColumnIndicesSize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + Integer.BYTES + NUM_BYTES_OBJECT_HEADER) + + NUM_BYTES_OBJECT_REF; + } + } + } + + // Add all calculated memory to total + memorySize += measurementMemorySize; + memorySize += dataTypesMemorySize; + + // deserialize times and calculate memory during deserialization + final long[] times = new long[rowSize]; + // Calculate memory: array header + long size * rowSize + final long timesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * rowSize); + + final boolean isTimesNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isTimesNotNull) { + for (int i = 0; i < rowSize; i++) { + times[i] = ReadWriteIOUtils.readLong(byteBuffer); + } + } + + // Add times memory to total + memorySize += timesMemorySize; + + // deserialize bitmaps and calculate memory during deserialization + final BitMap[] bitMaps; + final long bitMapsMemorySize; + + final boolean isBitMapsNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isBitMapsNotNull) { + // Use the method that returns both BitMap array and memory size + final Pair bitMapsAndMemory = + readBitMapsFromBufferWithMemory(byteBuffer, schemaSize); + bitMaps = bitMapsAndMemory.getLeft(); + bitMapsMemorySize = bitMapsAndMemory.getRight(); + } else { + // Calculate memory for empty BitMap array: array header + references + bitMaps = new BitMap[schemaSize]; + bitMapsMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + } + + // Add bitMaps memory to total + memorySize += bitMapsMemorySize; + + // Deserialize values and calculate memory during deserialization + final Object[] values; + final long valuesMemorySize; + + final boolean isValuesNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isValuesNotNull) { + // Use the method that returns both values array and memory size + final Pair valuesAndMemory = + readValuesFromBufferWithMemory(byteBuffer, dataTypes, schemaSize, rowSize); + values = valuesAndMemory.getLeft(); + valuesMemorySize = valuesAndMemory.getRight(); + } else { + // Calculate memory for empty values array: array header + references + values = new Object[schemaSize]; + valuesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + } + + // Add values memory to total + memorySize += valuesMemorySize; + + final boolean isAligned = ReadWriteIOUtils.readBoolean(byteBuffer); + + statement.setMeasurements(measurement); + statement.setTimes(times); + statement.setBitMaps(bitMaps); + statement.setDataTypes(dataTypes); + statement.setColumns(values); + statement.setRowCount(rowSize); + statement.setAligned(isAligned); + + // Read databaseName if requested (V2 format) + if (readDatabaseName) { + final String databaseName = ReadWriteIOUtils.readString(byteBuffer); + if (databaseName != null) { + statement.setDatabaseName(databaseName); + statement.setWriteToTable(true); + // For table model, insertTargetName is table name, convert to lowercase + statement.setDevicePath(new PartialPath(insertTargetName.toLowerCase(), false)); + // Calculate memory for databaseName + memorySize += org.apache.tsfile.utils.RamUsageEstimator.sizeOf(databaseName); + + statement.setColumnCategories(columnCategories); + + memorySize += columnCategoriesMemorySize; + memorySize += tagColumnIndicesSize; + } else { + // For tree model, use DataNodeDevicePathCache + statement.setDevicePath( + DataNodeDevicePathCache.getInstance().getPartialPath(insertTargetName)); + statement.setColumnCategories(null); + } + } else { + // V1 format: no databaseName in buffer, always use DataNodeDevicePathCache + statement.setDevicePath( + DataNodeDevicePathCache.getInstance().getPartialPath(insertTargetName)); + statement.setColumnCategories(null); + } + + // Calculate memory for devicePath + memorySize += InsertNodeMemoryEstimator.sizeOfPartialPath(statement.getDevicePath()); + + // Set the pre-calculated memory size to avoid recalculation + statement.setRamBytesUsed(memorySize); + + return statement; + } + + /** + * Deserialize InsertTabletStatement from Tablet format ByteBuffer (V1 format, no databaseName). + * + * @param byteBuffer ByteBuffer containing serialized data + * @return InsertTabletStatement with devicePath set using DataNodeDevicePathCache + */ + public static InsertTabletStatement deserializeStatementFromTabletFormat( + final ByteBuffer byteBuffer) throws IllegalPathException { + return deserializeStatementFromTabletFormat(byteBuffer, false); + } + + /** + * Skip a string in ByteBuffer without reading it. This is more efficient than reading and + * discarding the string. + * + * @param buffer ByteBuffer to skip string from + */ + private static void skipString(final ByteBuffer buffer) { + final int size = ReadWriteIOUtils.readInt(buffer); + if (size > 0) { + buffer.position(buffer.position() + size); + } + } + + /** + * Read measurement name and data type from buffer, skipping other measurement schema fields + * (encoding, compression, and tags/attributes) that are not needed for InsertTabletStatement. + * + * @param buffer ByteBuffer containing serialized measurement schema + * @return Pair of measurement name and data type + */ + private static Pair readMeasurement(final ByteBuffer buffer) { + // Read measurement name and data type + final Pair pair = + new Pair<>(ReadWriteIOUtils.readString(buffer), TSDataType.deserializeFrom(buffer)); + + // Skip encoding type (byte) and compression type (byte) - 2 bytes total + buffer.position(buffer.position() + 2); + + // Skip props map (Map) + final int size = ReadWriteIOUtils.readInt(buffer); + if (size > 0) { + for (int i = 0; i < size; i++) { + // Skip key (String) and value (String) without constructing temporary objects + skipString(buffer); + skipString(buffer); + } + } + + return pair; + } + + /** + * Deserialize bitmaps and calculate memory size during deserialization. Returns a Pair of BitMap + * array and the calculated memory size. + */ + private static Pair readBitMapsFromBufferWithMemory( + final ByteBuffer byteBuffer, final int columns) { + final BitMap[] bitMaps = new BitMap[columns]; + + // Calculate memory: array header + object references + long memorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * columns); + + for (int i = 0; i < columns; i++) { + final boolean hasBitMap = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (hasBitMap) { + final int size = ReadWriteIOUtils.readInt(byteBuffer); + final Binary valueBinary = ReadWriteIOUtils.readBinary(byteBuffer); + final byte[] byteArray = valueBinary.getValues(); + bitMaps[i] = new BitMap(size, byteArray); + + // Calculate memory for this BitMap: BitMap object + byte array + // BitMap shallow size + byte array (array header + array length) + memorySize += + SIZE_OF_BITMAP + + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + byteArray.length); + } + } + + return new Pair<>(bitMaps, memorySize); + } + + /** + * Deserialize values from buffer and calculate memory size during deserialization. Returns a Pair + * of values array and the calculated memory size. + * + * @param byteBuffer data values + * @param types data types + * @param columns column number + * @param rowSize row number + * @return Pair of values array and memory size + */ + @SuppressWarnings("squid:S3776") // Suppress high Cognitive Complexity warning + private static Pair readValuesFromBufferWithMemory( + final ByteBuffer byteBuffer, final TSDataType[] types, final int columns, final int rowSize) { + final Object[] values = new Object[columns]; + + // Calculate memory: array header + object references + long memorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * columns); + + for (int i = 0; i < columns; i++) { + final boolean isValueColumnsNotNull = + BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isValueColumnsNotNull && types[i] == null) { + continue; + } + + switch (types[i]) { + case BOOLEAN: + final boolean[] boolValues = new boolean[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + boolValues[index] = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + } + } + values[i] = boolValues; + // Calculate memory for boolean array: array header + 1 byte per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + rowSize); + break; + case INT32: + case DATE: + final int[] intValues = new int[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + intValues[index] = ReadWriteIOUtils.readInt(byteBuffer); + } + } + values[i] = intValues; + // Calculate memory for int array: array header + 4 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Integer.BYTES * rowSize); + break; + case INT64: + case TIMESTAMP: + final long[] longValues = new long[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + longValues[index] = ReadWriteIOUtils.readLong(byteBuffer); + } + } + values[i] = longValues; + // Calculate memory for long array: array header + 8 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * rowSize); + break; + case FLOAT: + final float[] floatValues = new float[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + floatValues[index] = ReadWriteIOUtils.readFloat(byteBuffer); + } + } + values[i] = floatValues; + // Calculate memory for float array: array header + 4 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Float.BYTES * rowSize); + break; + case DOUBLE: + final double[] doubleValues = new double[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + doubleValues[index] = ReadWriteIOUtils.readDouble(byteBuffer); + } + } + values[i] = doubleValues; + // Calculate memory for double array: array header + 8 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * rowSize); + break; + case TEXT: + case STRING: + case BLOB: + case OBJECT: + // Handle object array type: Binary[] is an array of objects + final Binary[] binaryValues = new Binary[rowSize]; + // Calculate memory for Binary array: array header + object references + long binaryArrayMemory = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * rowSize); + + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + final boolean isNotNull = + BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isNotNull) { + binaryValues[index] = ReadWriteIOUtils.readBinary(byteBuffer); + // Calculate memory for each Binary object during deserialization + binaryArrayMemory += binaryValues[index].ramBytesUsed(); + } else { + binaryValues[index] = Binary.EMPTY_VALUE; + // EMPTY_VALUE also has memory cost + binaryArrayMemory += Binary.EMPTY_VALUE.ramBytesUsed(); + } + } + } else { + Arrays.fill(binaryValues, Binary.EMPTY_VALUE); + // Calculate memory for all EMPTY_VALUE + binaryArrayMemory += (long) rowSize * Binary.EMPTY_VALUE.ramBytesUsed(); + } + values[i] = binaryValues; + // Add calculated Binary array memory to total + memorySize += binaryArrayMemory; + break; + default: + throw new UnSupportedDataTypeException( + String.format("data type %s is not supported when convert data at client", types[i])); + } + } + + return new Pair<>(values, memorySize); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java new file mode 100644 index 0000000000000..62111cb4129ca --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util.sorter; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.IDeviceID; +import org.apache.tsfile.utils.BitMap; + +/** + * Adapter interface to encapsulate common operations needed for sorting and deduplication. This + * interface allows the sorter to work with both Tablet and InsertTabletStatement. + */ +public interface InsertEventDataAdapter { + + /** + * Get the number of columns. + * + * @return number of columns + */ + int getColumnCount(); + + /** + * Get data type for a specific column. + * + * @param columnIndex column index + * @return data type of the column + */ + TSDataType getDataType(int columnIndex); + + /** + * Get bit maps for null values. + * + * @return array of bit maps, may be null + */ + BitMap[] getBitMaps(); + + /** + * Set bit maps for null values. + * + * @param bitMaps array of bit maps + */ + void setBitMaps(BitMap[] bitMaps); + + /** + * Get value arrays for all columns. + * + * @return array of value arrays (Object[]) + */ + Object[] getValues(); + + /** + * Set value array for a specific column. + * + * @param columnIndex column index + * @param value value array + */ + void setValue(int columnIndex, Object value); + + /** + * Get timestamps array. + * + * @return array of timestamps + */ + long[] getTimestamps(); + + /** + * Set timestamps array. + * + * @param timestamps array of timestamps + */ + void setTimestamps(long[] timestamps); + + /** + * Get row size/count. + * + * @return number of rows + */ + int getRowSize(); + + /** + * Set row size/count. + * + * @param rowSize number of rows + */ + void setRowSize(int rowSize); + + /** + * Get timestamp at a specific row index. + * + * @param rowIndex row index + * @return timestamp value + */ + long getTimestamp(int rowIndex); + + /** + * Get device ID at a specific row index (for table model). + * + * @param rowIndex row index + * @return device ID + */ + IDeviceID getDeviceID(int rowIndex); + + /** + * Check if the DATE type column value is stored as LocalDate[] (Tablet) or int[] (Statement). + * + * @param columnIndex column index + * @return true if DATE type is stored as LocalDate[], false if stored as int[] + */ + boolean isDateStoredAsLocalDate(int columnIndex); +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java new file mode 100644 index 0000000000000..30f74a22965a8 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util.sorter; + +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.IDeviceID; +import org.apache.tsfile.utils.BitMap; + +/** Adapter for InsertTabletStatement to implement InsertEventDataAdapter interface. */ +public class InsertTabletStatementAdapter implements InsertEventDataAdapter { + + private final InsertTabletStatement statement; + + public InsertTabletStatementAdapter(final InsertTabletStatement statement) { + this.statement = statement; + } + + @Override + public int getColumnCount() { + final Object[] columns = statement.getColumns(); + return columns != null ? columns.length : 0; + } + + @Override + public TSDataType getDataType(int columnIndex) { + final TSDataType[] dataTypes = statement.getDataTypes(); + if (dataTypes != null && columnIndex < dataTypes.length) { + return dataTypes[columnIndex]; + } + return null; + } + + @Override + public BitMap[] getBitMaps() { + return statement.getBitMaps(); + } + + @Override + public void setBitMaps(BitMap[] bitMaps) { + statement.setBitMaps(bitMaps); + } + + @Override + public Object[] getValues() { + return statement.getColumns(); + } + + @Override + public void setValue(int columnIndex, Object value) { + Object[] columns = statement.getColumns(); + if (columns != null && columnIndex < columns.length) { + columns[columnIndex] = value; + } + } + + @Override + public long[] getTimestamps() { + return statement.getTimes(); + } + + @Override + public void setTimestamps(long[] timestamps) { + statement.setTimes(timestamps); + } + + @Override + public int getRowSize() { + return statement.getRowCount(); + } + + @Override + public void setRowSize(int rowSize) { + statement.setRowCount(rowSize); + } + + @Override + public long getTimestamp(int rowIndex) { + long[] times = statement.getTimes(); + if (times != null && rowIndex < times.length) { + return times[rowIndex]; + } + return 0; + } + + @Override + public IDeviceID getDeviceID(int rowIndex) { + return statement.getTableDeviceID(rowIndex); + } + + @Override + public boolean isDateStoredAsLocalDate(int columnIndex) { + // InsertTabletStatement stores DATE as int[] + return false; + } + + public InsertTabletStatement getStatement() { + return statement; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTabletEventSorter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeInsertEventSorter.java similarity index 65% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTabletEventSorter.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeInsertEventSorter.java index 6540bf9855d62..46a3fc6df9429 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTabletEventSorter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeInsertEventSorter.java @@ -19,19 +19,20 @@ package org.apache.iotdb.db.pipe.sink.util.sorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; import org.apache.tsfile.write.UnSupportedDataTypeException; import org.apache.tsfile.write.record.Tablet; -import org.apache.tsfile.write.schema.IMeasurementSchema; import java.time.LocalDate; import java.util.Objects; -public class PipeTabletEventSorter { +public class PipeInsertEventSorter { - protected final Tablet tablet; + protected final InsertEventDataAdapter dataAdapter; protected Integer[] index; protected boolean isSorted = true; @@ -39,8 +40,31 @@ public class PipeTabletEventSorter { protected int[] deDuplicatedIndex; protected int deDuplicatedSize; - public PipeTabletEventSorter(final Tablet tablet) { - this.tablet = tablet; + /** + * Constructor for Tablet. + * + * @param tablet the tablet to sort + */ + public PipeInsertEventSorter(final Tablet tablet) { + this.dataAdapter = new TabletAdapter(tablet); + } + + /** + * Constructor for InsertTabletStatement. + * + * @param statement the insert tablet statement to sort + */ + public PipeInsertEventSorter(final InsertTabletStatement statement) { + this.dataAdapter = new InsertTabletStatementAdapter(statement); + } + + /** + * Constructor with adapter (for internal use or advanced scenarios). + * + * @param adapter the data adapter + */ + protected PipeInsertEventSorter(final InsertEventDataAdapter adapter) { + this.dataAdapter = adapter; } // Input: @@ -54,35 +78,42 @@ public PipeTabletEventSorter(final Tablet tablet) { // (Used index: [2(3), 4(0)]) // Col: [6, 1] protected void sortAndMayDeduplicateValuesAndBitMaps() { - int columnIndex = 0; - for (int i = 0, size = tablet.getSchemas().size(); i < size; i++) { - final IMeasurementSchema schema = tablet.getSchemas().get(i); - if (schema != null) { + final int columnCount = dataAdapter.getColumnCount(); + BitMap[] bitMaps = dataAdapter.getBitMaps(); + boolean bitMapsModified = false; + + for (int columnIndex = 0; columnIndex < columnCount; columnIndex++) { + final TSDataType dataType = dataAdapter.getDataType(columnIndex); + if (dataType != null) { BitMap deDuplicatedBitMap = null; BitMap originalBitMap = null; - if (tablet.getBitMaps() != null && tablet.getBitMaps()[columnIndex] != null) { - originalBitMap = tablet.getBitMaps()[columnIndex]; + if (bitMaps != null && columnIndex < bitMaps.length && bitMaps[columnIndex] != null) { + originalBitMap = bitMaps[columnIndex]; deDuplicatedBitMap = new BitMap(originalBitMap.getSize()); } - tablet.getValues()[columnIndex] = + final Object[] values = dataAdapter.getValues(); + final Object reorderedValue = reorderValueListAndBitMap( - tablet.getValues()[columnIndex], - schema.getType(), - originalBitMap, - deDuplicatedBitMap); + values[columnIndex], dataType, columnIndex, originalBitMap, deDuplicatedBitMap); + dataAdapter.setValue(columnIndex, reorderedValue); - if (tablet.getBitMaps() != null && tablet.getBitMaps()[columnIndex] != null) { - tablet.getBitMaps()[columnIndex] = deDuplicatedBitMap; + if (bitMaps != null && columnIndex < bitMaps.length && bitMaps[columnIndex] != null) { + bitMaps[columnIndex] = deDuplicatedBitMap; + bitMapsModified = true; } - columnIndex++; } } + + if (bitMapsModified) { + dataAdapter.setBitMaps(bitMaps); + } } protected Object reorderValueListAndBitMap( final Object valueList, final TSDataType dataType, + final int columnIndex, final BitMap originalBitMap, final BitMap deDuplicatedBitMap) { // Older version's sender may contain null values, we need to cover this case @@ -107,13 +138,26 @@ protected Object reorderValueListAndBitMap( } return deDuplicatedIntValues; case DATE: - final LocalDate[] dateValues = (LocalDate[]) valueList; - final LocalDate[] deDuplicatedDateValues = new LocalDate[dateValues.length]; - for (int i = 0; i < deDuplicatedSize; i++) { - deDuplicatedDateValues[i] = - dateValues[getLastNonnullIndex(i, originalBitMap, deDuplicatedBitMap)]; + // DATE type: Tablet uses LocalDate[], InsertTabletStatement uses int[] + if (dataAdapter.isDateStoredAsLocalDate(columnIndex)) { + // Tablet: LocalDate[] + final LocalDate[] dateValues = (LocalDate[]) valueList; + final LocalDate[] deDuplicatedDateValues = new LocalDate[dateValues.length]; + for (int i = 0; i < deDuplicatedSize; i++) { + deDuplicatedDateValues[i] = + dateValues[getLastNonnullIndex(i, originalBitMap, deDuplicatedBitMap)]; + } + return deDuplicatedDateValues; + } else { + // InsertTabletStatement: int[] + final int[] intDateValues = (int[]) valueList; + final int[] deDuplicatedIntDateValues = new int[intDateValues.length]; + for (int i = 0; i < deDuplicatedSize; i++) { + deDuplicatedIntDateValues[i] = + intDateValues[getLastNonnullIndex(i, originalBitMap, deDuplicatedBitMap)]; + } + return deDuplicatedIntDateValues; } - return deDuplicatedDateValues; case INT64: case TIMESTAMP: final long[] longValues = (long[]) valueList; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java index 5735b51c6d025..ba034cae3a310 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java @@ -19,6 +19,8 @@ package org.apache.iotdb.db.pipe.sink.util.sorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.file.metadata.IDeviceID; import org.apache.tsfile.utils.Pair; @@ -31,14 +33,29 @@ import java.util.List; import java.util.Map; -public class PipeTableModelTabletEventSorter extends PipeTabletEventSorter { +public class PipeTableModelTabletEventSorter extends PipeInsertEventSorter { private int initIndexSize; + /** + * Constructor for Tablet. + * + * @param tablet the tablet to sort + */ public PipeTableModelTabletEventSorter(final Tablet tablet) { super(tablet); deDuplicatedSize = tablet == null ? 0 : tablet.getRowSize(); } + /** + * Constructor for InsertTabletStatement. + * + * @param statement the insert tablet statement to sort + */ + public PipeTableModelTabletEventSorter(final InsertTabletStatement statement) { + super(statement); + deDuplicatedSize = statement == null ? 0 : statement.getRowCount(); + } + /** * For the sorting and deduplication needs of the table model tablet, it is done according to the * {@link IDeviceID}. For sorting, it is necessary to sort the {@link IDeviceID} first, and then @@ -46,18 +63,19 @@ public PipeTableModelTabletEventSorter(final Tablet tablet) { * the same timestamp in different {@link IDeviceID} will not be processed. */ public void sortAndDeduplicateByDevIdTimestamp() { - if (tablet == null || tablet.getRowSize() < 1) { + if (dataAdapter == null || dataAdapter.getRowSize() < 1) { return; } final HashMap>> deviceIDToIndexMap = new HashMap<>(); - final long[] timestamps = tablet.getTimestamps(); + final long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); - IDeviceID lastDevice = tablet.getDeviceID(0); - long previousTimestamp = tablet.getTimestamp(0); + IDeviceID lastDevice = dataAdapter.getDeviceID(0); + long previousTimestamp = dataAdapter.getTimestamp(0); int lasIndex = 0; - for (int i = 1, size = tablet.getRowSize(); i < size; ++i) { - final IDeviceID deviceID = tablet.getDeviceID(i); + for (int i = 1; i < rowSize; ++i) { + final IDeviceID deviceID = dataAdapter.getDeviceID(i); final long currentTimestamp = timestamps[i]; final int deviceComparison = deviceID.compareTo(lastDevice); if (deviceComparison == 0) { @@ -92,7 +110,7 @@ public void sortAndDeduplicateByDevIdTimestamp() { if (!list.isEmpty()) { isSorted = false; } - list.add(new Pair<>(lasIndex, tablet.getRowSize())); + list.add(new Pair<>(lasIndex, rowSize)); if (isSorted && isDeDuplicated) { return; @@ -100,8 +118,8 @@ public void sortAndDeduplicateByDevIdTimestamp() { initIndexSize = 0; deDuplicatedSize = 0; - index = new Integer[tablet.getRowSize()]; - deDuplicatedIndex = new int[tablet.getRowSize()]; + index = new Integer[rowSize]; + deDuplicatedIndex = new int[rowSize]; deviceIDToIndexMap.entrySet().stream() .sorted(Map.Entry.comparingByKey()) .forEach( @@ -129,19 +147,22 @@ public void sortAndDeduplicateByDevIdTimestamp() { } private void sortAndDeduplicateValuesAndBitMapsWithTimestamp() { - tablet.setTimestamps( + // TIMESTAMP is not a DATE type, so columnIndex is not relevant here, use -1 + dataAdapter.setTimestamps( (long[]) - reorderValueListAndBitMap(tablet.getTimestamps(), TSDataType.TIMESTAMP, null, null)); + reorderValueListAndBitMap( + dataAdapter.getTimestamps(), TSDataType.TIMESTAMP, -1, null, null)); sortAndMayDeduplicateValuesAndBitMaps(); - tablet.setRowSize(deDuplicatedSize); + dataAdapter.setRowSize(deDuplicatedSize); } private void sortTimestamps(final int startIndex, final int endIndex) { - Arrays.sort(this.index, startIndex, endIndex, Comparator.comparingLong(tablet::getTimestamp)); + Arrays.sort( + this.index, startIndex, endIndex, Comparator.comparingLong(dataAdapter::getTimestamp)); } private void deDuplicateTimestamps(final int startIndex, final int endIndex) { - final long[] timestamps = tablet.getTimestamps(); + final long[] timestamps = dataAdapter.getTimestamps(); long lastTime = timestamps[index[startIndex]]; for (int i = startIndex + 1; i < endIndex; i++) { if (lastTime != (lastTime = timestamps[index[i]])) { @@ -153,12 +174,13 @@ private void deDuplicateTimestamps(final int startIndex, final int endIndex) { /** Sort by time only. */ public void sortByTimestampIfNecessary() { - if (tablet == null || tablet.getRowSize() == 0) { + if (dataAdapter == null || dataAdapter.getRowSize() == 0) { return; } - final long[] timestamps = tablet.getTimestamps(); - for (int i = 1, size = tablet.getRowSize(); i < size; ++i) { + final long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); + for (int i = 1; i < rowSize; ++i) { final long currentTimestamp = timestamps[i]; final long previousTimestamp = timestamps[i - 1]; @@ -172,8 +194,8 @@ public void sortByTimestampIfNecessary() { return; } - index = new Integer[tablet.getRowSize()]; - for (int i = 0, size = tablet.getRowSize(); i < size; i++) { + index = new Integer[rowSize]; + for (int i = 0; i < rowSize; i++) { index[i] = i; } @@ -185,7 +207,8 @@ public void sortByTimestampIfNecessary() { } private void sortTimestamps() { - Arrays.sort(this.index, Comparator.comparingLong(tablet::getTimestamp)); - Arrays.sort(tablet.getTimestamps(), 0, tablet.getRowSize()); + Arrays.sort(this.index, Comparator.comparingLong(dataAdapter::getTimestamp)); + final long[] timestamps = dataAdapter.getTimestamps(); + Arrays.sort(timestamps, 0, dataAdapter.getRowSize()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java index c26f59220f981..2a56b706463ff 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java @@ -19,25 +19,43 @@ package org.apache.iotdb.db.pipe.sink.util.sorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + import org.apache.tsfile.write.record.Tablet; import java.util.Arrays; import java.util.Comparator; -public class PipeTreeModelTabletEventSorter extends PipeTabletEventSorter { +public class PipeTreeModelTabletEventSorter extends PipeInsertEventSorter { + /** + * Constructor for Tablet. + * + * @param tablet the tablet to sort + */ public PipeTreeModelTabletEventSorter(final Tablet tablet) { super(tablet); deDuplicatedSize = tablet == null ? 0 : tablet.getRowSize(); } + /** + * Constructor for InsertTabletStatement. + * + * @param statement the insert tablet statement to sort + */ + public PipeTreeModelTabletEventSorter(final InsertTabletStatement statement) { + super(statement); + deDuplicatedSize = statement == null ? 0 : statement.getRowCount(); + } + public void deduplicateAndSortTimestampsIfNecessary() { - if (tablet == null || tablet.getRowSize() == 0) { + if (dataAdapter == null || dataAdapter.getRowSize() == 0) { return; } - long[] timestamps = tablet.getTimestamps(); - for (int i = 1, size = tablet.getRowSize(); i < size; ++i) { + long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); + for (int i = 1; i < rowSize; ++i) { final long currentTimestamp = timestamps[i]; final long previousTimestamp = timestamps[i - 1]; @@ -54,9 +72,9 @@ public void deduplicateAndSortTimestampsIfNecessary() { return; } - index = new Integer[tablet.getRowSize()]; - deDuplicatedIndex = new int[tablet.getRowSize()]; - for (int i = 0, size = tablet.getRowSize(); i < size; i++) { + index = new Integer[rowSize]; + deDuplicatedIndex = new int[rowSize]; + for (int i = 0; i < rowSize; i++) { index[i] = i; } @@ -78,14 +96,16 @@ public void deduplicateAndSortTimestampsIfNecessary() { private void sortTimestamps() { // Index is sorted stably because it is Integer[] - Arrays.sort(index, Comparator.comparingLong(tablet::getTimestamp)); - Arrays.sort(tablet.getTimestamps(), 0, tablet.getRowSize()); + Arrays.sort(index, Comparator.comparingLong(dataAdapter::getTimestamp)); + final long[] timestamps = dataAdapter.getTimestamps(); + Arrays.sort(timestamps, 0, dataAdapter.getRowSize()); } private void deduplicateTimestamps() { deDuplicatedSize = 0; - long[] timestamps = tablet.getTimestamps(); - for (int i = 1, size = tablet.getRowSize(); i < size; i++) { + long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); + for (int i = 1; i < rowSize; i++) { if (timestamps[i] != timestamps[i - 1]) { deDuplicatedIndex[deDuplicatedSize] = i - 1; timestamps[deDuplicatedSize] = timestamps[i - 1]; @@ -94,8 +114,8 @@ private void deduplicateTimestamps() { } } - deDuplicatedIndex[deDuplicatedSize] = tablet.getRowSize() - 1; - timestamps[deDuplicatedSize] = timestamps[tablet.getRowSize() - 1]; - tablet.setRowSize(++deDuplicatedSize); + deDuplicatedIndex[deDuplicatedSize] = rowSize - 1; + timestamps[deDuplicatedSize] = timestamps[rowSize - 1]; + dataAdapter.setRowSize(++deDuplicatedSize); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java new file mode 100644 index 0000000000000..b200127a5d451 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util.sorter; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.IDeviceID; +import org.apache.tsfile.utils.BitMap; +import org.apache.tsfile.write.record.Tablet; +import org.apache.tsfile.write.schema.IMeasurementSchema; + +import java.util.List; + +/** Adapter for Tablet to implement InsertEventDataAdapter interface. */ +public class TabletAdapter implements InsertEventDataAdapter { + + private final Tablet tablet; + + public TabletAdapter(final Tablet tablet) { + this.tablet = tablet; + } + + @Override + public int getColumnCount() { + final Object[] values = tablet.getValues(); + return values != null ? values.length : 0; + } + + @Override + public TSDataType getDataType(int columnIndex) { + final List schemas = tablet.getSchemas(); + if (schemas != null && columnIndex < schemas.size()) { + final IMeasurementSchema schema = schemas.get(columnIndex); + return schema != null ? schema.getType() : null; + } + return null; + } + + @Override + public BitMap[] getBitMaps() { + return tablet.getBitMaps(); + } + + @Override + public void setBitMaps(BitMap[] bitMaps) { + tablet.setBitMaps(bitMaps); + } + + @Override + public Object[] getValues() { + return tablet.getValues(); + } + + @Override + public void setValue(int columnIndex, Object value) { + tablet.getValues()[columnIndex] = value; + } + + @Override + public long[] getTimestamps() { + return tablet.getTimestamps(); + } + + @Override + public void setTimestamps(long[] timestamps) { + tablet.setTimestamps(timestamps); + } + + @Override + public int getRowSize() { + return tablet.getRowSize(); + } + + @Override + public void setRowSize(int rowSize) { + tablet.setRowSize(rowSize); + } + + @Override + public long getTimestamp(int rowIndex) { + return tablet.getTimestamp(rowIndex); + } + + @Override + public IDeviceID getDeviceID(int rowIndex) { + return tablet.getDeviceID(rowIndex); + } + + @Override + public boolean isDateStoredAsLocalDate(int columnIndex) { + return true; + } + + public Tablet getTablet() { + return tablet; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java deleted file mode 100644 index 0d784617c0905..0000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.ClientManagerMetrics; -import org.apache.iotdb.commons.client.IClientPoolFactory; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ClientPoolProperty; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.concurrent.ThreadName; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.commons.pool2.impl.GenericKeyedObjectPool; - -import java.util.Optional; - -/** Dedicated factory for AINodeClient + AINodeClientPoolFactory. */ -public class AINodeClientFactory extends ThriftClientFactory { - - private static final int connectionTimeout = - CommonDescriptor.getInstance().getConfig().getDnConnectionTimeoutInMS(); - - public AINodeClientFactory( - ClientManager manager, ThriftClientProperty thriftProperty) { - super(manager, thriftProperty); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public void destroyObject(TEndPoint key, PooledObject pooled) throws Exception { - pooled.getObject().invalidate(); - } - - @Override - public boolean validateObject(TEndPoint key, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(org.apache.thrift.transport.TTransport::isOpen) - .orElse(false); - } - - /** The PoolFactory originally inside ClientPoolFactory — now moved here. */ - public static class AINodeClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - // Build thrift client properties - ThriftClientProperty thriftProperty = - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(connectionTimeout) - .setRpcThriftCompressionEnabled( - CommonDescriptor.getInstance().getConfig().isRpcThriftCompressionEnabled()) - .build(); - - GenericKeyedObjectPool pool = - new GenericKeyedObjectPool<>( - new AINodeClientFactory(manager, thriftProperty), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode( - CommonDescriptor.getInstance().getConfig().getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), pool); - - return pool; - } - } - - public static class AINodeHeartbeatClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - final CommonConfig conf = CommonDescriptor.getInstance().getConfig(); - - GenericKeyedObjectPool clientPool = - new GenericKeyedObjectPool<>( - new AsyncAINodeServiceClient.Factory( - manager, - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) - .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) - .setPrintLogWhenEncounterException(false) - .build(), - ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), clientPool); - - return clientPool; - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java index 2c037cf0f3e58..df80d49b502b0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java @@ -73,7 +73,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -102,7 +101,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -121,8 +119,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -184,7 +180,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; @@ -525,7 +520,8 @@ public TAINodeRestartResp restartAINode(TAINodeRestartReq req) throws TException @Override public TGetAINodeLocationResp getAINodeLocation() throws TException { - return client.getAINodeLocation(); + return executeRemoteCallWithRetry( + () -> client.getAINodeLocation(), resp -> !updateConfigNodeLeader(resp.status)); } @Override @@ -1339,28 +1335,6 @@ public TShowCQResp showCQ() throws TException { () -> client.showCQ(), resp -> !updateConfigNodeLeader(resp.status)); } - @Override - public TSStatus createModel(TCreateModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.createModel(req), status -> !updateConfigNodeLeader(status)); - } - - @Override - public TSStatus dropModel(TDropModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.dropModel(req), status -> !updateConfigNodeLeader(status)); - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.getModelInfo(req), resp -> !updateConfigNodeLeader(resp.getStatus())); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.updateModelInfo(req), status -> !updateConfigNodeLeader(status)); - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return executeRemoteCallWithRetry( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java index b5f5df430129f..da0d84d8466fe 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java @@ -27,12 +27,13 @@ import org.apache.iotdb.commons.consensus.ConfigRegionId; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; import org.apache.commons.pool2.impl.GenericKeyedObjectPool; public class DataNodeClientPoolFactory { - private static final IoTDBConfig conf = IoTDBDescriptor.getInstance().getConfig(); + private static final IoTDBConfig CONF = IoTDBDescriptor.getInstance().getConfig(); private DataNodeClientPoolFactory() { // Empty constructor @@ -49,11 +50,11 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() @@ -73,15 +74,38 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS() * 10) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS() * 10) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .setSelectorNumOfAsyncClientManager( - conf.getSelectorNumOfClientManager() / 10 > 0 - ? conf.getSelectorNumOfClientManager() / 10 + CONF.getSelectorNumOfClientManager() / 10 > 0 + ? CONF.getSelectorNumOfClientManager() / 10 : 1) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } + + public static class AINodeClientPoolFactory implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AINodeClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) + .build()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java deleted file mode 100644 index 54150b8f3007b..0000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; -import org.apache.iotdb.ainode.rpc.thrift.TConfigs; -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; -import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.ThriftClient; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.rpc.TConfigurationConst; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.thrift.TException; -import org.apache.thrift.transport.TSSLTransportFactory; -import org.apache.thrift.transport.TSocket; -import org.apache.thrift.transport.TTransport; -import org.apache.thrift.transport.TTransportException; -import org.apache.thrift.transport.layered.TFramedTransport; -import org.apache.tsfile.enums.TSDataType; -import org.apache.tsfile.read.common.block.TsBlock; -import org.apache.tsfile.read.common.block.column.TsBlockSerde; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; - -import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; -import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR; - -public class AINodeClient implements AutoCloseable, ThriftClient { - - private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class); - - private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - - private TEndPoint endPoint; - - private TTransport transport; - - private final ThriftClientProperty property; - private IAINodeRPCService.Client client; - - public static final String MSG_CONNECTION_FAIL = - "Fail to connect to AINode. Please check status of AINode"; - private static final int MAX_RETRY = 3; - - @FunctionalInterface - private interface RemoteCall { - R apply(IAINodeRPCService.Client c) throws TException; - } - - private final TsBlockSerde tsBlockSerde = new TsBlockSerde(); - - ClientManager clientManager; - - private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = - ConfigNodeClientManager.getInstance(); - - private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); - - public static TEndPoint getCurrentEndpoint() { - TAINodeLocation loc = CURRENT_LOCATION.get(); - if (loc == null) { - loc = refreshFromConfigNode(); - } - return (loc == null) ? null : pickEndpointFrom(loc); - } - - public static void updateGlobalAINodeLocation(final TAINodeLocation loc) { - if (loc != null) { - CURRENT_LOCATION.set(loc); - } - } - - private R executeRemoteCallWithRetry(RemoteCall call) throws TException { - TException last = null; - for (int attempt = 1; attempt <= MAX_RETRY; attempt++) { - try { - if (transport == null || !transport.isOpen()) { - final TEndPoint ep = getCurrentEndpoint(); - if (ep == null) { - throw new TException("AINode endpoint unavailable"); - } - this.endPoint = ep; - init(); - } - return call.apply(client); - } catch (TException e) { - last = e; - invalidate(); - final TAINodeLocation loc = refreshFromConfigNode(); - if (loc != null) { - this.endPoint = pickEndpointFrom(loc); - } - try { - Thread.sleep(1000L * attempt); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - } - } - } - throw (last != null ? last : new TException(MSG_CONNECTION_FAIL)); - } - - private static TAINodeLocation refreshFromConfigNode() { - try (final ConfigNodeClient cn = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TGetAINodeLocationResp resp = cn.getAINodeLocation(); - if (resp != null && resp.isSetAiNodeLocation()) { - final TAINodeLocation loc = resp.getAiNodeLocation(); - CURRENT_LOCATION.set(loc); - return loc; - } - } catch (Exception e) { - LoggerFactory.getLogger(AINodeClient.class) - .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); - } - return null; - } - - private static TEndPoint pickEndpointFrom(final TAINodeLocation loc) { - if (loc == null) return null; - if (loc.isSetInternalEndPoint() && loc.getInternalEndPoint() != null) { - return loc.getInternalEndPoint(); - } - return null; - } - - public AINodeClient( - ThriftClientProperty property, - TEndPoint endPoint, - ClientManager clientManager) - throws TException { - this.property = property; - this.clientManager = clientManager; - // Instance default endpoint (pool key). Global location can override it on retries. - this.endPoint = endPoint; - init(); - } - - private void init() throws TException { - try { - if (commonConfig.isEnableInternalSSL()) { - TSSLTransportFactory.TSSLTransportParameters params = - new TSSLTransportFactory.TSSLTransportParameters(); - params.setTrustStore(commonConfig.getTrustStorePath(), commonConfig.getTrustStorePwd()); - params.setKeyStore(commonConfig.getKeyStorePath(), commonConfig.getKeyStorePwd()); - transport = - new TFramedTransport.Factory() - .getTransport( - TSSLTransportFactory.getClientSocket( - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs(), - params)); - } else { - transport = - new TFramedTransport.Factory() - .getTransport( - new TSocket( - TConfigurationConst.defaultTConfiguration, - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs())); - } - if (!transport.isOpen()) { - transport.open(); - } - } catch (TTransportException e) { - throw new TException(MSG_CONNECTION_FAIL); - } - client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); - } - - public TTransport getTransport() { - return transport; - } - - public TSStatus stopAINode() throws TException { - try { - TSStatus status = client.stopAINode(); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(status.message); - } - return status; - } catch (TException e) { - logger.warn( - "Failed to connect to AINode from ConfigNode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public ModelInformation registerModel(String modelName, String uri) throws LoadModelException { - try { - TRegisterModelReq req = new TRegisterModelReq(uri, modelName); - TRegisterModelResp resp = client.registerModel(req); - if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new LoadModelException(resp.status.message, resp.status.getCode()); - } - return parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs()); - } catch (TException e) { - throw new LoadModelException( - e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()); - } - } - - private ModelInformation parseModelInformation( - String modelName, String attributes, TConfigs configs) { - int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray(); - int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray(); - - TSDataType[] inputType = new TSDataType[inputShape[1]]; - TSDataType[] outputType = new TSDataType[outputShape[1]]; - for (int i = 0; i < inputShape[1]; i++) { - inputType[i] = TSDataType.values()[configs.getInput_type().get(i)]; - } - for (int i = 0; i < outputShape[1]; i++) { - outputType[i] = TSDataType.values()[configs.getOutput_type().get(i)]; - } - - return new ModelInformation( - modelName, inputShape, outputShape, inputType, outputType, attributes); - } - - public TSStatus deleteModel(TDeleteModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.deleteModel(req)); - } - - public TSStatus loadModel(TLoadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.loadModel(req)); - } - - public TSStatus unloadModel(TUnloadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.unloadModel(req)); - } - - public TShowModelsResp showModels(TShowModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showModels(req)); - } - - public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showLoadedModels(req)); - } - - public TShowAIDevicesResp showAIDevices() throws TException { - return executeRemoteCallWithRetry(IAINodeRPCService.Client::showAIDevices); - } - - public TInferenceResp inference( - String modelId, - TsBlock inputTsBlock, - Map inferenceAttributes, - TWindowParams windowParams) - throws TException { - try { - TInferenceReq inferenceReq = new TInferenceReq(modelId, tsBlockSerde.serialize(inputTsBlock)); - if (windowParams != null) { - inferenceReq.setWindowParams(windowParams); - } - if (inferenceAttributes != null) { - inferenceReq.setInferenceAttributes(inferenceAttributes); - } - return executeRemoteCallWithRetry(c -> c.inference(inferenceReq)); - } catch (IOException e) { - throw new TException("An exception occurred while serializing input data", e); - } catch (TException e) { - logger.warn( - "Error happens in AINode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public TForecastResp forecast( - String modelId, TsBlock inputTsBlock, int outputLength, Map options) { - try { - TForecastReq forecastReq = - new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock), outputLength); - forecastReq.setOptions(options); - return executeRemoteCallWithRetry(c -> c.forecast(forecastReq)); - } catch (IOException e) { - TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode()); - tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage())); - return new TForecastResp(tsStatus); - } catch (TException e) { - TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode()); - tsStatus.setMessage( - String.format( - "Failed to connect to AINode when executing %s: %s", - Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage())); - return new TForecastResp(tsStatus); - } - } - - public TSStatus createTrainingTask(TTrainingReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.createTrainingTask(req)); - } - - @Override - public void close() throws Exception { - clientManager.returnClient(endPoint, this); - } - - @Override - public void invalidate() { - Optional.ofNullable(transport).ifPresent(TTransport::close); - } - - @Override - public void invalidateAll() { - clientManager.clear(endPoint); - } - - @Override - public boolean printLogWhenEncounterException() { - return property.isPrintLogWhenEncounterException(); - } - - public static class Factory extends ThriftClientFactory { - - public Factory( - ClientManager clientClientManager, - ThriftClientProperty thriftClientProperty) { - super(clientClientManager, thriftClientProperty); - } - - @Override - public void destroyObject(TEndPoint tEndPoint, PooledObject pooledObject) - throws Exception { - pooledObject.getObject().invalidate(); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public boolean validateObject(TEndPoint tEndPoint, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(TTransport::isOpen) - .orElse(false); - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java deleted file mode 100644 index faef1c1ae7b60..0000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; - -public class AINodeClientManager { - - public static final int DEFAULT_AINODE_ID = 0; - - private static final AINodeClientManager INSTANCE = new AINodeClientManager(); - - private final IClientManager clientManager; - - private volatile TEndPoint defaultAINodeEndPoint; - - private AINodeClientManager() { - this.clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeClientPoolFactory()); - } - - public static AINodeClientManager getInstance() { - return INSTANCE; - } - - public void updateDefaultAINodeLocation(TEndPoint endPoint) { - this.defaultAINodeEndPoint = endPoint; - } - - public AINodeClient borrowClient(TEndPoint endPoint) throws Exception { - return clientManager.borrowClient(endPoint); - } - - public AINodeClient borrowClient(int aiNodeId) throws Exception { - if (aiNodeId != DEFAULT_AINODE_ID) { - throw new IllegalArgumentException("Unsupported AINodeId: " + aiNodeId); - } - if (defaultAINodeEndPoint == null) { - defaultAINodeEndPoint = AINodeClient.getCurrentEndpoint(); - } - return clientManager.borrowClient(defaultAINodeEndPoint); - } - - public void clear(TEndPoint endPoint) { - clientManager.clear(endPoint); - } - - public void clearAll() { - clientManager.close(); - } - - public IClientManager getRawClientManager() { - return clientManager; - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java new file mode 100644 index 0000000000000..5eaffc40af9cd --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatResp; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; +import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ClientManager; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.client.factory.ThriftClientFactory; +import org.apache.iotdb.commons.client.property.ThriftClientProperty; +import org.apache.iotdb.commons.client.sync.SyncThriftClientWithErrorHandler; +import org.apache.iotdb.commons.conf.CommonConfig; +import org.apache.iotdb.commons.conf.CommonDescriptor; +import org.apache.iotdb.commons.consensus.ConfigRegionId; +import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; +import org.apache.iotdb.db.conf.IoTDBConfig; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.ConfigNodeClient; +import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; +import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; + +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLHandshakeException; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class AINodeClient implements IAINodeRPCService.Iface, AutoCloseable, ThriftClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeClient.class); + + private static final CommonConfig COMMON_CONFIG = CommonDescriptor.getInstance().getConfig(); + private static final IoTDBConfig IOTDB_CONFIG = IoTDBDescriptor.getInstance().getConfig(); + + private TTransport transport; + + private final ThriftClientProperty property; + private IAINodeRPCService.Client client; + + private static final int MAX_RETRY = 5; + private static final int RETRY_INTERVAL_MS = 100; + public static final String MSG_ALL_RETRY_FAILED = + String.format( + "Failed to connect to AINode after %d retries, please check the status of AINode", + MAX_RETRY); + public static final String MSG_AINODE_CONNECTION_FAIL = + "Fail to connect to AINode from DataNode %s when executing %s."; + private static final String UNSUPPORTED_INVOCATION = + "This method is not supported for invocation by DataNode"; + + @Override + public TSStatus stopAINode() throws TException { + return executeRemoteCallWithRetry(() -> client.stopAINode()); + } + + @Override + public TShowModelsResp showModels(TShowModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showModels(req)); + } + + @Override + public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showLoadedModels(req)); + } + + @Override + public TShowAIDevicesResp showAIDevices() throws TException { + return executeRemoteCallWithRetry(() -> client.showAIDevices()); + } + + @Override + public TSStatus deleteModel(TDeleteModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.deleteModel(req)); + } + + @Override + public TRegisterModelResp registerModel(TRegisterModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.registerModel(req)); + } + + @Override + public TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) { + throw new UnsupportedOperationException(UNSUPPORTED_INVOCATION); + } + + @Override + public TSStatus createTrainingTask(TTrainingReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.createTrainingTask(req)); + } + + @Override + public TSStatus loadModel(TLoadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.loadModel(req)); + } + + @Override + public TSStatus unloadModel(TUnloadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.unloadModel(req)); + } + + @Override + public TInferenceResp inference(TInferenceReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.inference(req)); + } + + @Override + public TForecastResp forecast(TForecastReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.forecast(req)); + } + + @FunctionalInterface + private interface RemoteCall { + R apply() throws TException; + } + + ClientManager clientManager; + + private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = + ConfigNodeClientManager.getInstance(); + + private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); + + private R executeRemoteCallWithRetry(RemoteCall call) throws TException { + for (int attempt = 0; attempt < MAX_RETRY; attempt++) { + try { + return call.apply(); + } catch (TException e) { + final String message = + String.format( + MSG_AINODE_CONNECTION_FAIL, + IOTDB_CONFIG.getAddressAndPort(), + Thread.currentThread().getStackTrace()[2].getMethodName()); + LOGGER.warn(message, e); + CURRENT_LOCATION.set(null); + if (e.getCause() != null && e.getCause() instanceof SSLHandshakeException) { + throw e; + } + } + try { + TimeUnit.MILLISECONDS.sleep(RETRY_INTERVAL_MS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.warn( + "Unexpected interruption when waiting to try to connect to AINode, may because current node has been down. Will break current execution process to avoid meaningless wait."); + break; + } + tryToConnect(property.getConnectionTimeoutMs()); + } + throw new TException(MSG_ALL_RETRY_FAILED); + } + + private void tryToConnect(int timeoutMs) { + TEndPoint endpoint = getCurrentEndpoint(); + if (endpoint != null) { + try { + connect(endpoint, timeoutMs); + return; + } catch (TException e) { + LOGGER.warn("The current AINode may have been down {}, because", endpoint, e); + CURRENT_LOCATION.set(null); + } + } else { + LOGGER.warn("Cannot connect to any AINode due to there are no available ones."); + } + if (transport != null) { + transport.close(); + } + } + + public void connect(TEndPoint endpoint, int timeoutMs) throws TException { + transport = + COMMON_CONFIG.isEnableInternalSSL() + ? DeepCopyRpcTransportFactory.INSTANCE.getTransport( + endpoint.getIp(), + endpoint.getPort(), + timeoutMs, + COMMON_CONFIG.getTrustStorePath(), + COMMON_CONFIG.getTrustStorePwd(), + COMMON_CONFIG.getKeyStorePath(), + COMMON_CONFIG.getKeyStorePwd()) + : DeepCopyRpcTransportFactory.INSTANCE.getTransport( + // As there is a try-catch already, we do not need to use TSocket.wrap + endpoint.getIp(), endpoint.getPort(), timeoutMs); + if (!transport.isOpen()) { + transport.open(); + } + client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); + } + + public TEndPoint getCurrentEndpoint() { + TAINodeLocation loc = CURRENT_LOCATION.get(); + if (loc == null) { + loc = refreshFromConfigNode(); + } + return (loc == null) ? null : loc.getInternalEndPoint(); + } + + private TAINodeLocation refreshFromConfigNode() { + try (final ConfigNodeClient cn = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + final TGetAINodeLocationResp resp = cn.getAINodeLocation(); + if (resp.isSetAiNodeLocation()) { + final TAINodeLocation loc = resp.getAiNodeLocation(); + CURRENT_LOCATION.set(loc); + return loc; + } + } catch (Exception e) { + LoggerFactory.getLogger(AINodeClient.class) + .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); + } + return null; + } + + public AINodeClient( + ThriftClientProperty property, ClientManager clientManager) { + this.property = property; + this.clientManager = clientManager; + tryToConnect(property.getConnectionTimeoutMs()); + } + + public TTransport getTransport() { + return transport; + } + + @Override + public void close() { + clientManager.returnClient(AINodeClientManager.AINODE_ID_PLACEHOLDER, this); + } + + @Override + public void invalidate() { + Optional.ofNullable(transport).ifPresent(TTransport::close); + } + + @Override + public void invalidateAll() { + clientManager.clear(AINodeClientManager.AINODE_ID_PLACEHOLDER); + } + + @Override + public boolean printLogWhenEncounterException() { + return property.isPrintLogWhenEncounterException(); + } + + public static class Factory extends ThriftClientFactory { + + public Factory( + ClientManager clientClientManager, + ThriftClientProperty thriftClientProperty) { + super(clientClientManager, thriftClientProperty); + } + + @Override + public void destroyObject(Integer aiNodeId, PooledObject pooledObject) { + pooledObject.getObject().invalidate(); + } + + @Override + public PooledObject makeObject(Integer Integer) throws Exception { + return new DefaultPooledObject<>( + SyncThriftClientWithErrorHandler.newErrorHandler( + AINodeClient.class, + AINodeClient.class.getConstructor( + thriftClientProperty.getClass(), clientManager.getClass()), + thriftClientProperty, + clientManager)); + } + + @Override + public boolean validateObject(Integer Integer, PooledObject pooledObject) { + return Optional.ofNullable(pooledObject.getObject().getTransport()) + .map(TTransport::isOpen) + .orElse(false); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java new file mode 100644 index 0000000000000..698c8e7938836 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; + +public class AINodeClientManager { + + public static final int AINODE_ID_PLACEHOLDER = 0; + + private AINodeClientManager() { + // Empty constructor + } + + public static IClientManager getInstance() { + return AINodeClientManagerHolder.INSTANCE; + } + + private static class AINodeClientManagerHolder { + + private static final IClientManager INSTANCE = + new IClientManager.Factory() + .createClientManager(new DataNodeClientPoolFactory.AINodeClientPoolFactory()); + + private AINodeClientManagerHolder() { + // Empty constructor + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java index 66be144edefb7..b657523987dc1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java @@ -212,7 +212,8 @@ public Response executeFastLastQueryStatement( t = e; return Response.ok().entity(ExceptionHandler.tryCatchException(e)).build(); } finally { - long costTime = System.nanoTime() - startTime; + long endTime = System.nanoTime(); + long costTime = endTime - startTime; StatementType statementType = Optional.ofNullable(statement) @@ -227,7 +228,18 @@ public Response executeFastLastQueryStatement( if (queryId != null) { COORDINATOR.cleanupQueryExecution(queryId); } else { - recordQueries(() -> costTime, new FastLastQueryContentSupplier(prefixPathList), t); + IClientSession clientSession = SESSION_MANAGER.getCurrSession(); + + Supplier contentOfQuerySupplier = new FastLastQueryContentSupplier(prefixPathList); + COORDINATOR.recordCurrentQueries( + null, + startTime / 1_000_000, + endTime / 1_000_000, + costTime, + contentOfQuerySupplier, + clientSession.getUsername(), + clientSession.getClientAddress()); + recordQueries(() -> costTime, contentOfQuerySupplier, t); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java index 6aa862b9242f4..ea90fbafeccde 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java @@ -33,6 +33,9 @@ public class ClientSession extends IClientSession { private final Map> statementIdToQueryId = new ConcurrentHashMap<>(); + // Map from statement name to PreparedStatementInfo + private final Map preparedStatements = new ConcurrentHashMap<>(); + public ClientSession(Socket clientSocket) { this.clientSocket = clientSocket; } @@ -103,4 +106,24 @@ public static void removeQueryId( } } } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + preparedStatements.put(statementName, info); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + return preparedStatements.remove(statementName); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + return preparedStatements.get(statementName); + } + + @Override + public Set getPreparedStatementNames() { + return preparedStatements.keySet(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java index 351806de099c9..97585673e824f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java @@ -188,6 +188,37 @@ public void setDatabaseName(@Nullable String databaseName) { this.databaseName = databaseName; } + /** + * Add a prepared statement to this session. + * + * @param statementName the name of the prepared statement + * @param info the prepared statement information + */ + public abstract void addPreparedStatement(String statementName, PreparedStatementInfo info); + + /** + * Remove a prepared statement from this session. + * + * @param statementName the name of the prepared statement + * @return the removed prepared statement info, or null if not found + */ + public abstract PreparedStatementInfo removePreparedStatement(String statementName); + + /** + * Get a prepared statement from this session. + * + * @param statementName the name of the prepared statement + * @return the prepared statement info, or null if not found + */ + public abstract PreparedStatementInfo getPreparedStatement(String statementName); + + /** + * Get all prepared statement names in this session. + * + * @return set of prepared statement names + */ + public abstract Set getPreparedStatementNames(); + public long getLastActiveTime() { return lastActiveTime; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java index 3c72d083a8c74..ed87d0b0ee324 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java @@ -88,4 +88,28 @@ public void addQueryId(Long statementId, long queryId) { public void removeQueryId(Long statementId, Long queryId) { ClientSession.removeQueryId(statementIdToQueryId, statementId, queryId); } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } + + @Override + public Set getPreparedStatementNames() { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java index ae9e2cd036165..c0b68e885a14d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java @@ -76,4 +76,28 @@ public void addQueryId(Long statementId, long queryId) { public void removeQueryId(Long statementId, Long queryId) { throw new UnsupportedOperationException(); } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } + + @Override + public Set getPreparedStatementNames() { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java new file mode 100644 index 0000000000000..0bfc750c4ba07 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.session; + +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * Information about a prepared statement stored in a session. The AST is cached here to avoid + * reparsing on EXECUTE. + */ +public class PreparedStatementInfo { + + private final String statementName; + private final Statement sql; // Cached AST (contains Parameter nodes) + private final long createTime; + private final long memorySizeInBytes; // Memory size allocated for this PreparedStatement + + public PreparedStatementInfo(String statementName, Statement sql, long memorySizeInBytes) { + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + this.createTime = System.currentTimeMillis(); + this.memorySizeInBytes = memorySizeInBytes; + } + + public PreparedStatementInfo( + String statementName, Statement sql, long createTime, long memorySizeInBytes) { + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + this.createTime = createTime; + this.memorySizeInBytes = memorySizeInBytes; + } + + public String getStatementName() { + return statementName; + } + + public Statement getSql() { + return sql; + } + + public long getCreateTime() { + return createTime; + } + + public long getMemorySizeInBytes() { + return memorySizeInBytes; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PreparedStatementInfo that = (PreparedStatementInfo) o; + return Objects.equals(statementName, that.statementName) && Objects.equals(sql, that.sql); + } + + @Override + public int hashCode() { + return Objects.hash(statementName, sql); + } + + @Override + public String toString() { + return "PreparedStatementInfo{" + + "statementName='" + + statementName + + '\'' + + ", sql=" + + sql + + ", createTime=" + + createTime + + '}'; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java index fa830ace3fbcf..d122c3c7dc5f9 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java @@ -22,12 +22,17 @@ import org.apache.iotdb.service.rpc.thrift.TSConnectionType; import java.util.Collections; +import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; public class RestClientSession extends IClientSession { private final String clientID; + // Map from statement name to PreparedStatementInfo + private final Map preparedStatements = new ConcurrentHashMap<>(); + public RestClientSession(String clientID) { this.clientID = clientID; } @@ -76,4 +81,24 @@ public void addQueryId(Long statementId, long queryId) { public void removeQueryId(Long statementId, Long queryId) { throw new UnsupportedOperationException(); } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + preparedStatements.put(statementName, info); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + return preparedStatements.remove(statementName); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + return preparedStatements.get(statementName); + } + + @Override + public Set getPreparedStatementNames() { + return preparedStatements.keySet(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java index a4a28efa5ec3d..e5e95d4f82ca4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java @@ -40,6 +40,7 @@ import org.apache.iotdb.db.protocol.thrift.OperationType; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; import org.apache.iotdb.db.queryengine.common.SessionInfo; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.PreparedStatementMemoryManager; import org.apache.iotdb.db.storageengine.dataregion.read.control.QueryResourceManager; import org.apache.iotdb.db.utils.DataNodeAuthUtils; import org.apache.iotdb.metrics.utils.MetricLevel; @@ -276,6 +277,7 @@ public boolean closeSession(IClientSession session, LongConsumer releaseByQueryI } private void releaseSessionResource(IClientSession session, LongConsumer releaseQueryResource) { + // Release query resources Iterable statementIds = session.getStatementIds(); if (statementIds != null) { for (Long statementId : statementIds) { @@ -287,6 +289,17 @@ private void releaseSessionResource(IClientSession session, LongConsumer release } } } + + // Release PreparedStatement memory resources + try { + PreparedStatementMemoryManager.getInstance().releaseAllForSession(session); + } catch (Exception e) { + LOGGER.warn( + "Failed to release PreparedStatement resources for session {}: {}", + session, + e.getMessage(), + e); + } } public TSStatus closeOperation( @@ -295,6 +308,7 @@ public TSStatus closeOperation( long statementId, boolean haveStatementId, boolean haveSetQueryId, + String preparedStatementName, LongConsumer releaseByQueryId) { if (!checkLogin(session)) { return RpcUtils.getStatus( @@ -307,7 +321,7 @@ public TSStatus closeOperation( if (haveSetQueryId) { this.closeDataset(session, statementId, queryId, releaseByQueryId); } else { - this.closeStatement(session, statementId, releaseByQueryId); + this.closeStatement(session, statementId, preparedStatementName, releaseByQueryId); } return RpcUtils.getStatus(TSStatusCode.SUCCESS_STATUS); } else { @@ -342,14 +356,35 @@ public long requestStatementId(IClientSession session) { } public void closeStatement( - IClientSession session, long statementId, LongConsumer releaseByQueryId) { + IClientSession session, + long statementId, + String preparedStatementName, + LongConsumer releaseByQueryId) { Set queryIdSet = session.removeStatementId(statementId); if (queryIdSet != null) { for (Long queryId : queryIdSet) { releaseByQueryId.accept(queryId); } } - session.removeStatementId(statementId); + + // If preparedStatementName is provided, release the prepared statement resources + if (preparedStatementName != null && !preparedStatementName.isEmpty()) { + try { + PreparedStatementInfo removedInfo = session.removePreparedStatement(preparedStatementName); + if (removedInfo != null) { + // Release the memory allocated for this PreparedStatement + PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); + } + } catch (Exception e) { + LOGGER.warn( + "Failed to release PreparedStatement '{}' resources when closing statement {} for session {}: {}", + preparedStatementName, + statementId, + session, + e.getMessage(), + e); + } + } } public long requestQueryId(IClientSession session, Long statementId) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java index e633caa45f6a4..9b7efd827806a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java @@ -71,6 +71,7 @@ public ServerContext createContext(TProtocol in, TProtocol out) { public void deleteContext(ServerContext context, TProtocol in, TProtocol out) { getSessionManager().removeCurrSession(); + if (context != null && factory != null) { ((JudgableServerContext) context).whenDisconnect(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java index ea7646efb2b24..167a1fa914fd2 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java @@ -365,16 +365,30 @@ private TSExecuteStatementResp executeStatementInternal( queryId = SESSION_MANAGER.requestQueryId(clientSession, req.statementId); - result = - COORDINATOR.executeForTreeModel( - s, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - partitionFetcher, - schemaFetcher, - req.getTimeout(), - true); + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchStatement( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + config.getQueryTimeoutThreshold(), + true); + } else { + result = + COORDINATOR.executeForTreeModel( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + req.getTimeout(), + true); + } } } else { org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement s = @@ -396,17 +410,32 @@ private TSExecuteStatementResp executeStatementInternal( queryId = SESSION_MANAGER.requestQueryId(clientSession, req.statementId); - result = - COORDINATOR.executeForTableModel( - s, - relationSqlParser, - clientSession, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - metadata, - req.getTimeout(), - true); + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchTableStatement( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + config.getQueryTimeoutThreshold(), + true); + } else { + result = + COORDINATOR.executeForTableModel( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + req.getTimeout(), + true); + } } if (result.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode() @@ -1021,13 +1050,23 @@ public TSExecuteStatementResp executeFastLastDataQueryForOnePrefixPath( resp.setMoreData(false); - long costTime = System.nanoTime() - startTime; + long endTime = System.nanoTime(); + long costTime = endTime - startTime; CommonUtils.addStatementExecutionLatency( OperationType.EXECUTE_QUERY_STATEMENT, StatementType.FAST_LAST_QUERY.name(), costTime); CommonUtils.addQueryLatency(StatementType.FAST_LAST_QUERY, costTime); - recordQueries( - () -> costTime, () -> String.format("thrift fastLastQuery %s", prefixPath), null); + + String statement = String.format("thrift fastLastQuery %s", prefixPath); + COORDINATOR.recordCurrentQueries( + null, + startTime / 1_000_000, + endTime / 1_000_000, + costTime, + () -> statement, + clientSession.getUsername(), + clientSession.getClientAddress()); + recordQueries(() -> costTime, () -> statement, null); return resp; } catch (final Exception e) { return RpcUtils.getTSExecuteStatementResp( @@ -1445,6 +1484,7 @@ public TSStatus closeOperation(TSCloseOperationReq req) { req.statementId, req.isSetStatementId(), req.isSetQueryId(), + req.isSetPreparedStatementName() ? req.getPreparedStatementName() : null, COORDINATOR::cleanupQueryExecution); } @@ -1845,16 +1885,31 @@ public TSStatus executeBatchStatement(TSExecuteBatchStatementReq req) { queryId = SESSION_MANAGER.requestQueryId(); type = s.getType() == null ? null : s.getType().name(); // create and cache dataset - result = - COORDINATOR.executeForTreeModel( - s, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - partitionFetcher, - schemaFetcher, - config.getQueryTimeoutThreshold(), - false); + + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchStatement( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + config.getQueryTimeoutThreshold(), + false); + } else { + result = + COORDINATOR.executeForTreeModel( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + config.getQueryTimeoutThreshold(), + false); + } } } else { @@ -1875,17 +1930,32 @@ public TSStatus executeBatchStatement(TSExecuteBatchStatementReq req) { queryId = SESSION_MANAGER.requestQueryId(); - result = - COORDINATOR.executeForTableModel( - s, - relationSqlParser, - clientSession, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - metadata, - config.getQueryTimeoutThreshold(), - false); + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchTableStatement( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + config.getQueryTimeoutThreshold(), + false); + } else { + result = + COORDINATOR.executeForTableModel( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + config.getQueryTimeoutThreshold(), + false); + } } results.add(result.status); @@ -3190,4 +3260,180 @@ public void handleClientExit() { PipeDataNodeAgent.receiver().legacy().handleClientExit(); SubscriptionAgent.receiver().handleClientExit(); } + + /** + * Executes tree-model Statement sub-statements in batch. + * + * @param statement the Statement to be executed + * @param queryId the query ID + * @param sessionInfo the session information + * @param statementStr the SQL statement string + * @param partitionFetcher the partition fetcher + * @param schemaFetcher the schema fetcher + * @param timeoutMs the timeout in milliseconds + * @return the execution result + */ + private ExecutionResult executeBatchStatement( + final Statement statement, + final long queryId, + final SessionInfo sessionInfo, + final String statementStr, + final IPartitionFetcher partitionFetcher, + final ISchemaFetcher schemaFetcher, + final long timeoutMs, + final boolean userQuery) { + + ExecutionResult result = null; + final List subStatements = statement.getSubStatements(); + final int totalSubStatements = subStatements.size(); + + LOGGER.info( + "Start batch executing {} sub-statement(s) in tree model, queryId: {}", + totalSubStatements, + queryId); + + for (int i = 0; i < totalSubStatements; i++) { + final Statement subStatement = subStatements.get(i); + + LOGGER.info( + "Executing sub-statement {}/{} in tree model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + + result = + COORDINATOR.executeForTreeModel( + subStatement, + queryId, + sessionInfo, + statementStr, + partitionFetcher, + schemaFetcher, + timeoutMs, + userQuery); + + // Exit early if any sub-statement execution fails + if (result != null + && result.status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + final int completed = i + 1; + final int remaining = totalSubStatements - completed; + final double percentage = (completed * 100.0) / totalSubStatements; + LOGGER.warn( + "Failed to execute sub-statement {}/{} in tree model, queryId: {}, completed: {}, remaining: {}, progress: {}%, error: {}", + i + 1, + totalSubStatements, + queryId, + completed, + remaining, + String.format("%.2f", percentage), + result.status.getMessage()); + break; + } + + LOGGER.info( + "Successfully executed sub-statement {}/{} in tree model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + } + + if (result != null && result.status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.info( + "Completed batch executing all {} sub-statement(s) in tree model, queryId: {}", + totalSubStatements, + queryId); + } + + return result; + } + + /** + * Executes table-model Statement sub-statements in batch. + * + * @param statement the Statement to be executed + * @param relationSqlParser the relational SQL parser + * @param clientSession the client session + * @param queryId the query ID + * @param sessionInfo the session information + * @param statementStr the SQL statement string + * @param metadata the metadata + * @param timeoutMs the timeout in milliseconds + * @return the execution result + */ + private ExecutionResult executeBatchTableStatement( + final org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statement, + final SqlParser relationSqlParser, + final IClientSession clientSession, + final long queryId, + final SessionInfo sessionInfo, + final String statementStr, + final Metadata metadata, + final long timeoutMs, + final boolean userQuery) { + + ExecutionResult result = null; + List + subStatements = statement.getSubStatements(); + int totalSubStatements = subStatements.size(); + LOGGER.info( + "Start batch executing {} sub-statement(s) in table model, queryId: {}", + totalSubStatements, + queryId); + + for (int i = 0; i < totalSubStatements; i++) { + final org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement subStatement = + subStatements.get(i); + + LOGGER.info( + "Executing sub-statement {}/{} in table model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + + result = + COORDINATOR.executeForTableModel( + subStatement, + relationSqlParser, + clientSession, + queryId, + sessionInfo, + statementStr, + metadata, + timeoutMs, + userQuery); + + // Exit early if any sub-statement execution fails + if (result != null + && result.status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + final int completed = i + 1; + final int remaining = totalSubStatements - completed; + final double percentage = (completed * 100.0) / totalSubStatements; + LOGGER.warn( + "Failed to execute sub-statement {}/{} in table model, queryId: {}, completed: {}, remaining: {}, progress: {}%, error: {}", + i + 1, + totalSubStatements, + queryId, + completed, + remaining, + String.format("%.2f", percentage), + result.status.getMessage()); + break; + } + + LOGGER.info( + "Successfully executed sub-statement {}/{} in table model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + } + + if (result != null && result.status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.info( + "Completed batch executing all {} sub-statement(s) in table model, queryId: {}", + totalSubStatements, + queryId); + } + + return result; + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java index a59ce8334bf51..44e67aa7ba608 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java @@ -19,6 +19,7 @@ package org.apache.iotdb.db.queryengine.common; +import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId; import org.apache.tsfile.utils.ReadWriteIOUtils; @@ -37,6 +38,8 @@ public class QueryId { public static final QueryId MOCK_QUERY_ID = QueryId.valueOf("mock_query_id"); + private static final int DATANODE_ID = IoTDBDescriptor.getInstance().getConfig().getDataNodeId(); + private final String id; private int nextPlanNodeIndex; @@ -67,6 +70,10 @@ public String getId() { return id; } + public static int getDataNodeId() { + return DATANODE_ID; + } + @Override public String toString() { return id; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index 7126af78b8b51..29e5580311d0b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -19,18 +19,15 @@ package org.apache.iotdb.db.queryengine.execution.operator.process.ai; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; import org.apache.iotdb.db.queryengine.execution.operator.Operator; import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; @@ -75,7 +72,6 @@ public class InferenceOperator implements ProcessOperator { private int resultIndex = 0; private List results; private final TsBlockSerde serde = new TsBlockSerde(); - private InferenceWindowType windowType = null; private final boolean generateTimeColumn; private long maxTimestamp; @@ -109,10 +105,6 @@ public InferenceOperator( this.maxReturnSize = maxReturnSize; this.totalRow = 0; - if (modelInferenceDescriptor.getInferenceWindowParameter() != null) { - windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType(); - } - if (generateTimeColumn) { this.interval = 0; this.minTimestamp = Long.MAX_VALUE; @@ -237,62 +229,6 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) { } } - private TWindowParams getWindowParams() { - TWindowParams windowParams; - if (windowType == null) { - return null; - } - if (windowType == InferenceWindowType.COUNT) { - CountInferenceWindowParameter countInferenceWindowParameter = - (CountInferenceWindowParameter) modelInferenceDescriptor.getInferenceWindowParameter(); - windowParams = new TWindowParams(); - windowParams.setWindowInterval((int) countInferenceWindowParameter.getInterval()); - windowParams.setWindowStep((int) countInferenceWindowParameter.getStep()); - } else { - windowParams = null; - } - return windowParams; - } - - private TsBlock preProcess(TsBlock inputTsBlock) { - // boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn(); - boolean notBuiltIn = false; - if (windowType == null || windowType == InferenceWindowType.HEAD) { - if (notBuiltIn - && totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data does not match the model input %s. Try to use LIMIT in SQL or WINDOW in CALL INFERENCE", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - return inputTsBlock; - } else if (windowType == InferenceWindowType.COUNT) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - } else if (windowType == InferenceWindowType.TAIL) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - // Tail window logic: get the latest data for inference - long windowSize = - (int) - ((BottomInferenceWindowParameter) - modelInferenceDescriptor.getInferenceWindowParameter()) - .getWindowSize(); - return inputTsBlock.subTsBlock((int) (totalRow - windowSize)); - } - return inputTsBlock; - } - private void submitInferenceTask() { if (generateTimeColumn) { @@ -301,20 +237,16 @@ private void submitInferenceTask() { TsBlock inputTsBlock = inputTsBlockBuilder.build(); - TsBlock finalInputTsBlock = preProcess(inputTsBlock); - TWindowParams windowParams = getWindowParams(); - inferenceExecutionFuture = Futures.submit( () -> { try (AINodeClient client = AINodeClientManager.getInstance() - .borrowClient(modelInferenceDescriptor.getTargetAINode())) { + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { return client.inference( - modelInferenceDescriptor.getModelName(), - finalInputTsBlock, - modelInferenceDescriptor.getInferenceAttributes(), - windowParams); + new TInferenceReq( + modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock)) + .setInferenceAttributes(modelInferenceDescriptor.getInferenceAttributes())); } catch (Exception e) { throw new ModelInferenceProcessException(e.getMessage()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java index fc68881656595..76d84c741de5f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java @@ -19,14 +19,11 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; import org.apache.iotdb.common.rpc.thrift.Model; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.audit.UserEntity; import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.conf.IoTDBConstant; @@ -68,11 +65,10 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClient; import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; +import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.plan.Coordinator; import org.apache.iotdb.db.queryengine.plan.execution.IQueryExecution; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.ShowCreateViewTask; @@ -157,8 +153,6 @@ public static Iterator getSupplier( return new SubscriptionSupplier(dataTypes, userEntity); case InformationSchema.VIEWS: return new ViewsSupplier(dataTypes, userEntity); - case InformationSchema.MODELS: - return new ModelsSupplier(dataTypes); case InformationSchema.FUNCTIONS: return new FunctionsSupplier(dataTypes); case InformationSchema.CONFIGURATIONS: @@ -173,6 +167,10 @@ public static Iterator getSupplier( return new DataNodesSupplier(dataTypes, userEntity); case InformationSchema.CONNECTIONS: return new ConnectionsSupplier(dataTypes, userEntity); + case InformationSchema.CURRENT_QUERIES: + return new CurrentQueriesSupplier(dataTypes, userEntity); + case InformationSchema.QUERIES_COSTS_HISTOGRAM: + return new QueriesCostsHistogramSupplier(dataTypes, userEntity); default: throw new UnsupportedOperationException("Unknown table: " + tableName); } @@ -208,14 +206,11 @@ protected void constructLine() { final IQueryExecution queryExecution = queryExecutions.get(nextConsumedIndex); if (queryExecution.getSQLDialect().equals(IClientSession.SqlDialect.TABLE)) { - final String[] splits = queryExecution.getQueryId().split("_"); - final int dataNodeId = Integer.parseInt(splits[splits.length - 1]); - columnBuilders[0].writeBinary(BytesUtils.valueOf(queryExecution.getQueryId())); columnBuilders[1].writeLong( TimestampPrecisionUtils.convertToCurrPrecision( queryExecution.getStartExecutionTime(), TimeUnit.MILLISECONDS)); - columnBuilders[2].writeInt(dataNodeId); + columnBuilders[2].writeInt(QueryId.getDataNodeId()); columnBuilders[3].writeFloat( (float) (currTime - queryExecution.getStartExecutionTime()) / 1000); columnBuilders[4].writeBinary( @@ -798,112 +793,6 @@ public boolean hasNext() { } } - private static class ModelsSupplier extends TsBlockSupplier { - private final ModelIterator iterator; - - private ModelsSupplier(final List dataTypes) throws Exception { - super(dataTypes); - final TEndPoint ep = AINodeClient.getCurrentEndpoint(); - try (final AINodeClient ai = AINodeClientManager.getInstance().borrowClient(ep)) { - iterator = new ModelIterator(ai.showModels(new TShowModelsReq())); - } - } - - private static class ModelIterator implements Iterator { - - private int index = 0; - private final TShowModelsResp resp; - - private ModelIterator(TShowModelsResp resp) { - this.resp = resp; - } - - @Override - public boolean hasNext() { - return index < resp.getModelIdListSize(); - } - - @Override - public ModelInfoInString next() { - String modelId = resp.getModelIdList().get(index++); - return new ModelInfoInString( - modelId, - resp.getModelTypeMap().get(modelId), - resp.getCategoryMap().get(modelId), - resp.getStateMap().get(modelId)); - } - } - - private static class ModelInfoInString { - - private final String modelId; - private final String modelType; - private final String category; - private final String state; - - public ModelInfoInString(String modelId, String modelType, String category, String state) { - this.modelId = modelId; - this.modelType = modelType; - this.category = category; - this.state = state; - } - - public String getModelId() { - return modelId; - } - - public String getModelType() { - return modelType; - } - - public String getCategory() { - return category; - } - - public String getState() { - return state; - } - } - - @Override - protected void constructLine() { - final ModelInfoInString modelInfo = iterator.next(); - columnBuilders[0].writeBinary( - new Binary(modelInfo.getModelId(), TSFileConfig.STRING_CHARSET)); - columnBuilders[1].writeBinary( - new Binary(modelInfo.getModelType(), TSFileConfig.STRING_CHARSET)); - columnBuilders[2].writeBinary( - new Binary(modelInfo.getCategory(), TSFileConfig.STRING_CHARSET)); - columnBuilders[3].writeBinary(new Binary(modelInfo.getState(), TSFileConfig.STRING_CHARSET)); - // if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) { - // columnBuilders[3].writeBinary( - // new Binary( - // INPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + INPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // columnBuilders[4].writeBinary( - // new Binary(ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // } else { - // columnBuilders[3].appendNull(); - // columnBuilders[4].writeBinary( - // new Binary("Built-in model in IoTDB", TSFileConfig.STRING_CHARSET)); - // } - resultBuilder.declarePosition(); - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - } - private static class FunctionsSupplier extends TsBlockSupplier { private final Iterator udfIterator; @@ -1294,4 +1183,140 @@ public boolean hasNext() { return sessionConnectionIterator.hasNext(); } } + + private static class CurrentQueriesSupplier extends TsBlockSupplier { + private int nextConsumedIndex; + private List queriesInfo; + + private CurrentQueriesSupplier(final List dataTypes, final UserEntity userEntity) { + super(dataTypes); + queriesInfo = Coordinator.getInstance().getCurrentQueriesInfo(); + try { + accessControl.checkUserGlobalSysPrivilege(userEntity); + } catch (final AccessDeniedException e) { + queriesInfo = + queriesInfo.stream() + .filter(iQueryInfo -> userEntity.getUsername().equals(iQueryInfo.getUser())) + .collect(Collectors.toList()); + } + } + + @Override + protected void constructLine() { + final Coordinator.StatedQueriesInfo queryInfo = queriesInfo.get(nextConsumedIndex); + columnBuilders[0].writeBinary(BytesUtils.valueOf(queryInfo.getQueryId())); + columnBuilders[1].writeBinary(BytesUtils.valueOf(queryInfo.getQueryState())); + columnBuilders[2].writeLong( + TimestampPrecisionUtils.convertToCurrPrecision( + queryInfo.getStartTime(), TimeUnit.MILLISECONDS)); + if (queryInfo.getEndTime() == Coordinator.QueryInfo.DEFAULT_END_TIME) { + columnBuilders[3].appendNull(); + } else { + columnBuilders[3].writeLong( + TimestampPrecisionUtils.convertToCurrPrecision( + queryInfo.getEndTime(), TimeUnit.MILLISECONDS)); + } + columnBuilders[4].writeInt(QueryId.getDataNodeId()); + columnBuilders[5].writeFloat(queryInfo.getCostTime()); + columnBuilders[6].writeBinary(BytesUtils.valueOf(queryInfo.getStatement())); + columnBuilders[7].writeBinary(BytesUtils.valueOf(queryInfo.getUser())); + columnBuilders[8].writeBinary(BytesUtils.valueOf(queryInfo.getClientHost())); + resultBuilder.declarePosition(); + nextConsumedIndex++; + } + + @Override + public boolean hasNext() { + return nextConsumedIndex < queriesInfo.size(); + } + } + + private static class QueriesCostsHistogramSupplier extends TsBlockSupplier { + private int nextConsumedIndex; + private static final Binary[] BUCKETS = + new Binary[] { + BytesUtils.valueOf("[0,1)"), + BytesUtils.valueOf("[1,2)"), + BytesUtils.valueOf("[2,3)"), + BytesUtils.valueOf("[3,4)"), + BytesUtils.valueOf("[4,5)"), + BytesUtils.valueOf("[5,6)"), + BytesUtils.valueOf("[6,7)"), + BytesUtils.valueOf("[7,8)"), + BytesUtils.valueOf("[8,9)"), + BytesUtils.valueOf("[9,10)"), + BytesUtils.valueOf("[10,11)"), + BytesUtils.valueOf("[11,12)"), + BytesUtils.valueOf("[12,13)"), + BytesUtils.valueOf("[13,14)"), + BytesUtils.valueOf("[14,15)"), + BytesUtils.valueOf("[15,16)"), + BytesUtils.valueOf("[16,17)"), + BytesUtils.valueOf("[17,18)"), + BytesUtils.valueOf("[18,19)"), + BytesUtils.valueOf("[19,20)"), + BytesUtils.valueOf("[20,21)"), + BytesUtils.valueOf("[21,22)"), + BytesUtils.valueOf("[22,23)"), + BytesUtils.valueOf("[23,24)"), + BytesUtils.valueOf("[24,25)"), + BytesUtils.valueOf("[25,26)"), + BytesUtils.valueOf("[26,27)"), + BytesUtils.valueOf("[27,28)"), + BytesUtils.valueOf("[28,29)"), + BytesUtils.valueOf("[29,30)"), + BytesUtils.valueOf("[30,31)"), + BytesUtils.valueOf("[31,32)"), + BytesUtils.valueOf("[32,33)"), + BytesUtils.valueOf("[33,34)"), + BytesUtils.valueOf("[34,35)"), + BytesUtils.valueOf("[35,36)"), + BytesUtils.valueOf("[36,37)"), + BytesUtils.valueOf("[37,38)"), + BytesUtils.valueOf("[38,39)"), + BytesUtils.valueOf("[39,40)"), + BytesUtils.valueOf("[40,41)"), + BytesUtils.valueOf("[41,42)"), + BytesUtils.valueOf("[42,43)"), + BytesUtils.valueOf("[43,44)"), + BytesUtils.valueOf("[44,45)"), + BytesUtils.valueOf("[45,46)"), + BytesUtils.valueOf("[46,47)"), + BytesUtils.valueOf("[47,48)"), + BytesUtils.valueOf("[48,49)"), + BytesUtils.valueOf("[49,50)"), + BytesUtils.valueOf("[50,51)"), + BytesUtils.valueOf("[51,52)"), + BytesUtils.valueOf("[52,53)"), + BytesUtils.valueOf("[53,54)"), + BytesUtils.valueOf("[54,55)"), + BytesUtils.valueOf("[55,56)"), + BytesUtils.valueOf("[56,57)"), + BytesUtils.valueOf("[57,58)"), + BytesUtils.valueOf("[58,59)"), + BytesUtils.valueOf("[59,60)"), + BytesUtils.valueOf("60+") + }; + private final int[] currentQueriesCostHistogram; + + private QueriesCostsHistogramSupplier( + final List dataTypes, final UserEntity userEntity) { + super(dataTypes); + accessControl.checkUserGlobalSysPrivilege(userEntity); + currentQueriesCostHistogram = Coordinator.getInstance().getCurrentQueriesCostHistogram(); + } + + @Override + protected void constructLine() { + columnBuilders[0].writeBinary(BUCKETS[nextConsumedIndex]); + columnBuilders[1].writeInt(currentQueriesCostHistogram[nextConsumedIndex]); + resultBuilder.declarePosition(); + nextConsumedIndex++; + } + + @Override + public boolean hasNext() { + return nextConsumedIndex < 61; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java index f10600cbda949..78cdee720da22 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java @@ -26,18 +26,24 @@ import org.apache.iotdb.commons.client.sync.SyncDataNodeInternalServiceClient; import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory; import org.apache.iotdb.commons.concurrent.ThreadName; +import org.apache.iotdb.commons.concurrent.threadpool.ScheduledExecutorUtil; import org.apache.iotdb.commons.conf.CommonConfig; import org.apache.iotdb.commons.conf.CommonDescriptor; import org.apache.iotdb.commons.conf.IoTDBConstant; +import org.apache.iotdb.commons.memory.IMemoryBlock; +import org.apache.iotdb.commons.memory.MemoryBlockType; import org.apache.iotdb.db.auth.AuthorityChecker; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; import org.apache.iotdb.db.queryengine.common.DataNodeEndPoints; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.common.SessionInfo; import org.apache.iotdb.db.queryengine.execution.QueryIdGenerator; +import org.apache.iotdb.db.queryengine.execution.QueryState; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; import org.apache.iotdb.db.queryengine.plan.analyze.schema.ISchemaFetcher; @@ -49,6 +55,7 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.TreeConfigTaskVisitor; import org.apache.iotdb.db.queryengine.plan.planner.LocalExecutionPlanner; import org.apache.iotdb.db.queryengine.plan.planner.TreeModelPlanner; +import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; import org.apache.iotdb.db.queryengine.plan.relational.planner.TableModelPlanner; @@ -56,6 +63,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.DistributedOptimizeFactory; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.LogicalOptimizeFactory; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ParameterExtractor; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.AddColumn; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.AlterDB; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ClearCache; @@ -64,6 +72,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateModel; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTable; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTraining; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Deallocate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DescribeTable; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropColumn; @@ -71,13 +80,19 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropFunction; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropModel; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropTable; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Execute; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExecuteImmediate; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExtendRegion; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Flush; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.KillQuery; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LoadConfiguration; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LoadModel; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.MigrateRegion; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PipeStatement; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Prepare; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ReconstructRegion; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.RelationalAuthorStatement; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.RemoveAINode; @@ -128,21 +143,37 @@ import org.apache.iotdb.db.utils.SetThreadName; import org.apache.thrift.TBase; +import org.apache.tsfile.utils.Accountable; +import org.apache.tsfile.utils.RamUsageEstimator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.BlockingDeque; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.LongSupplier; import java.util.function.Supplier; +import java.util.stream.Collectors; import static org.apache.iotdb.commons.utils.StatusUtils.needRetry; +import static org.apache.iotdb.db.queryengine.plan.Coordinator.QueryInfo.DEFAULT_END_TIME; import static org.apache.iotdb.db.utils.CommonUtils.getContentOfRequest; +import static org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.tsfile.utils.RamUsageEstimator.sizeOfCharArray; /** * The coordinator for MPP. It manages all the queries which are executed in current Node. And it @@ -184,14 +215,50 @@ public class Coordinator { private static final Coordinator INSTANCE = new Coordinator(); + private static final IMemoryBlock coordinatorMemoryBlock; + private final ConcurrentHashMap queryExecutionMap; + private final BlockingDeque currentQueriesInfo = new LinkedBlockingDeque<>(); + private final AtomicInteger[] currentQueriesCostHistogram = new AtomicInteger[61]; + private final ScheduledExecutorService retryFailTasksExecutor = + IoTDBThreadPoolFactory.newSingleThreadScheduledExecutor( + ThreadName.EXPIRED_QUERIES_INFO_CLEAR.getName()); + private final StatementRewrite statementRewrite; private final List logicalPlanOptimizers; private final List distributionPlanOptimizers; private final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier; private final TypeManager typeManager; + { + for (int i = 0; i < 61; i++) { + currentQueriesCostHistogram[i] = new AtomicInteger(); + } + + ScheduledExecutorUtil.safelyScheduleWithFixedDelay( + retryFailTasksExecutor, + this::clearExpiredQueriesInfoTask, + 1_000L, + 1_000L, + TimeUnit.MILLISECONDS); + LOGGER.info("Expired-Queries-Info-Clear thread is successfully started."); + } + + static { + coordinatorMemoryBlock = + IoTDBDescriptor.getInstance() + .getMemoryConfig() + .getCoordinatorMemoryManager() + .exactAllocate("Coordinator", MemoryBlockType.DYNAMIC); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Initialized shared MemoryBlock 'Coordinator' with all available memory: {} bytes", + coordinatorMemoryBlock.getTotalMemorySizeInBytes()); + } + } + private Coordinator() { this.queryExecutionMap = new ConcurrentHashMap<>(); this.typeManager = new InternalTypeManager(); @@ -401,6 +468,8 @@ private IQueryExecution createQueryExecutionForTableModel( distributionPlanOptimizers, AuthorityChecker.getAccessControl(), dataNodeLocationSupplier, + Collections.emptyList(), + Collections.emptyMap(), typeManager); return new QueryExecution(tableModelPlanner, queryContext, executor); } @@ -475,7 +544,9 @@ private IQueryExecution createQueryExecutionForTableModel( || statement instanceof LoadModel || statement instanceof UnloadModel || statement instanceof ShowLoadedModels - || statement instanceof RemoveRegion) { + || statement instanceof RemoveRegion + || statement instanceof Prepare + || statement instanceof Deallocate) { return new ConfigExecution( queryContext, null, @@ -485,12 +556,54 @@ private IQueryExecution createQueryExecutionForTableModel( clientSession, metadata, AuthorityChecker.getAccessControl(), typeManager), queryContext)); } + // Initialize variables for TableModelPlanner + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statementToUse = statement; + List parameters = Collections.emptyList(); + Map, Expression> parameterLookup = Collections.emptyMap(); + + if (statement instanceof Execute) { + Execute executeStatement = (Execute) statement; + String statementName = executeStatement.getStatementName().getValue(); + + // Get prepared statement from session (contains cached AST) + PreparedStatementInfo preparedInfo = clientSession.getPreparedStatement(statementName); + if (preparedInfo == null) { + throw new SemanticException( + String.format("Prepared statement '%s' does not exist", statementName)); + } + + // Use cached AST + statementToUse = preparedInfo.getSql(); + + // Bind parameters: create parameterLookup map + // Note: bindParameters() internally validates parameter count + parameterLookup = + ParameterExtractor.bindParameters(statementToUse, executeStatement.getParameters()); + parameters = new ArrayList<>(executeStatement.getParameters()); + + } else if (statement instanceof ExecuteImmediate) { + ExecuteImmediate executeImmediateStatement = (ExecuteImmediate) statement; + + // EXECUTE IMMEDIATE needs to parse SQL first + String sql = executeImmediateStatement.getSqlString(); + List literalParameters = executeImmediateStatement.getParameters(); + + statementToUse = sqlParser.createStatement(sql, clientSession.getZoneId(), clientSession); + + if (!literalParameters.isEmpty()) { + parameterLookup = ParameterExtractor.bindParameters(statementToUse, literalParameters); + parameters = new ArrayList<>(literalParameters); + } + } + if (statement instanceof WrappedInsertStatement) { ((WrappedInsertStatement) statement).setContext(queryContext); } - final TableModelPlanner tableModelPlanner = + + // Create QueryExecution with TableModelPlanner + TableModelPlanner tableModelPlanner = new TableModelPlanner( - statement, + statementToUse, sqlParser, metadata, scheduledExecutor, @@ -501,6 +614,8 @@ private IQueryExecution createQueryExecutionForTableModel( distributionPlanOptimizers, AuthorityChecker.getAccessControl(), dataNodeLocationSupplier, + parameters, + parameterLookup, typeManager); return new QueryExecution(tableModelPlanner, queryContext, executor); } @@ -546,12 +661,22 @@ public void cleanupQueryExecution( try (SetThreadName threadName = new SetThreadName(queryExecution.getQueryId())) { LOGGER.debug("[CleanUpQuery]]"); queryExecution.stopAndCleanup(t); + boolean isUserQuery = queryExecution.isQuery() && queryExecution.isUserQuery(); + Supplier contentOfQuerySupplier = + new ContentOfQuerySupplier(nativeApiRequest, queryExecution); + if (isUserQuery) { + recordCurrentQueries( + queryExecution.getQueryId(), + queryExecution.getStartExecutionTime(), + System.currentTimeMillis(), + queryExecution.getTotalExecutionTime(), + contentOfQuerySupplier, + queryExecution.getUser(), + queryExecution.getClientHostname()); + } queryExecutionMap.remove(queryId); - if (queryExecution.isQuery() && queryExecution.isUserQuery()) { - recordQueries( - queryExecution::getTotalExecutionTime, - new ContentOfQuerySupplier(nativeApiRequest, queryExecution), - t); + if (isUserQuery) { + recordQueries(queryExecution::getTotalExecutionTime, contentOfQuerySupplier, t); } } } @@ -609,6 +734,10 @@ public static Coordinator getInstance() { return INSTANCE; } + public static IMemoryBlock getCoordinatorMemoryBlock() { + return coordinatorMemoryBlock; + } + public void recordExecutionTime(long queryId, long executionTime) { IQueryExecution queryExecution = getQueryExecution(queryId); if (queryExecution != null) { @@ -639,4 +768,249 @@ public DataNodeLocationSupplierFactory.DataNodeLocationSupplier getDataNodeLocat public ExecutorService getDispatchExecutor() { return dispatchExecutor; } + + /** record query info in memory data structure */ + public void recordCurrentQueries( + String queryId, + long startTime, + long endTime, + long costTimeInNs, + Supplier contentOfQuerySupplier, + String user, + String clientHost) { + if (CONFIG.getQueryCostStatWindow() <= 0) { + return; + } + + if (queryId == null) { + // fast Last query API executeFastLastDataQueryForOnePrefixPath will enter this + queryId = queryIdGenerator.createNextQueryId().getId(); + } + + // ns -> s + float costTimeInSeconds = costTimeInNs * 1e-9f; + + QueryInfo queryInfo = + new QueryInfo( + queryId, + startTime, + endTime, + costTimeInSeconds, + contentOfQuerySupplier.get(), + user, + clientHost); + + while (!coordinatorMemoryBlock.allocate(RamUsageEstimator.sizeOfObject(queryInfo))) { + // try to release memory from the head of queue + QueryInfo queryInfoToRelease = currentQueriesInfo.poll(); + if (queryInfoToRelease == null) { + // no element in the queue and the memory is still not enough, skip this record + return; + } else { + // release memory and unrecord in histogram + coordinatorMemoryBlock.release(RamUsageEstimator.sizeOfObject(queryInfoToRelease)); + unrecordInHistogram(queryInfoToRelease.costTime); + } + } + + currentQueriesInfo.addLast(queryInfo); + recordInHistogram(costTimeInSeconds); + } + + private void recordInHistogram(float costTimeInSeconds) { + int bucket = (int) costTimeInSeconds; + if (bucket < 60) { + currentQueriesCostHistogram[bucket].getAndIncrement(); + } else { + currentQueriesCostHistogram[60].getAndIncrement(); + } + } + + private void unrecordInHistogram(float costTimeInSeconds) { + int bucket = (int) costTimeInSeconds; + if (bucket < 60) { + currentQueriesCostHistogram[bucket].getAndDecrement(); + } else { + currentQueriesCostHistogram[60].getAndDecrement(); + } + } + + private void clearExpiredQueriesInfoTask() { + int queryCostStatWindow = CONFIG.getQueryCostStatWindow(); + if (queryCostStatWindow <= 0) { + return; + } + + // the QueryInfo smaller than expired time will be cleared + long expiredTime = System.currentTimeMillis() - queryCostStatWindow * 60 * 1_000L; + // peek head, the head QueryInfo is in the time window, return directly + QueryInfo queryInfo = currentQueriesInfo.peekFirst(); + if (queryInfo == null || queryInfo.endTime >= expiredTime) { + return; + } + + queryInfo = currentQueriesInfo.poll(); + while (queryInfo != null) { + if (queryInfo.endTime < expiredTime) { + // out of time window, clear queryInfo + coordinatorMemoryBlock.release(RamUsageEstimator.sizeOfObject(queryInfo)); + unrecordInHistogram(queryInfo.costTime); + queryInfo = currentQueriesInfo.poll(); + } else { + // the head of the queue is not expired, add back + currentQueriesInfo.addFirst(queryInfo); + // there is no more candidate to clear + return; + } + } + } + + public List getCurrentQueriesInfo() { + List runningQueries = getAllQueryExecutions(); + Set runningQueryIdSet = + runningQueries.stream().map(IQueryExecution::getQueryId).collect(Collectors.toSet()); + List result = new ArrayList<>(); + + // add History queries (satisfy the time window) info + Iterator historyQueriesIterator = currentQueriesInfo.iterator(); + Set repetitionQueryIdSet = new HashSet<>(); + long currentTime = System.currentTimeMillis(); + long needRecordTime = currentTime - CONFIG.getQueryCostStatWindow() * 60 * 1_000L; + while (historyQueriesIterator.hasNext()) { + QueryInfo queryInfo = historyQueriesIterator.next(); + if (queryInfo.endTime < needRecordTime) { + // out of time window, ignore it + } else { + if (runningQueryIdSet.contains(queryInfo.queryId)) { + repetitionQueryIdSet.add(queryInfo.queryId); + } + result.add(new StatedQueriesInfo(QueryState.FINISHED, queryInfo)); + } + } + + // add Running queries info after remove the repetitions which has recorded in History queries + result.addAll( + runningQueries.stream() + .filter(queryExecution -> !repetitionQueryIdSet.contains(queryExecution.getQueryId())) + .map( + queryExecution -> + new StatedQueriesInfo( + QueryState.RUNNING, + queryExecution.getQueryId(), + queryExecution.getStartExecutionTime(), + DEFAULT_END_TIME, + (currentTime - queryExecution.getStartExecutionTime()) / 1000, + queryExecution.getExecuteSQL().orElse("UNKNOWN"), + queryExecution.getUser(), + queryExecution.getClientHostname())) + .collect(Collectors.toList())); + return result; + } + + public int[] getCurrentQueriesCostHistogram() { + return Arrays.stream(currentQueriesCostHistogram).mapToInt(AtomicInteger::get).toArray(); + } + + public static class QueryInfo implements Accountable { + public static final long DEFAULT_END_TIME = -1L; + private static final long INSTANCE_SIZE = shallowSizeOfInstance(QueryInfo.class); + + private final String queryId; + + // unit: millisecond + private final long startTime; + private final long endTime; + // unit: second + private final float costTime; + + private final String statement; + private final String user; + private final String clientHost; + + public QueryInfo( + String queryId, + long startTime, + long endTime, + float costTime, + String statement, + String user, + String clientHost) { + this.queryId = queryId; + this.startTime = startTime; + this.endTime = endTime; + this.costTime = costTime; + this.statement = statement; + this.user = user; + this.clientHost = clientHost; + } + + public String getClientHost() { + return clientHost; + } + + public String getUser() { + return user; + } + + public long getStartTime() { + return startTime; + } + + public long getEndTime() { + return endTime; + } + + public float getCostTime() { + return costTime; + } + + public String getQueryId() { + return queryId; + } + + public String getStatement() { + return statement; + } + + @Override + public long ramBytesUsed() { + return INSTANCE_SIZE + + sizeOfCharArray(statement.length()) + + sizeOfCharArray(user.length()) + + sizeOfCharArray(clientHost.length()); + } + } + + public static class StatedQueriesInfo extends QueryInfo { + private final QueryState queryState; + + private StatedQueriesInfo(QueryState queryState, QueryInfo queryInfo) { + super( + queryInfo.queryId, + queryInfo.startTime, + queryInfo.endTime, + queryInfo.costTime, + queryInfo.statement, + queryInfo.user, + queryInfo.clientHost); + this.queryState = queryState; + } + + private StatedQueriesInfo( + QueryState queryState, + String queryId, + long startTime, + long endTime, + long costTime, + String statement, + String user, + String clientHost) { + super(queryId, startTime, endTime, costTime, statement, user, clientHost); + this.queryState = queryState; + } + + public String getQueryState() { + return queryState.name(); + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 34a289b76c9da..dc56fe118b7b3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -25,7 +25,6 @@ import org.apache.iotdb.commons.conf.IoTDBConstant; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -55,14 +54,6 @@ import org.apache.iotdb.db.queryengine.common.schematree.IMeasurementSchemaInfo; import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.HeadInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.TailInferenceWindow; import org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet; import org.apache.iotdb.db.queryengine.plan.analyze.load.LoadTsFileAnalyzer; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; @@ -425,46 +416,14 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem return; } - // Get model metadata from configNode and do some check + // Get model metadata from AINode String modelId = queryStatement.getModelId(); TSStatus status = modelFetcher.fetchModel(modelId, analysis); if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { throw new GetModelInfoException(status.getMessage()); } - // set inference window if there is - if (queryStatement.isSetInferenceWindow()) { - InferenceWindow window = queryStatement.getInferenceWindow(); - if (InferenceWindowType.HEAD == window.getType()) { - long windowSize = ((HeadInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - if (queryStatement.hasLimit() && queryStatement.getRowLimit() < windowSize) { - throw new SemanticException( - "Limit in Sql should be larger than window size in inference"); - } - // optimize head window by limitNode - queryStatement.setRowLimit(windowSize); - } else if (InferenceWindowType.TAIL == window.getType()) { - long windowSize = ((TailInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new BottomInferenceWindowParameter(windowSize); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } else if (InferenceWindowType.COUNT == window.getType()) { - CountInferenceWindow countInferenceWindow = (CountInferenceWindow) window; - // checkWindowSize(countInferenceWindow.getInterval(), modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new CountInferenceWindowParameter( - countInferenceWindow.getInterval(), countInferenceWindow.getStep()); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } - } - - // set inference attributes if there is + // Set inference attributes if there is if (queryStatement.hasInferenceAttributes()) { analysis .getModelInferenceDescriptor() @@ -472,12 +431,6 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem } } - private void checkWindowSize(long windowSize, ModelInformation modelInformation) { - if (modelInformation.isBuiltIn()) { - return; - } - } - private ISchemaTree analyzeSchema( QueryStatement queryStatement, Analysis analysis, @@ -1717,22 +1670,11 @@ static void analyzeOutput( } if (queryStatement.hasModelInference()) { - ModelInformation modelInformation = analysis.getModelInformation(); // check input - checkInputShape(modelInformation, outputExpressions); - checkInputType(analysis, modelInformation, outputExpressions); - + checkInputType(analysis, outputExpressions); // set output List columnHeaders = new ArrayList<>(); - int[] outputShape = modelInformation.getOutputShape(); - TSDataType[] outputDataType = modelInformation.getOutputDataType(); - for (int i = 0; i < outputShape[1]; i++) { - columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME + i, outputDataType[i])); - } - analysis - .getModelInferenceDescriptor() - .setOutputColumnNames( - columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList())); + columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME, TSDataType.DOUBLE)); boolean isIgnoreTimestamp = !queryStatement.isGenerateTime(); analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); return; @@ -1756,74 +1698,16 @@ static void analyzeOutput( analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); } - // check if the result of SQL matches the input of model - private static void checkInputShape( - ModelInformation modelInformation, List> outputExpressions) { - if (modelInformation.isBuiltIn()) { - modelInformation.setInputColumnSize(outputExpressions.size()); - return; - } - - // check inputShape - int[] inputShape = modelInformation.getInputShape(); - if (inputShape.length != 2) { - throw new SemanticException( - String.format( - "The input shape of model is not correct, the dimension of input shape should be 2, actual dimension is %d", - inputShape.length)); - } - int columnNumber = inputShape[1]; - if (columnNumber != outputExpressions.size()) { - throw new SemanticException( - String.format( - "The column number of SQL result does not match the number of model input [%d] for inference", - columnNumber)); - } - } - private static void checkInputType( - Analysis analysis, - ModelInformation modelInformation, - List> outputExpressions) { - - if (modelInformation.isBuiltIn()) { - TSDataType[] inputType = new TSDataType[outputExpressions.size()]; - for (int i = 0; i < outputExpressions.size(); i++) { - Expression inputExpression = outputExpressions.get(i).left; - TSDataType inputDataType = analysis.getType(inputExpression); - if (!inputDataType.isNumeric()) { - throw new SemanticException( - String.format( - "The type of SQL result column [%s in %d] should be numeric when inference", - inputDataType, i)); - } - inputType[i] = inputDataType; - } - modelInformation.setInputDataType(inputType); - return; - } - - TSDataType[] inputType = modelInformation.getInputDataType(); - if (inputType.length != modelInformation.getInputShape()[1]) { - throw new SemanticException( - String.format( - "The inputType does not match the input shape [%d] for inference", - modelInformation.getInputShape()[1])); - } - for (int i = 0; i < inputType.length; i++) { + Analysis analysis, List> outputExpressions) { + for (int i = 0; i < outputExpressions.size(); i++) { Expression inputExpression = outputExpressions.get(i).left; TSDataType inputDataType = analysis.getType(inputExpression); - boolean isExpressionNumeric = inputDataType.isNumeric(); - boolean isModelNumeric = inputType[i].isNumeric(); - if (isExpressionNumeric && isModelNumeric) { - // every model supports numeric by default - continue; - } - if (inputDataType != inputType[i]) { + if (!inputDataType.isNumeric()) { throw new SemanticException( String.format( - "The type of SQL result column [%s in %d] does not match the type of model input [%s] when inference", - inputDataType, i, inputType[i])); + "The type of SQL result column [%s in %d] should be numeric when inference", + inputDataType, i)); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java index 586e12e589ab1..1feecaefde9c5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java @@ -20,12 +20,8 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; public interface IModelFetcher { /** Get model information by model id from configNode. */ TSStatus fetchModel(String modelId, Analysis analysis); - - // currently only used by table model - ModelInferenceDescriptor fetchModel(String modelName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java index dbeee4e8ed4b6..b4123c237bbd3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java @@ -20,27 +20,13 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.db.exception.ainode.ModelNotFoundException; -import org.apache.iotdb.db.exception.sql.StatementAnalyzeException; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; -import org.apache.thrift.TException; - +// TODO: This class should contact with AINode directly and cache model info in DataNode public class ModelFetcher implements IModelFetcher { - private final IClientManager configNodeClientManager = - ConfigNodeClientManager.getInstance(); - private static final class ModelFetcherHolder { private static final ModelFetcher INSTANCE = new ModelFetcher(); @@ -55,34 +41,9 @@ public static ModelFetcher getInstance() { private ModelFetcher() {} @Override - public TSStatus fetchModel(String modelName, Analysis analysis) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new StatementAnalyzeException(e.getMessage()); - } - } - - @Override - public ModelInferenceDescriptor fetchModel(String modelName) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new ModelInferenceDescriptor(getModelInfoResp.aiNodeAddress); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new IoTDBRuntimeException( - String.format("fetch model [%s] info failed: %s", modelName, e.getMessage()), - TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - } + public TSStatus fetchModel(String modelId, Analysis analysis) { + analysis.setModelInferenceDescriptor( + new ModelInferenceDescriptor(new ModelInformation(modelId))); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java index 98257c24293e8..e98f016767f6b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java @@ -77,4 +77,6 @@ public interface IQueryExecution { IClientSession.SqlDialect getSQLDialect(); String getUser(); + + String getClientHostname(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java index 2c1657e839e24..4734db5850a4c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java @@ -700,6 +700,11 @@ public String getUser() { return context.getSession().getUserName(); } + @Override + public String getClientHostname() { + return context.getCliHostname(); + } + public MPPQueryContext getContext() { return context; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java index 1880924f297ac..823a620820fd5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java @@ -353,4 +353,8 @@ public IClientSession.SqlDialect getSQLDialect() { public String getUser() { return context.getSession().getUserName(); } + + public String getClientHostname() { + return context.getCliHostname(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java index 3356ed2b58490..f9d668944e008 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java @@ -101,6 +101,8 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.ShowTablesDetailsTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.ShowTablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.UseDBTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.DeallocateTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.PrepareTask; import org.apache.iotdb.db.queryengine.plan.execution.config.session.SetSqlDialectTask; import org.apache.iotdb.db.queryengine.plan.execution.config.session.ShowCurrentDatabaseTask; import org.apache.iotdb.db.queryengine.plan.execution.config.session.ShowCurrentSqlDialectTask; @@ -150,6 +152,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateView; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataType; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DatabaseStatement; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Deallocate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DescribeTable; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropColumn; @@ -171,6 +174,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.MigrateRegion; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Prepare; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Property; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.QualifiedName; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ReconstructRegion; @@ -1375,6 +1379,18 @@ protected IConfigTask visitSetSqlDialect(SetSqlDialect node, MPPQueryContext con return new SetSqlDialectTask(node.getSqlDialect()); } + @Override + protected IConfigTask visitPrepare(Prepare node, MPPQueryContext context) { + context.setQueryType(QueryType.WRITE); + return new PrepareTask(node.getStatementName().getValue(), node.getSql()); + } + + @Override + protected IConfigTask visitDeallocate(Deallocate node, MPPQueryContext context) { + context.setQueryType(QueryType.WRITE); + return new DeallocateTask(node.getStatementName().getValue()); + } + @Override protected IConfigTask visitShowCurrentDatabase( ShowCurrentDatabase node, MPPQueryContext context) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 34bbd57640a3c..5f27b3feb8c32 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -19,7 +19,10 @@ package org.apache.iotdb.db.queryengine.plan.execution.config.executor; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; @@ -96,7 +99,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCountTimeSlotListResp; import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; @@ -114,7 +116,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -175,8 +176,8 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; @@ -381,6 +382,8 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = ConfigNodeClientManager.getInstance(); + private static final IClientManager AI_NODE_CLIENT_MANAGER = + AINodeClientManager.getInstance(); /** FIXME Consolidate this clientManager with the upper one. */ private static final IClientManager @@ -3617,16 +3620,16 @@ public SettableFuture showContinuousQueries() { @Override public SettableFuture createModel(String modelId, String uri) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TCreateModelReq req = new TCreateModelReq(modelId, uri); - final TSStatus status = client.createModel(req); - if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { - future.setException(new IoTDBException(status)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TRegisterModelReq req = new TRegisterModelReq(modelId, uri); + final TRegisterModelResp resp = client.registerModel(req); + if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != resp.getStatus().getCode()) { + future.setException(new IoTDBException(resp.getStatus())); } else { future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); } - } catch (final ClientManagerException | TException e) { + } catch (final TException | ClientManagerException e) { future.setException(e); } return future; @@ -3635,9 +3638,9 @@ public SettableFuture createModel(String modelId, String uri) @Override public SettableFuture dropModel(final String modelId) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TSStatus executionStatus = client.dropModel(new TDropModelReq(modelId)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TSStatus executionStatus = client.deleteModel(new TDeleteModelReq(modelId)); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) { future.setException(new IoTDBException(executionStatus)); } else { @@ -3653,7 +3656,7 @@ public SettableFuture dropModel(final String modelId) { public SettableFuture showModels(final String modelId) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowModelsReq req = new TShowModelsReq(); if (modelId != null) { req.setModelId(modelId); @@ -3674,7 +3677,7 @@ public SettableFuture showModels(final String modelId) { public SettableFuture showLoadedModels(List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowLoadedModelsReq req = new TShowLoadedModelsReq(); req.setDeviceIdList(deviceIdList != null ? deviceIdList : new ArrayList<>()); final TShowLoadedModelsResp resp = ai.showLoadedModels(req); @@ -3693,7 +3696,7 @@ public SettableFuture showLoadedModels(List deviceIdLi public SettableFuture showAIDevices() { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowAIDevicesResp resp = ai.showAIDevices(); if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { future.setException(new IoTDBException(resp.getStatus())); @@ -3711,7 +3714,7 @@ public SettableFuture loadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TLoadModelReq req = new TLoadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.loadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3730,7 +3733,7 @@ public SettableFuture unloadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TUnloadModelReq req = new TUnloadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.unloadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3755,7 +3758,7 @@ public SettableFuture createTraining( @Nullable List pathList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TTrainingReq req = new TTrainingReq(); req.setModelId(modelId); req.setParameters(parameters); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java new file mode 100644 index 0000000000000..6f5f3f4846154 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.protocol.session.SessionManager; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +/** + * Task for executing DEALLOCATE PREPARE statement. Removes the prepared statement from the session + * and releases its allocated memory. + */ +public class DeallocateTask implements IConfigTask { + + private final String statementName; + + public DeallocateTask(String statementName) { + this.statementName = statementName; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + SettableFuture future = SettableFuture.create(); + IClientSession session = SessionManager.getInstance().getCurrSession(); + if (session == null) { + future.setException( + new IllegalStateException("No current session available for DEALLOCATE statement")); + return future; + } + + // Remove the prepared statement + PreparedStatementInfo removedInfo = session.removePreparedStatement(statementName); + if (removedInfo == null) { + future.setException( + new SemanticException( + String.format("Prepared statement '%s' does not exist", statementName))); + return future; + } + + // Release the memory allocated for this PreparedStatement from the shared MemoryBlock + PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); + + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + return future; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java new file mode 100644 index 0000000000000..bf61e702c72d0 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.protocol.session.SessionManager; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.AstMemoryEstimator; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +/** + * Task for executing PREPARE statement. Stores the prepared statement AST in the session. The AST + * is cached to avoid reparsing on EXECUTE (skipping Parser phase). Memory is allocated from + * CoordinatorMemoryManager and shared across all sessions. + */ +public class PrepareTask implements IConfigTask { + + private final String statementName; + private final Statement sql; // AST containing Parameter nodes + + public PrepareTask(String statementName, Statement sql) { + this.statementName = statementName; + this.sql = sql; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + SettableFuture future = SettableFuture.create(); + IClientSession session = SessionManager.getInstance().getCurrSession(); + if (session == null) { + future.setException( + new IllegalStateException("No current session available for PREPARE statement")); + return future; + } + + // Check if prepared statement with the same name already exists + PreparedStatementInfo existingInfo = session.getPreparedStatement(statementName); + if (existingInfo != null) { + future.setException( + new SemanticException( + String.format("Prepared statement '%s' already exists.", statementName))); + return future; + } + + // Estimate memory size of the AST + long memorySizeInBytes = AstMemoryEstimator.estimateMemorySize(sql); + + // Allocate memory from CoordinatorMemoryManager + // This memory is shared across all sessions using a single MemoryBlock + PreparedStatementMemoryManager.getInstance().allocate(statementName, memorySizeInBytes); + + // Create and store the prepared statement info (AST is cached) + PreparedStatementInfo info = new PreparedStatementInfo(statementName, sql, memorySizeInBytes); + session.addPreparedStatement(statementName, info); + + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + return future; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java new file mode 100644 index 0000000000000..9d5a3fb098e5c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.commons.memory.IMemoryBlock; +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.queryengine.plan.Coordinator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; + +/** + * Memory manager for PreparedStatement. All PreparedStatements from all sessions share a single + * MemoryBlock named "Coordinator" allocated from CoordinatorMemoryManager. The MemoryBlock is + * initialized in Coordinator with all available memory. + */ +public class PreparedStatementMemoryManager { + private static final Logger LOGGER = + LoggerFactory.getLogger(PreparedStatementMemoryManager.class); + + private static final PreparedStatementMemoryManager INSTANCE = + new PreparedStatementMemoryManager(); + + private static final String SHARED_MEMORY_BLOCK_NAME = "Coordinator"; + + private PreparedStatementMemoryManager() { + // singleton + } + + public static PreparedStatementMemoryManager getInstance() { + return INSTANCE; + } + + private IMemoryBlock getSharedMemoryBlock() { + return Coordinator.getCoordinatorMemoryBlock(); + } + + /** + * Allocate memory for a PreparedStatement. + * + * @param statementName the name of the prepared statement + * @param memorySizeInBytes the memory size in bytes to allocate + * @throws SemanticException if memory allocation fails + */ + public void allocate(String statementName, long memorySizeInBytes) { + IMemoryBlock sharedMemoryBlock = getSharedMemoryBlock(); + // Allocate memory from the shared block + boolean allocated = sharedMemoryBlock.allocate(memorySizeInBytes); + if (!allocated) { + LOGGER.warn( + "Failed to allocate {} bytes from shared MemoryBlock '{}' for PreparedStatement '{}'", + memorySizeInBytes, + SHARED_MEMORY_BLOCK_NAME, + statementName); + throw new SemanticException( + String.format( + "Insufficient memory for PreparedStatement '%s'. " + + "Please deallocate some PreparedStatements and try again.", + statementName)); + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Allocated {} bytes for PreparedStatement '{}' from shared MemoryBlock '{}'. ", + memorySizeInBytes, + statementName, + SHARED_MEMORY_BLOCK_NAME); + } + } + + /** + * Release memory for a PreparedStatement. + * + * @param memorySizeInBytes the memory size in bytes to release + */ + public void release(long memorySizeInBytes) { + if (memorySizeInBytes <= 0) { + return; + } + + IMemoryBlock sharedMemoryBlock = getSharedMemoryBlock(); + if (!sharedMemoryBlock.isReleased()) { + long releasedSize = sharedMemoryBlock.release(memorySizeInBytes); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Released {} bytes from shared MemoryBlock '{}' for PreparedStatement. ", + releasedSize, + SHARED_MEMORY_BLOCK_NAME); + } + } else { + LOGGER.error( + "Attempted to release memory from shared MemoryBlock '{}' but it is released", + SHARED_MEMORY_BLOCK_NAME); + } + } + + /** + * Release all PreparedStatements for a session. This method should be called when a session is + * closed or connection is lost. + * + * @param session the session whose PreparedStatements should be released + */ + public void releaseAllForSession(IClientSession session) { + if (session == null) { + return; + } + + Set preparedStatementNames = session.getPreparedStatementNames(); + if (preparedStatementNames == null || preparedStatementNames.isEmpty()) { + return; + } + + int releasedCount = 0; + long totalReleasedBytes = 0; + + for (String statementName : preparedStatementNames) { + PreparedStatementInfo info = session.getPreparedStatement(statementName); + if (info != null) { + long memorySize = info.getMemorySizeInBytes(); + if (memorySize > 0) { + release(memorySize); + releasedCount++; + totalReleasedBytes += memorySize; + } + } + } + + if (releasedCount > 0 && LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Released {} PreparedStatement(s) ({} bytes total) for session {}", + releasedCount, + totalReleasedBytes, + session); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java index 09205c9eb5647..a01acf86db57f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java @@ -31,6 +31,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -90,7 +91,7 @@ public PlanNode clone() { @Override public List getOutputColumnNames() { - return modelInferenceDescriptor.getOutputColumnNames(); + return Collections.singletonList("output"); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java index b7c6aaa4f4b01..1301ec97eb32e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java @@ -19,9 +19,7 @@ package org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; import org.apache.tsfile.utils.ReadWriteIOUtils; @@ -36,19 +34,15 @@ public class ModelInferenceDescriptor { - private final TEndPoint targetAINode; - private ModelInformation modelInformation; + private final ModelInformation modelInformation; private List outputColumnNames; - private InferenceWindowParameter inferenceWindowParameter; private Map inferenceAttributes; - public ModelInferenceDescriptor(TEndPoint targetAINode) { - this.targetAINode = targetAINode; + public ModelInferenceDescriptor(ModelInformation modelInformation) { + this.modelInformation = modelInformation; } private ModelInferenceDescriptor(ByteBuffer buffer) { - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelInformation = ModelInformation.deserialize(buffer); int outputColumnNamesSize = ReadWriteIOUtils.readInt(buffer); if (outputColumnNamesSize == 0) { @@ -59,12 +53,6 @@ private ModelInferenceDescriptor(ByteBuffer buffer) { this.outputColumnNames.add(ReadWriteIOUtils.readString(buffer)); } } - boolean hasInferenceWindowParameter = ReadWriteIOUtils.readBool(buffer); - if (hasInferenceWindowParameter) { - this.inferenceWindowParameter = InferenceWindowParameter.deserialize(buffer); - } else { - this.inferenceWindowParameter = null; - } int inferenceAttributesSize = ReadWriteIOUtils.readInt(buffer); if (inferenceAttributesSize == 0) { this.inferenceAttributes = null; @@ -85,24 +73,12 @@ public Map getInferenceAttributes() { return inferenceAttributes; } - public void setInferenceWindowParameter(InferenceWindowParameter inferenceWindowParameter) { - this.inferenceWindowParameter = inferenceWindowParameter; - } - - public InferenceWindowParameter getInferenceWindowParameter() { - return inferenceWindowParameter; - } - public ModelInformation getModelInformation() { return modelInformation; } - public TEndPoint getTargetAINode() { - return targetAINode; - } - - public String getModelName() { - return modelInformation.getModelName(); + public String getModelId() { + return modelInformation.getModelId(); } public void setOutputColumnNames(List outputColumnNames) { @@ -114,8 +90,6 @@ public List getOutputColumnNames() { } public void serialize(ByteBuffer byteBuffer) { - ReadWriteIOUtils.write(targetAINode.ip, byteBuffer); - ReadWriteIOUtils.write(targetAINode.port, byteBuffer); modelInformation.serialize(byteBuffer); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, byteBuffer); @@ -125,12 +99,6 @@ public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(outputColumnName, byteBuffer); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, byteBuffer); - } else { - ReadWriteIOUtils.write(true, byteBuffer); - inferenceWindowParameter.serialize(byteBuffer); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, byteBuffer); } else { @@ -143,8 +111,6 @@ public void serialize(ByteBuffer byteBuffer) { } public void serialize(DataOutputStream stream) throws IOException { - ReadWriteIOUtils.write(targetAINode.ip, stream); - ReadWriteIOUtils.write(targetAINode.port, stream); modelInformation.serialize(stream); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, stream); @@ -154,12 +120,6 @@ public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(outputColumnName, stream); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, stream); - } else { - ReadWriteIOUtils.write(true, stream); - inferenceWindowParameter.serialize(stream); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, stream); } else { @@ -184,20 +144,13 @@ public boolean equals(Object o) { return false; } ModelInferenceDescriptor that = (ModelInferenceDescriptor) o; - return targetAINode.equals(that.targetAINode) - && modelInformation.equals(that.modelInformation) + return modelInformation.equals(that.modelInformation) && outputColumnNames.equals(that.outputColumnNames) - && inferenceWindowParameter.equals(that.inferenceWindowParameter) && inferenceAttributes.equals(that.inferenceAttributes); } @Override public int hashCode() { - return Objects.hash( - targetAINode, - modelInformation, - outputColumnNames, - inferenceWindowParameter, - inferenceAttributes); + return Objects.hash(modelInformation, outputColumnNames, inferenceAttributes); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java index 5acf353e32e50..e13c52ba8b168 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java @@ -40,13 +40,11 @@ import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableArgumentAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableFunctionInvocationAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; -import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName; import org.apache.iotdb.db.queryengine.plan.relational.metadata.TableMetadataImpl; import org.apache.iotdb.db.queryengine.plan.relational.metadata.TableSchema; -import org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter; import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; import org.apache.iotdb.db.queryengine.plan.relational.planner.ScopeAware; import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; @@ -133,6 +131,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullIfExpression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Offset; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.OrderBy; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PipeEnriched; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Property; @@ -211,6 +210,7 @@ import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification; import com.google.common.base.Joiner; +import com.google.common.base.VerifyException; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -281,6 +281,7 @@ import static org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction.TIMECOL_PARAMETER_NAME; import static org.apache.iotdb.db.queryengine.plan.relational.metadata.MetadataUtil.createQualifiedObjectName; import static org.apache.iotdb.db.queryengine.plan.relational.metadata.TableMetadataImpl.isTimestampType; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter.evaluateConstantExpression; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression.getQualifiedName; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Join.Type.FULL; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Join.Type.INNER; @@ -3981,15 +3982,12 @@ private void analyzeOffset(Offset node, Scope scope) { if (node.getRowCount() instanceof LongLiteral) { rowCount = ((LongLiteral) node.getRowCount()).getParsedValue(); } else { - // checkState( - // node.getRowCount() instanceof Parameter, - // "unexpected OFFSET rowCount: " + - // node.getRowCount().getClass().getSimpleName()); - throw new SemanticException( + checkState( + node.getRowCount() instanceof Parameter, "unexpected OFFSET rowCount: " + node.getRowCount().getClass().getSimpleName()); - // OptionalLong providedValue = - // analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "OFFSET"); - // rowCount = providedValue.orElse(0); + OptionalLong providedValue = + analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "OFFSET"); + rowCount = providedValue.orElse(0); } if (rowCount < 0) { throw new SemanticException( @@ -4050,14 +4048,10 @@ private boolean analyzeLimit(Limit node, Scope scope) { } else if (node.getRowCount() instanceof LongLiteral) { rowCount = OptionalLong.of(((LongLiteral) node.getRowCount()).getParsedValue()); } else { - // checkState( - // node.getRowCount() instanceof Parameter, - // "unexpected LIMIT rowCount: " + - // node.getRowCount().getClass().getSimpleName()); - throw new SemanticException( + checkState( + node.getRowCount() instanceof Parameter, "unexpected LIMIT rowCount: " + node.getRowCount().getClass().getSimpleName()); - // rowCount = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, - // "LIMIT"); + rowCount = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "LIMIT"); } rowCount.ifPresent( count -> { @@ -4073,32 +4067,27 @@ private boolean analyzeLimit(Limit node, Scope scope) { return false; } - // private OptionalLong analyzeParameterAsRowCount( - // Parameter parameter, Scope scope, String context) { - // // validate parameter index - // analyzeExpression(parameter, scope); - // Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); - // Object value; - // try { - // value = - // evaluateConstantExpression( - // providedValue, - // BIGINT, - // plannerContext, - // session, - // accessControl, - // analysis.getParameters()); - // } catch (VerifyException e) { - // throw new SemanticException( - // String.format("Non constant parameter value for %s: %s", context, providedValue)); - // } - // if (value == null) { - // throw new SemanticException( - // String.format("Parameter value provided for %s is NULL: %s", context, - // providedValue)); - // } - // return OptionalLong.of((long) value); - // } + private OptionalLong analyzeParameterAsRowCount( + Parameter parameter, Scope scope, String context) { + // validate parameter index + analyzeExpression(parameter, scope); + Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); + Object value; + try { + value = + evaluateConstantExpression( + providedValue, new PlannerContext(metadata, typeManager), sessionContext); + + } catch (VerifyException e) { + throw new SemanticException( + String.format("Non constant parameter value for %s: %s", context, providedValue)); + } + if (value == null) { + throw new SemanticException( + String.format("Parameter value provided for %s is NULL: %s", context, providedValue)); + } + return OptionalLong.of((long) value); + } private void analyzeAggregations( QuerySpecification node, @@ -4694,11 +4683,6 @@ public Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional String functionName = node.getName().toString(); TableFunction function = metadata.getTableFunction(functionName); - // set model fetcher for ForecastTableFunction - if (function instanceof ForecastTableFunction) { - ((ForecastTableFunction) function).setModelFetcher(metadata.getModelFetcher()); - } - Node errorLocation = node; if (!node.getArguments().isEmpty()) { errorLocation = node.getArguments().get(0); @@ -5191,7 +5175,7 @@ private ArgumentAnalysis analyzeScalarArgument( Expression expression, ScalarParameterSpecification argumentSpecification) { // currently, only constant arguments are supported Object constantValue = - IrExpressionInterpreter.evaluateConstantExpression( + evaluateConstantExpression( expression, new PlannerContext(metadata, typeManager), sessionContext); if (!argumentSpecification.getType().checkObjectType(constantValue)) { if ((argumentSpecification.getType().equals(org.apache.iotdb.udf.api.type.Type.STRING) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 887d7c26d305e..08f7ec6c8335c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -19,14 +19,14 @@ package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; import org.apache.iotdb.db.exception.sql.SemanticException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.relational.TableFunction; import org.apache.iotdb.udf.api.relational.access.Record; @@ -74,8 +74,9 @@ public class ForecastTableFunction implements TableFunction { + private static final TsBlockSerde SERDE = new TsBlockSerde(); + public static class ForecastTableFunctionHandle implements TableFunctionHandle { - TEndPoint targetAINode; String modelId; int maxInputLength; int outputLength; @@ -95,7 +96,6 @@ public ForecastTableFunctionHandle( int outputLength, long outputStartTime, long outputInterval, - TEndPoint targetAINode, List types) { this.keepInput = keepInput; this.maxInputLength = maxInputLength; @@ -104,7 +104,6 @@ public ForecastTableFunctionHandle( this.outputLength = outputLength; this.outputStartTime = outputStartTime; this.outputInterval = outputInterval; - this.targetAINode = targetAINode; this.types = types; } @@ -112,8 +111,6 @@ public ForecastTableFunctionHandle( public byte[] serialize() { try (PublicBAOS publicBAOS = new PublicBAOS(); DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { - ReadWriteIOUtils.write(targetAINode.getIp(), outputStream); - ReadWriteIOUtils.write(targetAINode.getPort(), outputStream); ReadWriteIOUtils.write(modelId, outputStream); ReadWriteIOUtils.write(maxInputLength, outputStream); ReadWriteIOUtils.write(outputLength, outputStream); @@ -138,8 +135,6 @@ public byte[] serialize() { @Override public void deserialize(byte[] bytes) { ByteBuffer buffer = ByteBuffer.wrap(bytes); - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelId = ReadWriteIOUtils.readString(buffer); this.maxInputLength = ReadWriteIOUtils.readInt(buffer); this.outputLength = ReadWriteIOUtils.readInt(buffer); @@ -168,7 +163,6 @@ public boolean equals(Object o) { && outputStartTime == that.outputStartTime && outputInterval == that.outputInterval && keepInput == that.keepInput - && Objects.equals(targetAINode, that.targetAINode) && Objects.equals(modelId, that.modelId) && Objects.equals(options, that.options) && Objects.equals(types, that.types); @@ -177,7 +171,6 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash( - targetAINode, modelId, maxInputLength, outputLength, @@ -284,8 +277,6 @@ public TableFunctionAnalysis analyze(Map arguments) { String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); } - TEndPoint targetAINode = getModelInfo(modelId).getTargetAINode(); - int outputLength = (int) ((ScalarArgument) arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue(); if (outputLength <= 0) { @@ -390,7 +381,6 @@ public TableFunctionAnalysis analyze(Map arguments) { outputLength, outputStartTime, outputInterval, - targetAINode, predicatedColumnTypes); // outputColumnSchema @@ -417,10 +407,6 @@ public TableFunctionDataProcessor getDataProcessor() { }; } - private ModelInferenceDescriptor getModelInfo(String modelId) { - return modelFetcher.fetchModel(modelId); - } - // only allow for INT32, INT64, FLOAT, DOUBLE private void checkType(Type type, String columnName) { if (!ALLOWED_INPUT_TYPES.contains(type)) { @@ -456,9 +442,9 @@ private static Map parseOptions(String options) { private static class ForecastDataProcessor implements TableFunctionDataProcessor { private static final TsBlockSerde SERDE = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); - private final TEndPoint targetAINode; private final String modelId; private final int maxInputLength; private final int outputLength; @@ -471,7 +457,6 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor private final TsBlockBuilder inputTsBlockBuilder; public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { - this.targetAINode = functionHandle.targetAINode; this.modelId = functionHandle.modelId; this.maxInputLength = functionHandle.maxInputLength; this.outputLength = functionHandle.outputLength; @@ -619,8 +604,12 @@ private TsBlock forecast() { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(modelId, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(modelId, SERDE.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java index f0c041ad8053d..db706d4980cba 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.TableHeaderSchemaValidator; @@ -211,9 +210,4 @@ DataPartition getDataPartitionWithUnclosedTimeRange( final String database, final List sgNameToQueryParamsMap); TableFunction getTableFunction(final String functionName); - - /** - * @return ModelFetcher - */ - IModelFetcher getModelFetcher(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 9260ec8de0320..7786876e19ada 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -1498,11 +1498,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - return modelFetcher; - } - public static boolean isTwoNumericType(List argumentTypes) { return argumentTypes.size() == 2 && isNumericType(argumentTypes.get(0)) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java index fddff825f0d27..78f9729ece341 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java @@ -35,13 +35,16 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analysis; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analyzer; +import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.StatementAnalyzerFactory; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.planner.distribute.TableDistributedPlanner; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.DataNodeLocationSupplierFactory; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer; import org.apache.iotdb.db.queryengine.plan.relational.security.AccessControl; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LoadTsFile; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PipeEnriched; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WrappedInsertStatement; @@ -57,8 +60,8 @@ import org.apache.iotdb.rpc.TSStatusCode; import java.util.ArrayList; -import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import static org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet.DISTRIBUTION_PLANNER; @@ -88,6 +91,9 @@ public class TableModelPlanner implements IPlanner { private final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier; + // Parameters for prepared statements (optional) + private final List parameters; + private final Map, Expression> parameterLookup; private final TypeManager typeManager; public TableModelPlanner( @@ -104,6 +110,8 @@ public TableModelPlanner( final List distributionPlanOptimizers, final AccessControl accessControl, final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier, + final List parameters, + final Map, Expression> parameterLookup, final TypeManager typeManager) { this.statement = statement; this.sqlParser = sqlParser; @@ -116,6 +124,8 @@ public TableModelPlanner( this.distributionPlanOptimizers = distributionPlanOptimizers; this.accessControl = accessControl; this.dataNodeLocationSupplier = dataNodeLocationSupplier; + this.parameters = parameters; + this.parameterLookup = parameterLookup; this.typeManager = typeManager; } @@ -125,8 +135,8 @@ public IAnalysis analyze(final MPPQueryContext context) { context, context.getSession(), new StatementAnalyzerFactory(metadata, sqlParser, accessControl, typeManager), - Collections.emptyList(), - Collections.emptyMap(), + parameters, + parameterLookup, statementRewrite, warningCollector) .analyze(statement); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java index 4676559bd7b15..d7d755ddc1da6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java @@ -86,6 +86,8 @@ public List getDataNodeLocations(final String tableName) { switch (tableName) { case InformationSchema.QUERIES: case InformationSchema.CONNECTIONS: + case InformationSchema.CURRENT_QUERIES: + case InformationSchema.QUERIES_COSTS_HISTOGRAM: return getReadableDataNodeLocations(); case InformationSchema.DATABASES: case InformationSchema.TABLES: @@ -96,7 +98,6 @@ public List getDataNodeLocations(final String tableName) { case InformationSchema.TOPICS: case InformationSchema.SUBSCRIPTIONS: case InformationSchema.VIEWS: - case InformationSchema.MODELS: case InformationSchema.FUNCTIONS: case InformationSchema.CONFIGURATIONS: case InformationSchema.KEYWORDS: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java new file mode 100644 index 0000000000000..d45f6546e0f6a --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql; + +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultTraversalVisitor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +import org.apache.tsfile.utils.RamUsageEstimator; + +/** + * Utility class for estimating memory usage of AST nodes. Uses RamUsageEstimator to calculate + * approximate memory size. + */ +public final class AstMemoryEstimator { + private AstMemoryEstimator() {} + + /** + * Estimate the memory size of a Statement AST node in bytes. + * + * @param statement the statement AST to estimate + * @return estimated memory size in bytes + */ + public static long estimateMemorySize(Statement statement) { + if (statement == null) { + return 0L; + } + MemoryEstimatingVisitor visitor = new MemoryEstimatingVisitor(); + visitor.process(statement, null); + return visitor.getTotalMemorySize(); + } + + private static class MemoryEstimatingVisitor extends DefaultTraversalVisitor { + private long totalMemorySize = 0L; + + public long getTotalMemorySize() { + return totalMemorySize; + } + + @Override + protected Void visitNode(Node node, Void context) { + // Estimate shallow size of the node object + long nodeSize = RamUsageEstimator.shallowSizeOfInstance(node.getClass()); + totalMemorySize += nodeSize; + + // Traverse children (DefaultTraversalVisitor handles this) + return super.visitNode(node, context); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java new file mode 100644 index 0000000000000..d727acdc35e2b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.db.queryengine.plan.relational.sql; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultTraversalVisitor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** Utility class for extracting and binding parameters in prepared statements. */ +public final class ParameterExtractor { + private ParameterExtractor() {} + + /** + * Get the number of parameters in a statement. + * + * @param statement the statement to analyze + * @return the number of parameters + */ + public static int getParameterCount(Statement statement) { + return extractParameters(statement).size(); + } + + /** + * Extract all Parameter nodes from a statement in order of appearance. + * + * @param statement the statement to analyze + * @return list of Parameter nodes in order of appearance + */ + public static List extractParameters(Statement statement) { + ParameterExtractingVisitor visitor = new ParameterExtractingVisitor(); + visitor.process(statement, null); + return visitor.getParameters().stream() + .sorted( + Comparator.comparing( + parameter -> + parameter + .getLocation() + .orElseThrow( + () -> new SemanticException("Parameter node must have a location")), + Comparator.comparing( + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NodeLocation + ::getLineNumber) + .thenComparing( + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NodeLocation + ::getColumnNumber))) + .collect(toImmutableList()); + } + + /** + * Bind parameter values to Parameter nodes in a statement. Creates a map from Parameter node + * references to their corresponding Expression values. + * + * @param statement the statement containing Parameter nodes + * @param values the parameter values (in order) + * @return map from Parameter node references to Expression values + * @throws SemanticException if the number of parameters doesn't match + */ + public static Map, Expression> bindParameters( + Statement statement, List values) { + List parametersList = extractParameters(statement); + + // Validate parameter count + if (parametersList.size() != values.size()) { + throw new SemanticException( + String.format( + "Invalid number of parameters: expected %d, got %d", + parametersList.size(), values.size())); + } + + ImmutableMap.Builder, Expression> builder = ImmutableMap.builder(); + Iterator iterator = values.iterator(); + for (Parameter parameter : parametersList) { + builder.put(NodeRef.of(parameter), iterator.next()); + } + return builder.buildOrThrow(); + } + + private static class ParameterExtractingVisitor extends DefaultTraversalVisitor { + private final List parameters = new ArrayList<>(); + + public List getParameters() { + return parameters; + } + + @Override + protected Void visitParameter(Parameter node, Void context) { + parameters.add(node); + return null; + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java index 54802a1d2f60f..2728750418fe2 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java @@ -93,6 +93,22 @@ protected R visitUse(Use node, C context) { return visitStatement(node, context); } + protected R visitPrepare(Prepare node, C context) { + return visitStatement(node, context); + } + + protected R visitExecute(Execute node, C context) { + return visitStatement(node, context); + } + + protected R visitExecuteImmediate(ExecuteImmediate node, C context) { + return visitStatement(node, context); + } + + protected R visitDeallocate(Deallocate node, C context) { + return visitStatement(node, context); + } + protected R visitGenericLiteral(GenericLiteral node, C context) { return visitLiteral(node, context); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java new file mode 100644 index 0000000000000..fd579bb64b5b6 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** DEALLOCATE PREPARE statement AST node. Example: DEALLOCATE PREPARE stmt1 */ +public final class Deallocate extends Statement { + + private final Identifier statementName; + + public Deallocate(Identifier statementName) { + this(null, statementName); + } + + public Deallocate(NodeLocation location, Identifier statementName) { + super(location); + this.statementName = requireNonNull(statementName, "statementName is null"); + } + + public Identifier getStatementName() { + return statementName; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitDeallocate(this, context); + } + + @Override + public List getChildren() { + return ImmutableList.of(statementName); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Deallocate that = (Deallocate) o; + return Objects.equals(statementName, that.statementName); + } + + @Override + public int hashCode() { + return Objects.hash(statementName); + } + + @Override + public String toString() { + return toStringHelper(this).add("statementName", statementName).toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java new file mode 100644 index 0000000000000..d7e219faf1b98 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** EXECUTE statement AST node. Example: EXECUTE stmt1 USING 100, 'test' */ +public final class Execute extends Statement { + + private final Identifier statementName; + private final List parameters; + + public Execute(Identifier statementName) { + this(null, statementName, ImmutableList.of()); + } + + public Execute(Identifier statementName, List parameters) { + this(null, statementName, parameters); + } + + public Execute(NodeLocation location, Identifier statementName, List parameters) { + super(location); + this.statementName = requireNonNull(statementName, "statementName is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + } + + public Identifier getStatementName() { + return statementName; + } + + public List getParameters() { + return parameters; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitExecute(this, context); + } + + @Override + public List getChildren() { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(statementName); + children.addAll(parameters); + return children.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Execute that = (Execute) o; + return Objects.equals(statementName, that.statementName) + && Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(statementName, parameters); + } + + @Override + public String toString() { + return toStringHelper(this) + .add("statementName", statementName) + .add("parameters", parameters) + .toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java new file mode 100644 index 0000000000000..955ac54e4fb8b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** + * EXECUTE IMMEDIATE statement AST node. Example: EXECUTE IMMEDIATE 'SELECT * FROM table WHERE id = + * 100' + */ +public final class ExecuteImmediate extends Statement { + + private final StringLiteral sql; + private final List parameters; + + public ExecuteImmediate(StringLiteral sql) { + this(null, sql, ImmutableList.of()); + } + + public ExecuteImmediate(StringLiteral sql, List parameters) { + this(null, sql, parameters); + } + + public ExecuteImmediate(NodeLocation location, StringLiteral sql, List parameters) { + super(location); + this.sql = requireNonNull(sql, "sql is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + } + + public StringLiteral getSql() { + return sql; + } + + public String getSqlString() { + return sql.getValue(); + } + + public List getParameters() { + return parameters; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitExecuteImmediate(this, context); + } + + @Override + public List getChildren() { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(sql); + children.addAll(parameters); + return children.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ExecuteImmediate that = (ExecuteImmediate) o; + return Objects.equals(sql, that.sql) && Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(sql, parameters); + } + + @Override + public String toString() { + return toStringHelper(this).add("sql", sql).add("parameters", parameters).toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java index 93fc8c7b58330..166f06b85e328 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java @@ -37,7 +37,7 @@ public class LoadTsFile extends Statement { - private final String filePath; + private String filePath; private int databaseLevel; // For loading to tree-model only private String database; // For loading to table-model only @@ -50,7 +50,7 @@ public class LoadTsFile extends Statement { private boolean isGeneratedByPipe = false; - private final Map loadAttributes; + private Map loadAttributes; private List tsFiles; private List resources; @@ -232,6 +232,63 @@ public boolean reconstructStatementIfMiniFileConverted(final List isMin return tsFiles == null || tsFiles.isEmpty(); } + @Override + public boolean shouldSplit() { + final int splitThreshold = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileStatementSplitThreshold(); + return tsFiles.size() > splitThreshold && !isAsyncLoad; + } + + /** + * Splits the current LoadTsFile statement into multiple sub-statements, each handling a batch of + * TsFiles. Used to limit resource consumption during statement analysis, etc. + * + * @return the list of sub-statements + */ + @Override + public List getSubStatements() { + final int batchSize = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileSubStatementBatchSize(); + final int totalBatches = (tsFiles.size() + batchSize - 1) / batchSize; // Ceiling division + final List subStatements = new ArrayList<>(totalBatches); + + for (int i = 0; i < tsFiles.size(); i += batchSize) { + final int endIndex = Math.min(i + batchSize, tsFiles.size()); + final List batchFiles = tsFiles.subList(i, endIndex); + + // Use the first file's path for the sub-statement + final String filePath = batchFiles.get(0).getAbsolutePath(); + final Map properties = this.loadAttributes; + + final LoadTsFile subStatement = + new LoadTsFile(getLocation().orElse(null), filePath, properties); + + // Copy all configuration properties + subStatement.databaseLevel = this.databaseLevel; + subStatement.database = this.database; + subStatement.verify = this.verify; + subStatement.deleteAfterLoad = this.deleteAfterLoad; + subStatement.convertOnTypeMismatch = this.convertOnTypeMismatch; + subStatement.tabletConversionThresholdBytes = this.tabletConversionThresholdBytes; + subStatement.autoCreateDatabase = this.autoCreateDatabase; + subStatement.isAsyncLoad = this.isAsyncLoad; + subStatement.isGeneratedByPipe = this.isGeneratedByPipe; + + // Set all files in the batch + subStatement.tsFiles = new ArrayList<>(batchFiles); + subStatement.resources = new ArrayList<>(batchFiles.size()); + subStatement.writePointCountList = new ArrayList<>(batchFiles.size()); + subStatement.isTableModel = new ArrayList<>(batchFiles.size()); + for (int j = 0; j < batchFiles.size(); j++) { + subStatement.isTableModel.add(true); + } + + subStatements.add(subStatement); + } + + return subStatements; + } + @Override public R accept(AstVisitor visitor, C context) { return visitor.visitLoadTsFile(this, context); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java new file mode 100644 index 0000000000000..f413b8a19266e --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** PREPARE statement AST node. Example: PREPARE stmt1 FROM SELECT * FROM table WHERE id = ? */ +public final class Prepare extends Statement { + + private final Identifier statementName; + private final Statement sql; + + public Prepare(Identifier statementName, Statement sql) { + super(null); + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + } + + public Prepare(NodeLocation location, Identifier statementName, Statement sql) { + super(location); + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + } + + public Identifier getStatementName() { + return statementName; + } + + public Statement getSql() { + return sql; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitPrepare(this, context); + } + + @Override + public List getChildren() { + return ImmutableList.of(statementName, sql); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Prepare that = (Prepare) o; + return Objects.equals(statementName, that.statementName) && Objects.equals(sql, that.sql); + } + + @Override + public int hashCode() { + return Objects.hash(statementName, sql); + } + + @Override + public String toString() { + return toStringHelper(this).add("statementName", statementName).add("sql", sql).toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java index 0352c85f1eb44..7ba19b972a29d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java @@ -21,6 +21,9 @@ import javax.annotation.Nullable; +import java.util.Collections; +import java.util.List; + public abstract class Statement extends Node { protected Statement(final @Nullable NodeLocation location) { @@ -31,4 +34,26 @@ protected Statement(final @Nullable NodeLocation location) { public R accept(final AstVisitor visitor, final C context) { return visitor.visitStatement(this, context); } + + /** + * Checks whether this statement should be split into multiple sub-statements based on the given + * async requirement. Used to limit resource consumption during statement analysis, etc. + * + * @param requireAsync whether async execution is required + * @return true if the statement should be split, false otherwise. Default implementation returns + * false. + */ + public boolean shouldSplit() { + return false; + } + + /** + * Splits the current statement into multiple sub-statements. Used to limit resource consumption + * during statement analysis, etc. + * + * @return the list of sub-statements. Default implementation returns empty list. + */ + public List getSubStatements() { + return Collections.emptyList(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java index f62b45d8b8a8a..b041c56b2bc97 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java @@ -71,6 +71,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentUser; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataType; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataTypeParameter; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Deallocate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Delete; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression; @@ -89,6 +90,8 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.EmptyPattern; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Except; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExcludedPattern; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Execute; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExecuteImmediate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExistsPredicate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Explain; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExplainAnalyze; @@ -145,6 +148,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation.RowsPerMatch; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternVariable; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Prepare; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ProcessingMode; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Property; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.QualifiedName; @@ -3786,6 +3790,40 @@ public Node visitDropModelStatement(RelationalSqlParser.DropModelStatementContex return new DropModel(modelId); } + @Override + public Node visitPrepareStatement(RelationalSqlParser.PrepareStatementContext ctx) { + Identifier statementName = lowerIdentifier((Identifier) visit(ctx.statementName)); + Statement sql = (Statement) visit(ctx.sql); + return new Prepare(getLocation(ctx), statementName, sql); + } + + @Override + public Node visitExecuteStatement(RelationalSqlParser.ExecuteStatementContext ctx) { + Identifier statementName = lowerIdentifier((Identifier) visit(ctx.statementName)); + List parameters = + ctx.literalExpression() != null && !ctx.literalExpression().isEmpty() + ? visit(ctx.literalExpression(), Literal.class) + : ImmutableList.of(); + return new Execute(getLocation(ctx), statementName, parameters); + } + + @Override + public Node visitExecuteImmediateStatement( + RelationalSqlParser.ExecuteImmediateStatementContext ctx) { + StringLiteral sql = (StringLiteral) visit(ctx.sql); + List parameters = + ctx.literalExpression() != null && !ctx.literalExpression().isEmpty() + ? visit(ctx.literalExpression(), Literal.class) + : ImmutableList.of(); + return new ExecuteImmediate(getLocation(ctx), sql, parameters); + } + + @Override + public Node visitDeallocateStatement(RelationalSqlParser.DeallocateStatementContext ctx) { + Identifier statementName = lowerIdentifier((Identifier) visit(ctx.statementName)); + return new Deallocate(getLocation(ctx), statementName); + } + // ***************** arguments ***************** @Override public Node visitGenericType(RelationalSqlParser.GenericTypeContext ctx) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java index 5b31f08ca677f..0b2ecff6b5bc5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java @@ -23,6 +23,7 @@ import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.plan.parser.ASTVisitor; +import java.util.Collections; import java.util.List; /** @@ -68,4 +69,26 @@ public org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement toRelat public String getPipeLoggingString() { return toString(); } + + /** + * Checks whether this statement should be split into multiple sub-statements based on the given + * async requirement. Used to limit resource consumption during statement analysis, etc. + * + * @param requireAsync whether async execution is required + * @return true if the statement should be split, false otherwise. Default implementation returns + * false. + */ + public boolean shouldSplit() { + return false; + } + + /** + * Splits the current statement into multiple sub-statements. Used to limit resource consumption + * during statement analysis, etc. + * + * @return the list of sub-statements. Default implementation returns empty list. + */ + public List getSubStatements() { + return Collections.emptyList(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java index 1aae871ea0c1d..d8786e33959a5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java @@ -839,6 +839,16 @@ public long ramBytesUsed() { return ramBytesUsed; } + /** + * Set the pre-calculated memory size. This is used when memory size is calculated during + * deserialization to avoid recalculation. + * + * @param ramBytesUsed the calculated memory size in bytes + */ + public void setRamBytesUsed(long ramBytesUsed) { + this.ramBytesUsed = ramBytesUsed; + } + private long shallowSizeOfList(List list) { return Objects.nonNull(list) ? UpdateDetailContainer.LIST_SIZE diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java index d2c1d2a783a9c..12369d81bfd30 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java @@ -44,6 +44,7 @@ import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor; import org.apache.iotdb.db.utils.CommonUtils; +import org.apache.tsfile.enums.ColumnCategory; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.file.metadata.IDeviceID; import org.apache.tsfile.file.metadata.IDeviceID.Factory; @@ -58,6 +59,8 @@ import org.apache.tsfile.write.record.Tablet; import org.apache.tsfile.write.schema.IMeasurementSchema; import org.apache.tsfile.write.schema.MeasurementSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.time.LocalDate; import java.util.ArrayList; @@ -69,11 +72,22 @@ import java.util.Objects; public class InsertTabletStatement extends InsertBaseStatement implements ISchemaValidation { + private static final Logger LOGGER = LoggerFactory.getLogger(InsertTabletStatement.class); + private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(InsertTabletStatement.class); private static final String DATATYPE_UNSUPPORTED = "Data type %s is not supported."; + /** + * Get the instance size of InsertTabletStatement for memory calculation. + * + * @return instance size in bytes + */ + public static long getInstanceSize() { + return INSTANCE_SIZE; + } + protected long[] times; // times should be sorted. It is done in the session API. protected BitMap[] nullBitMaps; protected Object[] columns; @@ -701,6 +715,189 @@ protected void subRemoveAttributeColumns(List columnsToKeep) { } } + /** + * Convert this InsertTabletStatement to Tablet. This method constructs a Tablet object from this + * statement, converting all necessary fields. All arrays are copied to rowSize length to ensure + * immutability. + * + * @return Tablet object + * @throws MetadataException if conversion fails + */ + public Tablet convertToTablet() throws MetadataException { + try { + // Get deviceId/tableName from devicePath + final String deviceIdOrTableName = + this.getDevicePath() != null ? this.getDevicePath().getFullPath() : ""; + + // Get schemas from measurementSchemas + final MeasurementSchema[] measurementSchemas = this.getMeasurementSchemas(); + final String[] measurements = this.getMeasurements(); + final TSDataType[] dataTypes = this.getDataTypes(); + // If measurements and dataTypes are not null, use measurements.length as the standard length + final int originalSchemaSize = measurements != null ? measurements.length : 0; + + // Build schemas and track valid column indices (skip null columns) + // measurements and dataTypes being null is standard - skip those columns + final List schemas = new ArrayList<>(); + final List validColumnIndices = new ArrayList<>(); + for (int i = 0; i < originalSchemaSize; i++) { + if (dataTypes != null && measurements[i] != null && dataTypes[i] != null) { + // Create MeasurementSchema if not present + schemas.add(new MeasurementSchema(measurements[i], dataTypes[i])); + validColumnIndices.add(i); + } + // Skip null columns - don't add to schemas or validColumnIndices + } + + final int schemaSize = schemas.size(); + + // Get columnTypes (for table model) - only for valid columns + final TsTableColumnCategory[] columnCategories = this.getColumnCategories(); + final List tabletColumnTypes = new ArrayList<>(); + if (columnCategories != null && columnCategories.length > 0) { + for (final int validIndex : validColumnIndices) { + if (columnCategories[validIndex] != null) { + tabletColumnTypes.add(columnCategories[validIndex].toTsFileColumnType()); + } else { + tabletColumnTypes.add(ColumnCategory.FIELD); + } + } + } else { + // Default to FIELD for all valid columns if not specified + for (int i = 0; i < schemaSize; i++) { + tabletColumnTypes.add(ColumnCategory.FIELD); + } + } + + // Get timestamps - always copy to ensure immutability + final long[] times = this.getTimes(); + final int rowSize = this.getRowCount(); + final long[] timestamps; + if (times != null && times.length >= rowSize && rowSize > 0) { + timestamps = new long[rowSize]; + System.arraycopy(times, 0, timestamps, 0, rowSize); + } else { + LOGGER.warn( + "Times array is null or too small. times.length={}, rowSize={}, deviceId={}", + times != null ? times.length : 0, + rowSize, + deviceIdOrTableName); + timestamps = new long[0]; + } + + // Get values - convert Statement columns to Tablet format, only for valid columns + // All arrays are truncated/copied to rowSize length + final Object[] statementColumns = this.getColumns(); + final Object[] tabletValues = new Object[schemaSize]; + if (statementColumns != null && statementColumns.length > 0) { + for (int i = 0; i < validColumnIndices.size(); i++) { + final int originalIndex = validColumnIndices.get(i); + if (statementColumns[originalIndex] != null && dataTypes[originalIndex] != null) { + tabletValues[i] = + convertColumnToTablet( + statementColumns[originalIndex], dataTypes[originalIndex], rowSize); + } else { + tabletValues[i] = null; + } + } + } + + // Get bitMaps - copy and truncate to rowSize, only for valid columns + final BitMap[] originalBitMaps = this.getBitMaps(); + final BitMap[] bitMaps; + if (originalBitMaps != null && originalBitMaps.length > 0) { + bitMaps = new BitMap[schemaSize]; + for (int i = 0; i < validColumnIndices.size(); i++) { + final int originalIndex = validColumnIndices.get(i); + if (originalBitMaps[originalIndex] != null) { + // Create a new BitMap truncated to rowSize + final byte[] truncatedBytes = + originalBitMaps[originalIndex].getTruncatedByteArray(rowSize); + bitMaps[i] = new BitMap(rowSize, truncatedBytes); + } else { + bitMaps[i] = null; + } + } + } else { + bitMaps = null; + } + + // Create Tablet using the full constructor + // Tablet(String tableName, List schemas, List + // columnTypes, + // long[] timestamps, Object[] values, BitMap[] bitMaps, int rowSize) + return new Tablet( + deviceIdOrTableName, + schemas, + tabletColumnTypes, + timestamps, + tabletValues, + bitMaps, + rowSize); + } catch (final Exception e) { + throw new MetadataException("Failed to convert InsertTabletStatement to Tablet", e); + } + } + + /** + * Convert a single column value from Statement format to Tablet format. Statement uses primitive + * arrays (e.g., int[], long[], float[]), while Tablet may need different format. All arrays are + * copied and truncated to rowSize length to ensure immutability - even if the original array is + * modified, the converted array remains unchanged. + * + * @param columnValue column value from Statement (primitive array) + * @param dataType data type of the column + * @param rowSize number of rows to copy (truncate to this length) + * @return column value in Tablet format (copied and truncated array) + */ + private Object convertColumnToTablet( + final Object columnValue, final TSDataType dataType, final int rowSize) { + + if (columnValue == null) { + return null; + } + + if (TSDataType.DATE.equals(dataType)) { + final int[] values = (int[]) columnValue; + // Copy and truncate to rowSize + final int[] copiedValues = Arrays.copyOf(values, Math.min(values.length, rowSize)); + final LocalDate[] localDateValue = new LocalDate[rowSize]; + for (int i = 0; i < copiedValues.length; i++) { + localDateValue[i] = DateUtils.parseIntToLocalDate(copiedValues[i]); + } + // Fill remaining with null if needed + for (int i = copiedValues.length; i < rowSize; i++) { + localDateValue[i] = null; + } + return localDateValue; + } + + // For primitive arrays, copy and truncate to rowSize + if (columnValue instanceof boolean[]) { + final boolean[] original = (boolean[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof int[]) { + final int[] original = (int[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof long[]) { + final long[] original = (long[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof float[]) { + final float[] original = (float[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof double[]) { + final double[] original = (double[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof Binary[]) { + // For Binary arrays, create a new array and copy references, truncate to rowSize + final Binary[] original = (Binary[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } + + // For other types, return as-is (should not happen for standard types) + return columnValue; + } + @Override public String toString() { final int size = CommonDescriptor.getInstance().getConfig().getPathLogMaxSize(); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java index d1dff1bb9cf25..a51dcaf09d2b3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java @@ -306,6 +306,54 @@ public boolean reconstructStatementIfMiniFileConverted(final List isMin return tsFiles == null || tsFiles.isEmpty(); } + @Override + public boolean shouldSplit() { + final int splitThreshold = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileStatementSplitThreshold(); + return tsFiles.size() > splitThreshold && !isAsyncLoad; + } + + /** + * Splits the current LoadTsFileStatement into multiple sub-statements, each handling a batch of + * TsFiles. Used to limit resource consumption during statement analysis, etc. + * + * @return the list of sub-statements + */ + @Override + public List getSubStatements() { + final int batchSize = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileSubStatementBatchSize(); + final int totalBatches = (tsFiles.size() + batchSize - 1) / batchSize; // Ceiling division + final List subStatements = new ArrayList<>(totalBatches); + + for (int i = 0; i < tsFiles.size(); i += batchSize) { + final int endIndex = Math.min(i + batchSize, tsFiles.size()); + final List batchFiles = tsFiles.subList(i, endIndex); + + final LoadTsFileStatement statement = new LoadTsFileStatement(); + statement.databaseLevel = this.databaseLevel; + statement.verifySchema = this.verifySchema; + statement.deleteAfterLoad = this.deleteAfterLoad; + statement.convertOnTypeMismatch = this.convertOnTypeMismatch; + statement.tabletConversionThresholdBytes = this.tabletConversionThresholdBytes; + statement.autoCreateDatabase = this.autoCreateDatabase; + statement.isAsyncLoad = this.isAsyncLoad; + statement.isGeneratedByPipe = this.isGeneratedByPipe; + + statement.tsFiles = new ArrayList<>(batchFiles); + statement.resources = new ArrayList<>(batchFiles.size()); + statement.writePointCountList = new ArrayList<>(batchFiles.size()); + statement.isTableModel = new ArrayList<>(batchFiles.size()); + for (int j = 0; j < batchFiles.size(); j++) { + statement.isTableModel.add(false); + } + + subStatements.add(statement); + } + + return subStatements; + } + @Override public List getPaths() { return Collections.emptyList(); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index 260410954d4c6..09f00f8ed672b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -19,14 +19,12 @@ package org.apache.iotdb.db.queryengine.plan.udf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.UDTF; import org.apache.iotdb.udf.api.access.Row; @@ -54,8 +52,8 @@ public class UDTFForecast implements UDTF { private static final TsBlockSerde serde = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); - private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); private String model_id; private int maxInputLength; private int outputLength; @@ -66,7 +64,6 @@ public class UDTFForecast implements UDTF { List types; private LinkedList inputRows; private TsBlockBuilder inputTsBlockBuilder; - private final IModelFetcher modelFetcher = ModelFetcher.getInstance(); private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); @@ -112,8 +109,6 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati throw new IllegalArgumentException( "MODEL_ID parameter must be provided and cannot be empty."); } - ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); - this.targetAINode = descriptor.getTargetAINode(); this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); this.outputLength = @@ -211,8 +206,12 @@ private TsBlock forecast() throws Exception { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(model_id, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(model_id, serde.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException( e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode()); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java index c4b08f8f05a8a..1fac05012fe89 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java @@ -111,6 +111,7 @@ import org.apache.iotdb.db.storageengine.dataregion.memtable.TsFileProcessor; import org.apache.iotdb.db.storageengine.dataregion.wal.WALManager; import org.apache.iotdb.db.storageengine.dataregion.wal.utils.WALMode; +import org.apache.iotdb.db.storageengine.load.active.ActiveLoadAgent; import org.apache.iotdb.db.storageengine.rescon.disk.TierManager; import org.apache.iotdb.db.subscription.agent.SubscriptionAgent; import org.apache.iotdb.db.trigger.executor.TriggerExecutor; @@ -245,6 +246,8 @@ protected void start() { sendRegisterRequestToConfigNode(false); saveSecretKey(); saveHardwareCode(); + // Clean up active load listening directories on first startup + ActiveLoadAgent.cleanupListeningDirectories(); } else { /* Check encrypt magic string */ try { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java index 1124e33a7df95..4db15d48bc673 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java @@ -2476,7 +2476,8 @@ private List getFileHandleListForQuery( } else { tsFileResource .getProcessor() - .queryForSeriesRegionScanWithoutLock(partialPaths, context, fileScanHandles); + .queryForSeriesRegionScanWithoutLock( + partialPaths, context, fileScanHandles, globalTimeFilter); } } return fileScanHandles; @@ -2553,7 +2554,8 @@ private List getFileHandleListForQuery( } else { tsFileResource .getProcessor() - .queryForDeviceRegionScanWithoutLock(devicePathsToContext, context, fileScanHandles); + .queryForDeviceRegionScanWithoutLock( + devicePathsToContext, context, fileScanHandles, globalTimeFilter); } } return fileScanHandles; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java index 92803984cee1e..22606f232a8d1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java @@ -70,6 +70,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; @@ -488,7 +489,8 @@ public void queryForSeriesRegionScan( long ttlLowerBound, Map> chunkMetaDataMap, Map> memChunkHandleMap, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { IDeviceID deviceID = fullPath.getDeviceId(); if (fullPath instanceof NonAlignedFullPath) { @@ -506,7 +508,12 @@ public void queryForSeriesRegionScan( fullPath.getDeviceId(), measurementId, this, modsToMemTabled, ttlLowerBound); } getMemChunkHandleFromMemTable( - deviceID, measurementId, chunkMetaDataMap, memChunkHandleMap, deletionList); + deviceID, + measurementId, + chunkMetaDataMap, + memChunkHandleMap, + deletionList, + globalTimeFilter); } else { // check If MemTable Contains this path if (!memTableMap.containsKey(deviceID)) { @@ -528,7 +535,8 @@ public void queryForSeriesRegionScan( ((AlignedFullPath) fullPath).getSchemaList(), chunkMetaDataMap, memChunkHandleMap, - deletionList); + deletionList, + globalTimeFilter); } } @@ -539,7 +547,8 @@ public void queryForDeviceRegionScan( long ttlLowerBound, Map> chunkMetadataMap, Map> memChunkHandleMap, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { Map memTableMap = getMemTableMap(); @@ -556,7 +565,8 @@ public void queryForDeviceRegionScan( chunkMetadataMap, memChunkHandleMap, ttlLowerBound, - modsToMemTabled); + modsToMemTabled, + globalTimeFilter); } else { getMemChunkHandleFromMemTable( deviceID, @@ -564,7 +574,8 @@ public void queryForDeviceRegionScan( chunkMetadataMap, memChunkHandleMap, ttlLowerBound, - modsToMemTabled); + modsToMemTabled, + globalTimeFilter); } } @@ -573,24 +584,30 @@ private void getMemChunkHandleFromMemTable( String measurementId, Map> chunkMetadataMap, Map> memChunkHandleMap, - List deletionList) { + List deletionList, + Filter globalTimeFilter) { WritableMemChunk memChunk = (WritableMemChunk) memTableMap.get(deviceID).getMemChunkMap().get(measurementId); - long[] timestamps = memChunk.getFilteredTimestamp(deletionList); + if (memChunk == null) { + return; + } + Optional anySatisfiedTimestamp = + memChunk.getAnySatisfiedTimestamp(deletionList, globalTimeFilter); + if (!anySatisfiedTimestamp.isPresent()) { + return; + } + long satisfiedTimestamp = anySatisfiedTimestamp.get(); chunkMetadataMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) .add( - buildChunkMetaDataForMemoryChunk( - measurementId, - timestamps[0], - timestamps[timestamps.length - 1], - Collections.emptyList())); + buildFakeChunkMetaDataForFakeMemoryChunk( + measurementId, satisfiedTimestamp, satisfiedTimestamp, Collections.emptyList())); memChunkHandleMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) - .add(new MemChunkHandleImpl(deviceID, measurementId, timestamps)); + .add(new MemChunkHandleImpl(deviceID, measurementId, new long[] {satisfiedTimestamp})); } private void getMemAlignedChunkHandleFromMemTable( @@ -598,7 +615,8 @@ private void getMemAlignedChunkHandleFromMemTable( List schemaList, Map> chunkMetadataList, Map> memChunkHandleMap, - List> deletionList) { + List> deletionList, + Filter globalTimeFilter) { AlignedWritableMemChunk alignedMemChunk = ((AlignedWritableMemChunkGroup) memTableMap.get(deviceID)).getAlignedMemChunk(); @@ -615,7 +633,11 @@ private void getMemAlignedChunkHandleFromMemTable( } List bitMaps = new ArrayList<>(); - long[] timestamps = alignedMemChunk.getFilteredTimestamp(deletionList, bitMaps, true); + long[] timestamps = + alignedMemChunk.getAnySatisfiedTimestamp(deletionList, bitMaps, true, globalTimeFilter); + if (timestamps.length == 0) { + return; + } buildAlignedMemChunkHandle( deviceID, @@ -633,7 +655,8 @@ private void getMemAlignedChunkHandleFromMemTable( Map> chunkMetadataList, Map> memChunkHandleMap, long ttlLowerBound, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { AlignedWritableMemChunk memChunk = writableMemChunkGroup.getAlignedMemChunk(); List schemaList = memChunk.getSchemaList(); @@ -648,7 +671,11 @@ private void getMemAlignedChunkHandleFromMemTable( } List bitMaps = new ArrayList<>(); - long[] timestamps = memChunk.getFilteredTimestamp(deletionList, bitMaps, true); + long[] timestamps = + memChunk.getAnySatisfiedTimestamp(deletionList, bitMaps, true, globalTimeFilter); + if (timestamps.length == 0) { + return; + } buildAlignedMemChunkHandle( deviceID, timestamps, @@ -665,7 +692,8 @@ private void getMemChunkHandleFromMemTable( Map> chunkMetadataMap, Map> memChunkHandleMap, long ttlLowerBound, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { for (Entry entry : writableMemChunkGroup.getMemChunkMap().entrySet()) { @@ -679,18 +707,20 @@ private void getMemChunkHandleFromMemTable( ModificationUtils.constructDeletionList( deviceID, measurementId, this, modsToMemTabled, ttlLowerBound); } - long[] timestamps = writableMemChunk.getFilteredTimestamp(deletionList); + Optional anySatisfiedTimestamp = + writableMemChunk.getAnySatisfiedTimestamp(deletionList, globalTimeFilter); + if (!anySatisfiedTimestamp.isPresent()) { + return; + } + long satisfiedTimestamp = anySatisfiedTimestamp.get(); chunkMetadataMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) .add( - buildChunkMetaDataForMemoryChunk( - measurementId, - timestamps[0], - timestamps[timestamps.length - 1], - Collections.emptyList())); + buildFakeChunkMetaDataForFakeMemoryChunk( + measurementId, satisfiedTimestamp, satisfiedTimestamp, Collections.emptyList())); memChunkHandleMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) - .add(new MemChunkHandleImpl(deviceID, measurementId, timestamps)); + .add(new MemChunkHandleImpl(deviceID, measurementId, new long[] {satisfiedTimestamp})); } } @@ -714,7 +744,7 @@ private void buildAlignedMemChunkHandle( chunkMetadataList .computeIfAbsent(measurement, k -> new ArrayList<>()) .add( - buildChunkMetaDataForMemoryChunk( + buildFakeChunkMetaDataForFakeMemoryChunk( measurement, startEndTime[0], startEndTime[1], deletion)); chunkHandleMap .computeIfAbsent(measurement, k -> new ArrayList<>()) @@ -745,7 +775,7 @@ private long[] calculateStartEndTime(long[] timestamps, List bitMaps, in return new long[] {startTime, endTime}; } - private IChunkMetadata buildChunkMetaDataForMemoryChunk( + private IChunkMetadata buildFakeChunkMetaDataForFakeMemoryChunk( String measurement, long startTime, long endTime, List deletionList) { TimeStatistics timeStatistics = new TimeStatistics(); timeStatistics.setStartTime(startTime); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java index 5f6af57746fe7..f42867d4772c1 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java @@ -34,6 +34,7 @@ import org.apache.tsfile.encrypt.EncryptUtils; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.common.TimeRange; +import org.apache.tsfile.read.filter.basic.Filter; import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; import org.apache.tsfile.utils.Pair; @@ -55,6 +56,7 @@ import java.util.Set; import java.util.TreeMap; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; import static org.apache.iotdb.db.utils.ModificationUtils.isPointDeleted; @@ -287,30 +289,99 @@ private Pair checkAndReorderColumnValuesInInsertPlan( return new Pair<>(reorderedColumnValues, reorderedBitMaps); } - private void filterDeletedTimeStamp( + public long[] getAnySatisfiedTimestamp( + List> deletionList, + List bitMaps, + boolean ignoreAllNullRows, + Filter globalTimeFilter) { + BitMap columnHasNonNullValue = new BitMap(schemaList.size()); + AtomicInteger hasNonNullValueColumnCount = new AtomicInteger(0); + Map timestampWithBitmap = new TreeMap<>(); + + getAnySatisfiedTimestamp( + list, + deletionList, + ignoreAllNullRows, + timestampWithBitmap, + globalTimeFilter, + columnHasNonNullValue, + hasNonNullValueColumnCount); + for (int i = 0; + i < sortedList.size() && hasNonNullValueColumnCount.get() < schemaList.size(); + i++) { + if (!ignoreAllNullRows && !timestampWithBitmap.isEmpty()) { + // count devices in table model + break; + } + getAnySatisfiedTimestamp( + sortedList.get(i), + deletionList, + ignoreAllNullRows, + timestampWithBitmap, + globalTimeFilter, + columnHasNonNullValue, + hasNonNullValueColumnCount); + } + + long[] timestamps = new long[timestampWithBitmap.size()]; + int idx = 0; + for (Map.Entry entry : timestampWithBitmap.entrySet()) { + timestamps[idx++] = entry.getKey(); + bitMaps.add(entry.getValue()); + } + return timestamps; + } + + private void getAnySatisfiedTimestamp( AlignedTVList alignedTVList, List> valueColumnsDeletionList, boolean ignoreAllNullRows, - Map timestampWithBitmap) { + Map timestampWithBitmap, + Filter globalTimeFilter, + BitMap columnHasNonNullValue, + AtomicInteger hasNonNullValueColumnCount) { + if (globalTimeFilter != null + && !globalTimeFilter.satisfyStartEndTime( + alignedTVList.getMinTime(), alignedTVList.getMaxTime())) { + return; + } BitMap allValueColDeletedMap = alignedTVList.getAllValueColDeletedMap(); - int rowCount = alignedTVList.rowCount(); List valueColumnDeleteCursor = new ArrayList<>(); if (valueColumnsDeletionList != null) { valueColumnsDeletionList.forEach(x -> valueColumnDeleteCursor.add(new int[] {0})); } + // example: + // globalTimeFilter:null, ignoreAllNullRows: true + // tvList: + // time s1 s2 s3 + // 1 1 null null + // 2 null 1 null + // 2 1 1 null + // 3 1 null null + // 4 1 null 1 + // timestampWithBitmap: + // timestamp: 1 bitmap: 011 + // timestamp: 2 bitmap: 101 + // timestamp: 4 bitmap: 110 for (int row = 0; row < rowCount; row++) { // the row is deleted if (allValueColDeletedMap != null && allValueColDeletedMap.isMarked(row)) { continue; } long timestamp = alignedTVList.getTime(row); + if (globalTimeFilter != null && !globalTimeFilter.satisfy(timestamp, null)) { + continue; + } + + // Note that this method will only perform bitmap unmarking on the first occurrence of a + // non-null value in multiple timestamps for the same column. + BitMap currentRowNullValueBitmap = null; - BitMap bitMap = new BitMap(schemaList.size()); for (int column = 0; column < schemaList.size(); column++) { if (alignedTVList.isNullValue(alignedTVList.getValueIndex(row), column)) { - bitMap.mark(column); + continue; } // skip deleted row @@ -320,33 +391,44 @@ && isPointDeleted( timestamp, valueColumnsDeletionList.get(column), valueColumnDeleteCursor.get(column))) { - bitMap.mark(column); - } - - // skip all-null row - if (ignoreAllNullRows && bitMap.isAllMarked()) { continue; } - timestampWithBitmap.put(timestamp, bitMap); + if (!columnHasNonNullValue.isMarked(column)) { + hasNonNullValueColumnCount.incrementAndGet(); + columnHasNonNullValue.mark(column); + currentRowNullValueBitmap = + currentRowNullValueBitmap != null + ? currentRowNullValueBitmap + : timestampWithBitmap.computeIfAbsent( + timestamp, k -> getAllMarkedBitmap(schemaList.size())); + currentRowNullValueBitmap.unmark(column); + } } - } - } - public long[] getFilteredTimestamp( - List> deletionList, List bitMaps, boolean ignoreAllNullRows) { - Map timestampWithBitmap = new TreeMap<>(); + if (!ignoreAllNullRows) { + timestampWithBitmap.put( + timestamp, + currentRowNullValueBitmap != null + ? currentRowNullValueBitmap + : getAllMarkedBitmap(schemaList.size())); + return; + } + if (currentRowNullValueBitmap == null) { + continue; + } + // found new column with non-null value + timestampWithBitmap.put(timestamp, currentRowNullValueBitmap); - filterDeletedTimeStamp(list, deletionList, ignoreAllNullRows, timestampWithBitmap); - for (AlignedTVList alignedTVList : sortedList) { - filterDeletedTimeStamp(alignedTVList, deletionList, ignoreAllNullRows, timestampWithBitmap); + if (hasNonNullValueColumnCount.get() == schemaList.size()) { + return; + } } + } - List filteredTimestamps = new ArrayList<>(); - for (Map.Entry entry : timestampWithBitmap.entrySet()) { - filteredTimestamps.add(entry.getKey()); - bitMaps.add(entry.getValue()); - } - return filteredTimestamps.stream().mapToLong(Long::valueOf).toArray(); + private BitMap getAllMarkedBitmap(int size) { + BitMap bitMap = new BitMap(size); + bitMap.markAll(); + return bitMap; } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java index 6c6e09c578256..fd9ffe90b0ae8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java @@ -125,7 +125,8 @@ void queryForSeriesRegionScan( long ttlLowerBound, Map> chunkMetadataMap, Map> memChunkHandleMap, - List> modsToMemtabled) + List> modsToMemtabled, + Filter globalTimeFilter) throws IOException, QueryProcessException, MetadataException; void queryForDeviceRegionScan( @@ -134,7 +135,8 @@ void queryForDeviceRegionScan( long ttlLowerBound, Map> chunkMetadataMap, Map> memChunkHandleMap, - List> modsToMemtabled) + List> modsToMemtabled, + Filter globalTimeFilter) throws IOException, QueryProcessException, MetadataException; /** putBack all the memory resources. */ diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java index 68776c92c0b59..36cbb4e0f880d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java @@ -1974,7 +1974,8 @@ private List getAlignedVisibleMetadataListFromWriterByDeviceID( public void queryForSeriesRegionScanWithoutLock( List pathList, QueryContext queryContext, - List fileScanHandlesForQuery) { + List fileScanHandlesForQuery, + Filter globalTimeFilter) { long startTime = System.nanoTime(); try { Map>> deviceToMemChunkHandleMap = new HashMap<>(); @@ -1995,7 +1996,8 @@ public void queryForSeriesRegionScanWithoutLock( timeLowerBound, measurementToChunkMetaList, measurementToChunkHandleList, - modsToMemtable); + modsToMemtable, + globalTimeFilter); } if (workMemTable != null) { workMemTable.queryForSeriesRegionScan( @@ -2003,7 +2005,8 @@ public void queryForSeriesRegionScanWithoutLock( timeLowerBound, measurementToChunkMetaList, measurementToChunkHandleList, - null); + null, + globalTimeFilter); } IDeviceID deviceID = seriesPath.getDeviceId(); // Some memTable have been flushed already, so we need to get the chunk metadata from @@ -2054,7 +2057,8 @@ public void queryForSeriesRegionScanWithoutLock( public void queryForDeviceRegionScanWithoutLock( Map devicePathsToContext, QueryContext queryContext, - List fileScanHandlesForQuery) { + List fileScanHandlesForQuery, + Filter globalTimeFilter) { long startTime = System.nanoTime(); try { Map>> deviceToMemChunkHandleMap = new HashMap<>(); @@ -2077,7 +2081,8 @@ public void queryForDeviceRegionScanWithoutLock( timeLowerBound, measurementToChunkMetadataList, measurementToMemChunkHandleList, - modsToMemtable); + modsToMemtable, + globalTimeFilter); } if (workMemTable != null) { workMemTable.queryForDeviceRegionScan( @@ -2086,7 +2091,8 @@ public void queryForDeviceRegionScanWithoutLock( timeLowerBound, measurementToChunkMetadataList, measurementToMemChunkHandleList, - null); + null, + globalTimeFilter); } buildChunkHandleForFlushedMemTable( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java index c4a871a29cc13..6499d33fcd9f4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java @@ -35,6 +35,7 @@ import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.TimeValuePair; import org.apache.tsfile.read.common.TimeRange; +import org.apache.tsfile.read.filter.basic.Filter; import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; import org.apache.tsfile.write.UnSupportedDataTypeException; @@ -48,10 +49,9 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.concurrent.BlockingQueue; -import java.util.stream.Collectors; import static org.apache.iotdb.db.utils.MemUtils.getBinarySize; @@ -575,11 +575,30 @@ public List getSortedList() { return sortedList; } - private void filterDeletedTimestamp( - TVList tvlist, List deletionList, List timestampList) { - long lastTime = Long.MIN_VALUE; + public Optional getAnySatisfiedTimestamp( + List deletionList, Filter globalTimeFilter) { + Optional anySatisfiedTimestamp = + getAnySatisfiedTimestamp(list, deletionList, globalTimeFilter); + if (anySatisfiedTimestamp.isPresent()) { + return anySatisfiedTimestamp; + } + for (TVList tvList : sortedList) { + anySatisfiedTimestamp = getAnySatisfiedTimestamp(tvList, deletionList, globalTimeFilter); + if (anySatisfiedTimestamp.isPresent()) { + break; + } + } + return anySatisfiedTimestamp; + } + + private Optional getAnySatisfiedTimestamp( + TVList tvlist, List deletionList, Filter globalTimeFilter) { int[] deletionCursor = {0}; int rowCount = tvlist.rowCount(); + if (globalTimeFilter != null + && !globalTimeFilter.satisfyStartEndTime(tvlist.getMinTime(), tvlist.getMaxTime())) { + return Optional.empty(); + } for (int i = 0; i < rowCount; i++) { if (tvlist.getBitMap() != null && tvlist.isNullValue(tvlist.getValueIndex(i))) { continue; @@ -589,27 +608,12 @@ private void filterDeletedTimestamp( && ModificationUtils.isPointDeleted(curTime, deletionList, deletionCursor)) { continue; } - - if (i == rowCount - 1 || curTime != lastTime) { - timestampList.add(curTime); + if (globalTimeFilter != null && !globalTimeFilter.satisfy(curTime, null)) { + continue; } - lastTime = curTime; + return Optional.of(curTime); } - } - - public long[] getFilteredTimestamp(List deletionList) { - List timestampList = new ArrayList<>(); - filterDeletedTimestamp(list, deletionList, timestampList); - for (TVList tvList : sortedList) { - filterDeletedTimestamp(tvList, deletionList, timestampList); - } - - // remove duplicated time - List distinctTimestamps = timestampList.stream().distinct().collect(Collectors.toList()); - // sort timestamps - long[] filteredTimestamps = distinctTimestamps.stream().mapToLong(Long::longValue).toArray(); - Arrays.sort(filteredTimestamps); - return filteredTimestamps; + return Optional.empty(); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java index f060bd9a96f3c..6065c349c8c54 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java @@ -19,8 +19,21 @@ package org.apache.iotdb.db.storageengine.load.active; +import org.apache.iotdb.commons.utils.FileUtils; +import org.apache.iotdb.db.conf.IoTDBDescriptor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + public class ActiveLoadAgent { + private static final Logger LOGGER = LoggerFactory.getLogger(ActiveLoadAgent.class); + private final ActiveLoadTsFileLoader activeLoadTsFileLoader; private final ActiveLoadDirScanner activeLoadDirScanner; private final ActiveLoadMetricsCollector activeLoadMetricsCollector; @@ -48,4 +61,81 @@ public synchronized void start() { activeLoadDirScanner.start(); activeLoadMetricsCollector.start(); } + + /** + * Clean up all listening directories for active load on DataNode first startup. This method will + * clean up all files and subdirectories in the listening directories, including: 1. Pending + * directories (configured by load_active_listening_dirs) 2. Pipe directory (for pipe data sync) + * 3. Failed directory (for failed files) + * + *

This method is called during DataNode startup and must not throw any exceptions to ensure + * startup can proceed normally. All exceptions are caught and logged internally. + */ + public static void cleanupListeningDirectories() { + try { + final List dirsToClean = new ArrayList<>(); + + dirsToClean.addAll( + Arrays.asList(IoTDBDescriptor.getInstance().getConfig().getLoadActiveListeningDirs())); + + // Add pipe dir + dirsToClean.add(IoTDBDescriptor.getInstance().getConfig().getLoadActiveListeningPipeDir()); + + // Add failed dir + dirsToClean.add(IoTDBDescriptor.getInstance().getConfig().getLoadActiveListeningFailDir()); + + // Clean up each directory + for (final String dirPath : dirsToClean) { + try { + if (dirPath == null || dirPath.isEmpty()) { + continue; + } + + final File dir = new File(dirPath); + + // Check if directory exists and is a directory + // These methods may throw SecurityException if access is denied + try { + if (!dir.exists() || !dir.isDirectory()) { + continue; + } + } catch (Exception e) { + LOGGER.debug("Failed to check directory: {}", dirPath, e); + continue; + } + + // Only delete contents inside the directory, not the directory itself + // listFiles() may throw SecurityException if access is denied + File[] files = null; + try { + files = dir.listFiles(); + } catch (Exception e) { + LOGGER.warn("Failed to list files in directory: {}", dirPath, e); + continue; + } + + if (files != null) { + for (final File file : files) { + // FileUtils.deleteFileOrDirectory internally calls file.isDirectory() and + // file.listFiles() without try-catch, so exceptions may propagate here. + // We need to catch it to prevent one file failure from stopping the cleanup. + try { + FileUtils.deleteFileOrDirectory(file, true); + } catch (Exception e) { + LOGGER.debug("Failed to delete file or directory: {}", file.getAbsolutePath(), e); + } + } + } + } catch (Exception e) { + LOGGER.warn("Failed to cleanup directory: {}", dirPath, e); + } + } + + LOGGER.info("Cleaned up active load listening directories"); + } catch (Throwable t) { + // Catch all exceptions and errors (including OutOfMemoryError, etc.) + // to ensure startup process is not affected + LOGGER.warn("Unexpected error during cleanup of active load listening directories", t); + } + } } diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java index 1a122c63d4f23..0cc4470882efd 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java @@ -437,7 +437,7 @@ public void testPipeTransferTabletBatchReqV2() throws IOException { try (final PublicBAOS byteArrayOutputStream = new PublicBAOS(); final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) { t.serialize(outputStream); - ReadWriteIOUtils.write(false, outputStream); + ReadWriteIOUtils.write(true, outputStream); tabletBuffers.add( ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size())); tabletDataBase.add("test"); @@ -459,7 +459,7 @@ public void testPipeTransferTabletBatchReqV2() throws IOException { new byte[] {'a', 'b'}, deserializedReq.getBinaryReqs().get(0).getByteBuffer().array()); Assert.assertEquals(node, deserializedReq.getInsertNodeReqs().get(0).getInsertNode()); Assert.assertEquals(t, deserializedReq.getTabletReqs().get(0).getTablet()); - Assert.assertFalse(deserializedReq.getTabletReqs().get(0).getIsAligned()); + Assert.assertTrue(deserializedReq.getTabletReqs().get(0).getIsAligned()); Assert.assertEquals("test", deserializedReq.getBinaryReqs().get(0).getDataBaseName()); Assert.assertEquals("test", deserializedReq.getTabletReqs().get(0).getDataBaseName()); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java new file mode 100644 index 0000000000000..7c13d9764b3cf --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink; + +import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTableModelTabletEventSorter; +import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.write.record.Tablet; +import org.apache.tsfile.write.schema.IMeasurementSchema; +import org.apache.tsfile.write.schema.MeasurementSchema; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class PipeStatementEventSorterTest { + + @Test + public void testTreeModelDeduplicateAndSort() throws Exception { + List schemaList = new ArrayList<>(); + schemaList.add(new MeasurementSchema("s1", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s2", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s3", TSDataType.INT64)); + + Tablet tablet = new Tablet("root.sg.device", schemaList, 30); + + long timestamp = 300; + for (long i = 0; i < 10; i++) { + int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp + i); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, timestamp + i); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp - i); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, timestamp - i); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, timestamp); + } + } + + Set indices = new HashSet<>(); + for (int i = 0; i < 30; i++) { + indices.add((int) tablet.getTimestamp(i)); + } + + Assert.assertFalse(tablet.isSorted()); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, true, null); + + // Sort using Statement + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + + Assert.assertEquals(indices.size(), statement.getRowCount()); + + final long[] timestamps = Arrays.copyOfRange(statement.getTimes(), 0, statement.getRowCount()); + final Object[] columns = statement.getColumns(); + for (int i = 0; i < 3; ++i) { + Assert.assertArrayEquals( + timestamps, Arrays.copyOfRange((long[]) columns[i], 0, statement.getRowCount())); + } + + for (int i = 1; i < statement.getRowCount(); ++i) { + Assert.assertTrue(timestamps[i] > timestamps[i - 1]); + for (int j = 0; j < 3; ++j) { + Assert.assertTrue(((long[]) columns[j])[i] > ((long[]) columns[j])[i - 1]); + } + } + } + + @Test + public void testTreeModelDeduplicate() throws Exception { + final List schemaList = new ArrayList<>(); + schemaList.add(new MeasurementSchema("s1", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s2", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s3", TSDataType.INT64)); + + final Tablet tablet = new Tablet("root.sg.device", schemaList, 10); + + final long timestamp = 300; + for (long i = 0; i < 10; i++) { + final int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp); + for (int s = 0; s < 3; s++) { + tablet.addValue( + schemaList.get(s).getMeasurementName(), + rowIndex, + (i + s) % 3 != 0 ? timestamp + i : null); + } + } + + final Set indices = new HashSet<>(); + for (int i = 0; i < 10; i++) { + indices.add((int) tablet.getTimestamp(i)); + } + + Assert.assertTrue(tablet.isSorted()); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, true, null); + + // Sort using Statement + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + + Assert.assertEquals(indices.size(), statement.getRowCount()); + + final long[] timestamps = Arrays.copyOfRange(statement.getTimes(), 0, statement.getRowCount()); + final Object[] columns = statement.getColumns(); + Assert.assertEquals(timestamps[0] + 8, ((long[]) columns[0])[0]); + for (int i = 1; i < 3; ++i) { + Assert.assertEquals(timestamps[0] + 9, ((long[]) columns[i])[0]); + } + } + + @Test + public void testTreeModelSort() throws Exception { + List schemaList = new ArrayList<>(); + schemaList.add(new MeasurementSchema("s1", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s2", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s3", TSDataType.INT64)); + + Tablet tablet = new Tablet("root.sg.device", schemaList, 30); + + for (long i = 0; i < 10; i++) { + int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, (long) rowIndex + 2); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, (long) rowIndex + 2); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, rowIndex); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, (long) rowIndex); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, (long) rowIndex - 2); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, (long) rowIndex - 2); + } + } + + Set indices = new HashSet<>(); + for (int i = 0; i < 30; i++) { + indices.add((int) tablet.getTimestamp(i)); + } + + Assert.assertFalse(tablet.isSorted()); + + long[] timestamps = Arrays.copyOfRange(tablet.getTimestamps(), 0, tablet.getRowSize()); + for (int i = 0; i < 3; ++i) { + Assert.assertArrayEquals( + timestamps, Arrays.copyOfRange((long[]) tablet.getValues()[i], 0, tablet.getRowSize())); + } + + for (int i = 1; i < tablet.getRowSize(); ++i) { + Assert.assertTrue(timestamps[i] != timestamps[i - 1]); + for (int j = 0; j < 3; ++j) { + Assert.assertNotEquals((long) tablet.getValue(i, j), (long) tablet.getValue(i - 1, j)); + } + } + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, true, null); + + // Sort using Statement + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + + Assert.assertEquals(indices.size(), statement.getRowCount()); + + timestamps = Arrays.copyOfRange(statement.getTimes(), 0, statement.getRowCount()); + final Object[] columns = statement.getColumns(); + for (int i = 0; i < 3; ++i) { + Assert.assertArrayEquals( + timestamps, Arrays.copyOfRange((long[]) columns[i], 0, statement.getRowCount())); + } + + for (int i = 1; i < statement.getRowCount(); ++i) { + Assert.assertTrue(timestamps[i] > timestamps[i - 1]); + for (int j = 0; j < 3; ++j) { + Assert.assertTrue(((long[]) columns[j])[i] > ((long[]) columns[j])[i - 1]); + } + } + } + + @Test + public void testTableModelDeduplicateAndSort() throws Exception { + doTableModelTest(true, true); + } + + @Test + public void testTableModelDeduplicate() throws Exception { + doTableModelTest(true, false); + } + + @Test + public void testTableModelSort() throws Exception { + doTableModelTest(false, true); + } + + @Test + public void testTableModelSort1() throws Exception { + doTableModelTest1(); + } + + public void doTableModelTest(final boolean hasDuplicates, final boolean isUnSorted) + throws Exception { + final Tablet tablet = + PipeTabletEventSorterTest.generateTablet("test", 10, hasDuplicates, isUnSorted); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, false, "test_db"); + + // Sort using Statement + new PipeTableModelTabletEventSorter(statement).sortAndDeduplicateByDevIdTimestamp(); + + long[] timestamps = statement.getTimes(); + final Object[] columns = statement.getColumns(); + for (int i = 1; i < statement.getRowCount(); i++) { + long time = timestamps[i]; + Assert.assertTrue(time > timestamps[i - 1]); + Assert.assertEquals( + ((Binary[]) columns[0])[i], + new Binary(String.valueOf(i / 100).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[1])[i], (long) i); + Assert.assertEquals(((float[]) columns[2])[i], i * 1.0f, 0.001f); + Assert.assertEquals( + ((Binary[]) columns[3])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[4])[i], (long) i); + Assert.assertEquals(((int[]) columns[5])[i], i); + Assert.assertEquals(((double[]) columns[6])[i], i * 0.1, 0.0001); + // DATE is stored as int[] in Statement, not LocalDate[] + LocalDate expectedDate = PipeTabletEventSorterTest.getDate(i); + int expectedDateInt = + org.apache.tsfile.utils.DateUtils.parseDateExpressionToInt(expectedDate); + Assert.assertEquals(((int[]) columns[7])[i], expectedDateInt); + Assert.assertEquals( + ((Binary[]) columns[8])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + } + } + + public void doTableModelTest1() throws Exception { + final Tablet tablet = PipeTabletEventSorterTest.generateTablet("test", 10, false, true); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, false, "test_db"); + + // Sort using Statement + new PipeTableModelTabletEventSorter(statement).sortByTimestampIfNecessary(); + + long[] timestamps = statement.getTimes(); + final Object[] columns = statement.getColumns(); + for (int i = 1; i < statement.getRowCount(); i++) { + long time = timestamps[i]; + Assert.assertTrue(time > timestamps[i - 1]); + Assert.assertEquals( + ((Binary[]) columns[0])[i], + new Binary(String.valueOf(i / 100).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[1])[i], (long) i); + Assert.assertEquals(((float[]) columns[2])[i], i * 1.0f, 0.001f); + Assert.assertEquals( + ((Binary[]) columns[3])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[4])[i], (long) i); + Assert.assertEquals(((int[]) columns[5])[i], i); + Assert.assertEquals(((double[]) columns[6])[i], i * 0.1, 0.0001); + // DATE is stored as int[] in Statement, not LocalDate[] + LocalDate expectedDate = PipeTabletEventSorterTest.getDate(i); + int expectedDateInt = + org.apache.tsfile.utils.DateUtils.parseDateExpressionToInt(expectedDate); + Assert.assertEquals(((int[]) columns[7])[i], expectedDateInt); + Assert.assertEquals( + ((Binary[]) columns[8])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + } + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java new file mode 100644 index 0000000000000..410afc7613030 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java @@ -0,0 +1,607 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under this License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util; + +import org.apache.iotdb.commons.exception.MetadataException; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.ColumnCategory; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.DateUtils; +import org.apache.tsfile.utils.PublicBAOS; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.apache.tsfile.write.record.Tablet; +import org.apache.tsfile.write.schema.IMeasurementSchema; +import org.apache.tsfile.write.schema.MeasurementSchema; +import org.junit.Assert; +import org.junit.Test; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; + +public class TabletStatementConverterTest { + + @Test + public void testConvertStatementToTabletTreeModel() throws MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String deviceName = "root.sg.device"; + final boolean isAligned = true; + + // Generate Tablet and construct Statement from it + final Tablet originalTablet = generateTreeModelTablet(deviceName, columnCount, rowCount); + final InsertTabletStatement statement = + new InsertTabletStatement(originalTablet, isAligned, null); + + // Convert Statement to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + + // Verify conversion + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testConvertStatementToTabletTableModel() throws MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String tableName = "table1"; + final String databaseName = "test_db"; + final boolean isAligned = false; + + // Generate Tablet and construct Statement from it + final Tablet originalTablet = generateTableModelTablet(tableName, columnCount, rowCount); + final InsertTabletStatement statement = + new InsertTabletStatement(originalTablet, isAligned, databaseName); + + // Convert Statement to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + + // Verify conversion + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testDeserializeStatementFromTabletFormat() throws IOException, MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String deviceName = "root.sg.device"; + + // Generate test Tablet + final Tablet originalTablet = generateTreeModelTablet(deviceName, columnCount, rowCount); + + // Serialize Tablet to ByteBuffer + final PublicBAOS byteArrayOutputStream = new PublicBAOS(); + final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream); + // Then serialize the tablet + originalTablet.serialize(outputStream); + // Write isAligned at the end + final boolean isAligned = true; + ReadWriteIOUtils.write(isAligned, outputStream); + + final ByteBuffer buffer = + ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size()); + + // Deserialize Statement from Tablet format + final InsertTabletStatement statement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer); + + // Verify basic information + Assert.assertEquals(deviceName, statement.getDevicePath().getFullPath()); + Assert.assertEquals(rowCount, statement.getRowCount()); + Assert.assertEquals(columnCount, statement.getMeasurements().length); + Assert.assertEquals(isAligned, statement.isAligned()); + + // Verify data by converting Statement back to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testRoundTripConversionTreeModel() throws MetadataException, IOException { + final int columnCount = 1000; + final int rowCount = 100; + final String deviceName = "root.sg.device"; + + // Generate original Tablet + final Tablet originalTablet = generateTreeModelTablet(deviceName, columnCount, rowCount); + + // Serialize Tablet to ByteBuffer + final PublicBAOS byteArrayOutputStream = new PublicBAOS(); + final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream); + originalTablet.serialize(outputStream); + final boolean isAligned = true; + ReadWriteIOUtils.write(isAligned, outputStream); + final ByteBuffer buffer = + ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size()); + + // Deserialize to Statement + final InsertTabletStatement statement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer); + // Convert Statement back to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + + // Verify round trip + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testRoundTripConversionTableModel() throws MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String tableName = "table1"; + final String databaseName = "test_db"; + final boolean isAligned = false; + + // Generate original Tablet for table model + final Tablet originalTablet = generateTableModelTablet(tableName, columnCount, rowCount); + + // Construct Statement from Tablet + final InsertTabletStatement originalStatement = + new InsertTabletStatement(originalTablet, isAligned, databaseName); + + // Convert Statement to Tablet + final Tablet convertedTablet = originalStatement.convertToTablet(); + + // Convert Tablet back to Statement + final InsertTabletStatement convertedStatement = + new InsertTabletStatement(convertedTablet, isAligned, databaseName); + + // Verify round trip: original Tablet should equal converted Tablet + assertTabletsEqual(originalTablet, convertedTablet); + } + + /** + * Generate a Tablet for tree model with all data types and specified number of columns and rows. + * + * @param deviceName device name + * @param columnCount number of columns + * @param rowCount number of rows + * @return Tablet with test data + */ + private Tablet generateTreeModelTablet( + final String deviceName, final int columnCount, final int rowCount) { + final List schemaList = new ArrayList<>(); + final TSDataType[] dataTypes = new TSDataType[columnCount]; + final String[] measurementNames = new String[columnCount]; + final Object[] columnData = new Object[columnCount]; + + // Create schemas and generate test data + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = ALL_DATA_TYPES[col % ALL_DATA_TYPES.length]; + final String measurementName = "col_" + col + "_" + dataType.name(); + schemaList.add(new MeasurementSchema(measurementName, dataType)); + dataTypes[col] = dataType; + measurementNames[col] = measurementName; + + // Generate test data for this column + switch (dataType) { + case BOOLEAN: + final boolean[] boolValues = new boolean[rowCount]; + for (int row = 0; row < rowCount; row++) { + boolValues[row] = (row + col) % 2 == 0; + } + columnData[col] = boolValues; + break; + case INT32: + final int[] intValues = new int[rowCount]; + for (int row = 0; row < rowCount; row++) { + intValues[row] = row * 100 + col; + } + columnData[col] = intValues; + break; + case DATE: + final LocalDate[] dateValues = new LocalDate[rowCount]; + for (int row = 0; row < rowCount; row++) { + // Generate valid dates starting from 2024-01-01 + dateValues[row] = LocalDate.of(2024, 1, 1).plusDays((row + col) % 365); + } + columnData[col] = dateValues; + break; + case INT64: + case TIMESTAMP: + final long[] longValues = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + longValues[row] = (long) row * 1000L + col; + } + columnData[col] = longValues; + break; + case FLOAT: + final float[] floatValues = new float[rowCount]; + for (int row = 0; row < rowCount; row++) { + floatValues[row] = row * 1.5f + col * 0.1f; + } + columnData[col] = floatValues; + break; + case DOUBLE: + final double[] doubleValues = new double[rowCount]; + for (int row = 0; row < rowCount; row++) { + doubleValues[row] = row * 2.5 + col * 0.01; + } + columnData[col] = doubleValues; + break; + case TEXT: + case STRING: + case BLOB: + final Binary[] binaryValues = new Binary[rowCount]; + for (int row = 0; row < rowCount; row++) { + binaryValues[row] = new Binary(("value_row_" + row + "_col_" + col).getBytes()); + } + columnData[col] = binaryValues; + break; + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + // Create and fill tablet + final Tablet tablet = new Tablet(deviceName, schemaList, rowCount); + final long[] times = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + times[row] = row * 1000L; + final int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, times[row]); + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = dataTypes[col]; + final Object data = columnData[col]; + switch (dataType) { + case BOOLEAN: + tablet.addValue(measurementNames[col], rowIndex, ((boolean[]) data)[row]); + break; + case INT32: + tablet.addValue(measurementNames[col], rowIndex, ((int[]) data)[row]); + break; + case DATE: + tablet.addValue(measurementNames[col], rowIndex, ((LocalDate[]) data)[row]); + break; + case INT64: + case TIMESTAMP: + tablet.addValue(measurementNames[col], rowIndex, ((long[]) data)[row]); + break; + case FLOAT: + tablet.addValue(measurementNames[col], rowIndex, ((float[]) data)[row]); + break; + case DOUBLE: + tablet.addValue(measurementNames[col], rowIndex, ((double[]) data)[row]); + break; + case TEXT: + case STRING: + case BLOB: + tablet.addValue(measurementNames[col], rowIndex, ((Binary[]) data)[row]); + break; + } + } + } + + return tablet; + } + + /** + * Generate a Tablet for table model with all data types and specified number of columns and rows. + * + * @param tableName table name + * @param columnCount number of columns + * @param rowCount number of rows + * @return Tablet with test data + */ + private Tablet generateTableModelTablet( + final String tableName, final int columnCount, final int rowCount) { + final List schemaList = new ArrayList<>(); + final TSDataType[] dataTypes = new TSDataType[columnCount]; + final String[] measurementNames = new String[columnCount]; + final List columnTypes = new ArrayList<>(); + final Object[] columnData = new Object[columnCount]; + + // Create schemas and generate test data + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = ALL_DATA_TYPES[col % ALL_DATA_TYPES.length]; + final String measurementName = "col_" + col + "_" + dataType.name(); + schemaList.add(new MeasurementSchema(measurementName, dataType)); + dataTypes[col] = dataType; + measurementNames[col] = measurementName; + // For table model, all columns are FIELD (can be TAG/ATTRIBUTE/FIELD, but we use FIELD for + // simplicity) + columnTypes.add(ColumnCategory.FIELD); + + // Generate test data for this column + switch (dataType) { + case BOOLEAN: + final boolean[] boolValues = new boolean[rowCount]; + for (int row = 0; row < rowCount; row++) { + boolValues[row] = (row + col) % 2 == 0; + } + columnData[col] = boolValues; + break; + case INT32: + final int[] intValues = new int[rowCount]; + for (int row = 0; row < rowCount; row++) { + intValues[row] = row * 100 + col; + } + columnData[col] = intValues; + break; + case DATE: + final LocalDate[] dateValues = new LocalDate[rowCount]; + for (int row = 0; row < rowCount; row++) { + // Generate valid dates starting from 2024-01-01 + dateValues[row] = LocalDate.of(2024, 1, 1).plusDays((row + col) % 365); + } + columnData[col] = dateValues; + break; + case INT64: + case TIMESTAMP: + final long[] longValues = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + longValues[row] = (long) row * 1000L + col; + } + columnData[col] = longValues; + break; + case FLOAT: + final float[] floatValues = new float[rowCount]; + for (int row = 0; row < rowCount; row++) { + floatValues[row] = row * 1.5f + col * 0.1f; + } + columnData[col] = floatValues; + break; + case DOUBLE: + final double[] doubleValues = new double[rowCount]; + for (int row = 0; row < rowCount; row++) { + doubleValues[row] = row * 2.5 + col * 0.01; + } + columnData[col] = doubleValues; + break; + case TEXT: + case STRING: + case BLOB: + final Binary[] binaryValues = new Binary[rowCount]; + for (int row = 0; row < rowCount; row++) { + binaryValues[row] = new Binary(("value_row_" + row + "_col_" + col).getBytes()); + } + columnData[col] = binaryValues; + break; + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + // Create tablet using table model constructor: Tablet(String, List, List, + // List, int) + final List measurementNameList = IMeasurementSchema.getMeasurementNameList(schemaList); + final List dataTypeList = IMeasurementSchema.getDataTypeList(schemaList); + final Tablet tablet = + new Tablet(tableName, measurementNameList, dataTypeList, columnTypes, rowCount); + tablet.initBitMaps(); + + // Fill tablet with data + final long[] times = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + times[row] = row * 1000L; + final int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, times[row]); + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = dataTypes[col]; + final Object data = columnData[col]; + switch (dataType) { + case BOOLEAN: + tablet.addValue(measurementNames[col], rowIndex, ((boolean[]) data)[row]); + break; + case INT32: + tablet.addValue(measurementNames[col], rowIndex, ((int[]) data)[row]); + break; + case DATE: + tablet.addValue(measurementNames[col], rowIndex, ((LocalDate[]) data)[row]); + break; + case INT64: + case TIMESTAMP: + tablet.addValue(measurementNames[col], rowIndex, ((long[]) data)[row]); + break; + case FLOAT: + tablet.addValue(measurementNames[col], rowIndex, ((float[]) data)[row]); + break; + case DOUBLE: + tablet.addValue(measurementNames[col], rowIndex, ((double[]) data)[row]); + break; + case TEXT: + case STRING: + case BLOB: + tablet.addValue(measurementNames[col], rowIndex, ((Binary[]) data)[row]); + break; + } + } + tablet.setRowSize(rowIndex + 1); + } + + return tablet; + } + + /** + * Check if two Tablets are equal in all aspects. + * + * @param expected expected Tablet + * @param actual actual Tablet + */ + private void assertTabletsEqual(final Tablet expected, final Tablet actual) { + Assert.assertEquals(expected.getDeviceId(), actual.getDeviceId()); + Assert.assertEquals(expected.getRowSize(), actual.getRowSize()); + Assert.assertEquals(expected.getSchemas().size(), actual.getSchemas().size()); + + // Verify timestamps + final long[] expectedTimes = expected.getTimestamps(); + final long[] actualTimes = actual.getTimestamps(); + Assert.assertArrayEquals(expectedTimes, actualTimes); + + // Verify each column + final int columnCount = expected.getSchemas().size(); + final int rowCount = expected.getRowSize(); + final Object[] expectedValues = expected.getValues(); + final Object[] actualValues = actual.getValues(); + + for (int col = 0; col < columnCount; col++) { + final IMeasurementSchema schema = expected.getSchemas().get(col); + final TSDataType dataType = schema.getType(); + final Object expectedColumn = expectedValues[col]; + final Object actualColumn = actualValues[col]; + + Assert.assertNotNull(actualColumn); + + // Verify each row in this column + for (int row = 0; row < rowCount; row++) { + switch (dataType) { + case BOOLEAN: + final boolean expectedBool = ((boolean[]) expectedColumn)[row]; + final boolean actualBool = ((boolean[]) actualColumn)[row]; + Assert.assertEquals(expectedBool, actualBool); + break; + case INT32: + final int expectedInt = ((int[]) expectedColumn)[row]; + final int actualInt = ((int[]) actualColumn)[row]; + Assert.assertEquals(expectedInt, actualInt); + break; + case DATE: + final LocalDate expectedDate = ((LocalDate[]) expectedColumn)[row]; + final LocalDate actualDate = ((LocalDate[]) actualColumn)[row]; + Assert.assertEquals(expectedDate, actualDate); + break; + case INT64: + case TIMESTAMP: + final long expectedLong = ((long[]) expectedColumn)[row]; + final long actualLong = ((long[]) actualColumn)[row]; + Assert.assertEquals(expectedLong, actualLong); + break; + case FLOAT: + final float expectedFloat = ((float[]) expectedColumn)[row]; + final float actualFloat = ((float[]) actualColumn)[row]; + Assert.assertEquals(expectedFloat, actualFloat, 0.0001f); + break; + case DOUBLE: + final double expectedDouble = ((double[]) expectedColumn)[row]; + final double actualDouble = ((double[]) actualColumn)[row]; + Assert.assertEquals(expectedDouble, actualDouble, 0.0001); + break; + case TEXT: + case STRING: + case BLOB: + final Binary expectedBinary = ((Binary[]) expectedColumn)[row]; + final Binary actualBinary = ((Binary[]) actualColumn)[row]; + Assert.assertNotNull(actualBinary); + Assert.assertEquals(expectedBinary, actualBinary); + break; + } + } + } + } + + /** + * Check if a Tablet and an InsertTabletStatement contain the same data. + * + * @param tablet Tablet + * @param statement InsertTabletStatement + */ + private void assertTabletAndStatementEqual( + final Tablet tablet, final InsertTabletStatement statement) { + Assert.assertEquals( + tablet.getDeviceId(), + statement.getDevicePath() != null ? statement.getDevicePath().getFullPath() : null); + Assert.assertEquals(tablet.getRowSize(), statement.getRowCount()); + Assert.assertEquals(tablet.getSchemas().size(), statement.getMeasurements().length); + + // Verify timestamps + Assert.assertArrayEquals(tablet.getTimestamps(), statement.getTimes()); + + // Verify each column + final int columnCount = tablet.getSchemas().size(); + final int rowCount = tablet.getRowSize(); + final Object[] tabletValues = tablet.getValues(); + final Object[] statementColumns = statement.getColumns(); + + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = tablet.getSchemas().get(col).getType(); + final Object tabletColumn = tabletValues[col]; + final Object statementColumn = statementColumns[col]; + + Assert.assertNotNull(statementColumn); + + // Verify each row + for (int row = 0; row < rowCount; row++) { + switch (dataType) { + case BOOLEAN: + final boolean tabletBool = ((boolean[]) tabletColumn)[row]; + final boolean statementBool = ((boolean[]) statementColumn)[row]; + Assert.assertEquals(tabletBool, statementBool); + break; + case INT32: + final int tabletInt = ((int[]) tabletColumn)[row]; + final int statementInt = ((int[]) statementColumn)[row]; + Assert.assertEquals(tabletInt, statementInt); + break; + case DATE: + // DATE type: Tablet uses LocalDate[], Statement uses int[] (YYYYMMDD format) + final LocalDate tabletDate = ((LocalDate[]) tabletColumn)[row]; + final int statementDateInt = ((int[]) statementColumn)[row]; + // Convert LocalDate to int (YYYYMMDD format) for comparison + final int tabletDateInt = DateUtils.parseDateExpressionToInt(tabletDate); + Assert.assertEquals(tabletDateInt, statementDateInt); + break; + case INT64: + case TIMESTAMP: + final long tabletLong = ((long[]) tabletColumn)[row]; + final long statementLong = ((long[]) statementColumn)[row]; + Assert.assertEquals(tabletLong, statementLong); + break; + case FLOAT: + final float tabletFloat = ((float[]) tabletColumn)[row]; + final float statementFloat = ((float[]) statementColumn)[row]; + Assert.assertEquals(tabletFloat, statementFloat, 0.0001f); + break; + case DOUBLE: + final double tabletDouble = ((double[]) tabletColumn)[row]; + final double statementDouble = ((double[]) statementColumn)[row]; + Assert.assertEquals(tabletDouble, statementDouble, 0.0001); + break; + case TEXT: + case STRING: + case BLOB: + final Binary tabletBinary = ((Binary[]) tabletColumn)[row]; + final Binary statementBinary = ((Binary[]) statementColumn)[row]; + Assert.assertNotNull(statementBinary); + Assert.assertEquals(tabletBinary, statementBinary); + break; + } + } + } + } + + // Define all supported data types + private static final TSDataType[] ALL_DATA_TYPES = { + TSDataType.BOOLEAN, + TSDataType.INT32, + TSDataType.INT64, + TSDataType.FLOAT, + TSDataType.DOUBLE, + TSDataType.TEXT, + TSDataType.TIMESTAMP, + TSDataType.DATE, + TSDataType.BLOB, + TSDataType.STRING + }; +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java index 4ca38a5c8d9f3..8fc7d01437cfc 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java @@ -1910,5 +1910,10 @@ public QueryType getQueryType() { public boolean isUserQuery() { return false; } + + @Override + public String getClientHostname() { + return SessionConfig.DEFAULT_HOST; + } } } diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java index e60a14b727cad..79c031560973a 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.AlignedDeviceEntry; @@ -402,11 +401,6 @@ public TableFunction getTableFunction(String functionName) { return null; } - @Override - public IModelFetcher getModelFetcher() { - return null; - } - private static final DataPartition DATA_PARTITION = MockTSBSDataPartition.constructDataPartition(); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index e56b48936b96d..7bbfe150ade45 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -19,7 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; @@ -378,7 +377,6 @@ public void testForecastFunction() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan @@ -439,7 +437,6 @@ public void testForecastFunctionWithNoLowerCase() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java index aa9fcdfd1b514..4b1d18944b732 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java @@ -19,8 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -32,12 +30,10 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.function.Exclude; import org.apache.iotdb.db.queryengine.plan.function.Repeat; import org.apache.iotdb.db.queryengine.plan.function.Split; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.SubtractionResolver; @@ -560,21 +556,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - String modelId = "timer_xl"; - IModelFetcher fetcher = Mockito.mock(IModelFetcher.class); - ModelInferenceDescriptor descriptor = Mockito.mock(ModelInferenceDescriptor.class); - Mockito.when(descriptor.getTargetAINode()).thenReturn(new TEndPoint("127.0.0.1", 10810)); - ModelInformation modelInformation = Mockito.mock(ModelInformation.class); - Mockito.when(modelInformation.available()).thenReturn(true); - Mockito.when(modelInformation.getInputShape()).thenReturn(new int[] {1440, 96}); - Mockito.when(descriptor.getModelInformation()).thenReturn(modelInformation); - Mockito.when(descriptor.getModelName()).thenReturn(modelId); - Mockito.when(fetcher.fetchModel(modelId)).thenReturn(descriptor); - return fetcher; - } - private static final DataPartition TABLE_DATA_PARTITION = MockTableModelDataPartition.constructDataPartition(DB1); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java index a4cef177d8132..850755702016c 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java @@ -83,6 +83,8 @@ public class PlanTester { public List getDataNodeLocations(String table) { switch (table) { case "queries": + case "current_queries": + case "queries_costs_histogram": return ImmutableList.of( genDataNodeLocation(1, "192.0.1.1"), genDataNodeLocation(2, "192.0.1.2")); default: diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java new file mode 100644 index 0000000000000..1e3163321ec45 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.db.queryengine.plan.relational.planner.informationschema; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlanTester; + +import com.google.common.collect.ImmutableList; +import org.junit.Test; + +import java.util.Optional; + +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.BIN; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.CLIENT_IP; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.COST_TIME; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.DATA_NODE_ID_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.END_TIME_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.NUMS; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.QUERY_ID_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.START_TIME_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATE_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.infoSchemaTableScan; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.output; + +public class CurrentQueriesTest { + private final PlanTester planTester = new PlanTester(); + + @Test + public void testCurrentQueries() { + LogicalQueryPlan logicalQueryPlan = + planTester.createPlan("select * from information_schema.current_queries"); + assertPlan( + logicalQueryPlan, + output( + infoSchemaTableScan( + "information_schema.current_queries", + Optional.empty(), + ImmutableList.of( + QUERY_ID_TABLE_MODEL, + STATE_TABLE_MODEL, + START_TIME_TABLE_MODEL, + END_TIME_TABLE_MODEL, + DATA_NODE_ID_TABLE_MODEL, + COST_TIME, + STATEMENT_TABLE_MODEL, + USER_TABLE_MODEL, + CLIENT_IP)))); + + // - Exchange + // Output - Collect - Exchange + assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange()))); + // TableScan + assertPlan( + planTester.getFragmentPlan(1), + infoSchemaTableScan("information_schema.current_queries", Optional.of(1))); + // TableScan + assertPlan( + planTester.getFragmentPlan(2), + infoSchemaTableScan("information_schema.current_queries", Optional.of(2))); + } + + @Test + public void testQueriesCostsHistogram() { + LogicalQueryPlan logicalQueryPlan = + planTester.createPlan("select * from information_schema.queries_costs_histogram"); + assertPlan( + logicalQueryPlan, + output( + infoSchemaTableScan( + "information_schema.queries_costs_histogram", + Optional.empty(), + ImmutableList.of(BIN, NUMS)))); + + // - Exchange + // Output - Collect - Exchange + assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange()))); + // TableScan + assertPlan( + planTester.getFragmentPlan(1), + infoSchemaTableScan("information_schema.queries_costs_histogram", Optional.of(1))); + // TableScan + assertPlan( + planTester.getFragmentPlan(2), + infoSchemaTableScan("information_schema.queries_costs_histogram", Optional.of(2))); + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ShowQueriesTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/ShowQueriesTest.java similarity index 94% rename from iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ShowQueriesTest.java rename to iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/ShowQueriesTest.java index 63b7b783d7473..7161c68f4e9ef 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ShowQueriesTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/ShowQueriesTest.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.iotdb.db.queryengine.plan.relational.analyzer; +package org.apache.iotdb.db.queryengine.plan.relational.planner.informationschema; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.planner.PlanTester; @@ -32,7 +32,9 @@ import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.QUERY_ID_TABLE_MODEL; import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.START_TIME_TABLE_MODEL; import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL; import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange; @@ -61,8 +63,8 @@ public void testNormal() { START_TIME_TABLE_MODEL, DATA_NODE_ID_TABLE_MODEL, ELAPSED_TIME_TABLE_MODEL, - STATEMENT.toLowerCase(Locale.ENGLISH), - USER.toLowerCase(Locale.ENGLISH))))); + STATEMENT_TABLE_MODEL, + USER_TABLE_MODEL)))); // - Exchange // Output - Collect - Exchange diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java new file mode 100644 index 0000000000000..3967409db52de --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.storageengine.dataregion.memtable; + +import org.apache.iotdb.commons.path.AlignedFullPath; +import org.apache.iotdb.commons.path.NonAlignedFullPath; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.storageengine.dataregion.read.filescan.IChunkHandle; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.StringArrayDeviceID; +import org.apache.tsfile.read.common.TimeRange; +import org.apache.tsfile.read.filter.operator.TimeFilterOperators; +import org.apache.tsfile.utils.BitMap; +import org.apache.tsfile.write.schema.IMeasurementSchema; +import org.apache.tsfile.write.schema.MeasurementSchema; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +@RunWith(Parameterized.class) +public class WritableMemChunkRegionScanTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new Object[][] {{0}, {1000}, {10000}, {20000}}); + } + + private int defaultTvListThreshold; + private int tvListSortThreshold; + + public WritableMemChunkRegionScanTest(int tvListSortThreshold) { + this.tvListSortThreshold = tvListSortThreshold; + } + + @Before + public void setup() { + defaultTvListThreshold = IoTDBDescriptor.getInstance().getConfig().getTvListSortThreshold(); + IoTDBDescriptor.getInstance().getConfig().setTVListSortThreshold(tvListSortThreshold); + } + + @After + public void tearDown() { + IoTDBDescriptor.getInstance().getConfig().setTVListSortThreshold(defaultTvListThreshold); + } + + @Test + public void testAlignedWritableMemChunkRegionScan() { + PrimitiveMemTable memTable = new PrimitiveMemTable("root.test", "0"); + try { + List measurementSchemas = + Arrays.asList( + new MeasurementSchema("s1", TSDataType.INT32), + new MeasurementSchema("s2", TSDataType.INT32), + new MeasurementSchema("s3", TSDataType.INT32)); + AlignedWritableMemChunk writableMemChunk = null; + int size = 100000; + for (int i = 0; i < size; i++) { + if (i <= 10000) { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {1, null, 1}); + } else if (i <= 20000) { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {null, null, 2}); + } else if (i <= 30000) { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {3, null, null}); + } else { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {4, 4, 4}); + } + } + writableMemChunk = + (AlignedWritableMemChunk) + memTable.getWritableMemChunk(new StringArrayDeviceID("root.test.d1"), ""); + List bitMaps = new ArrayList<>(); + long[] timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + bitMaps, + true, + null); + Assert.assertEquals(2, timestamps.length); + Assert.assertEquals(0, timestamps[0]); + Assert.assertFalse(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertFalse(bitMaps.get(0).isMarked(2)); + Assert.assertTrue(bitMaps.get(1).isMarked(0)); + Assert.assertFalse(bitMaps.get(1).isMarked(1)); + Assert.assertTrue(bitMaps.get(1).isMarked(2)); + Assert.assertEquals(30001, timestamps[1]); + + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(new TimeRange(0, 12000))), + bitMaps, + true, + new TimeFilterOperators.TimeGt(10000000)); + Assert.assertEquals(0, timestamps.length); + + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(new TimeRange(0, 12000))), + bitMaps, + true, + new TimeFilterOperators.TimeGt(11000)); + + Assert.assertEquals(3, timestamps.length); + Assert.assertEquals(12001, timestamps[0]); + Assert.assertTrue(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertFalse(bitMaps.get(0).isMarked(2)); + Assert.assertEquals(20001, timestamps[1]); + Assert.assertFalse(bitMaps.get(1).isMarked(0)); + Assert.assertTrue(bitMaps.get(1).isMarked(1)); + Assert.assertTrue(bitMaps.get(1).isMarked(2)); + Assert.assertEquals(30001, timestamps[2]); + Assert.assertTrue(bitMaps.get(2).isMarked(0)); + Assert.assertFalse(bitMaps.get(2).isMarked(1)); + Assert.assertTrue(bitMaps.get(2).isMarked(2)); + + writableMemChunk.writeAlignedPoints( + 1000001, new Object[] {1, null, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints( + 1000002, new Object[] {null, 1, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints(1000002, new Object[] {1, 1, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints( + 1000003, new Object[] {1, null, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints(1000004, new Object[] {1, null, 1}, measurementSchemas); + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + bitMaps, + true, + new TimeFilterOperators.TimeGt(1000000)); + Assert.assertEquals(3, timestamps.length); + Assert.assertEquals(1000001, timestamps[0]); + Assert.assertFalse(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertTrue(bitMaps.get(0).isMarked(2)); + Assert.assertEquals(1000002, timestamps[1]); + Assert.assertTrue(bitMaps.get(1).isMarked(0)); + Assert.assertFalse(bitMaps.get(1).isMarked(1)); + Assert.assertTrue(bitMaps.get(1).isMarked(2)); + Assert.assertEquals(1000004, timestamps[2]); + Assert.assertTrue(bitMaps.get(2).isMarked(0)); + Assert.assertTrue(bitMaps.get(2).isMarked(1)); + Assert.assertFalse(bitMaps.get(2).isMarked(2)); + + Map> chunkHandleMap = new HashMap<>(); + memTable.queryForDeviceRegionScan( + new StringArrayDeviceID("root.test.d1"), + true, + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1000000)); + Assert.assertEquals(3, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {1000001, 1000001}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000002, 1000002}, chunkHandleMap.get("s2").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000004, 1000004}, chunkHandleMap.get("s3").get(0).getPageStatisticsTime()); + + memTable.queryForSeriesRegionScan( + new AlignedFullPath( + new StringArrayDeviceID("root.test.d1"), + IMeasurementSchema.getMeasurementNameList(measurementSchemas), + measurementSchemas), + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1000000)); + Assert.assertEquals(3, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {1000001, 1000001}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000002, 1000002}, chunkHandleMap.get("s2").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000004, 1000004}, chunkHandleMap.get("s3").get(0).getPageStatisticsTime()); + } finally { + memTable.release(); + } + } + + @Test + public void testTableWritableMemChunkRegionScan() { + List measurementSchemas = + Arrays.asList( + new MeasurementSchema("s1", TSDataType.INT32), + new MeasurementSchema("s2", TSDataType.INT32), + new MeasurementSchema("s3", TSDataType.INT32)); + AlignedWritableMemChunk writableMemChunk = + new AlignedWritableMemChunk(measurementSchemas, true); + int size = 100000; + for (int i = 0; i < size; i++) { + if (i <= 10000) { + writableMemChunk.writeAlignedPoints(i, new Object[] {1, null, 1}, measurementSchemas); + } else if (i <= 20000) { + writableMemChunk.writeAlignedPoints(i, new Object[] {null, null, 2}, measurementSchemas); + } else if (i <= 30000) { + writableMemChunk.writeAlignedPoints(i, new Object[] {3, null, null}, measurementSchemas); + } else { + writableMemChunk.writeAlignedPoints(i, new Object[] {4, 4, 4}, measurementSchemas); + } + } + List bitMaps = new ArrayList<>(); + long[] timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + bitMaps, + false, + null); + Assert.assertEquals(1, timestamps.length); + Assert.assertEquals(0, timestamps[0]); + Assert.assertFalse(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertFalse(bitMaps.get(0).isMarked(2)); + + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(new TimeRange(0, 12000))), + bitMaps, + false, + new TimeFilterOperators.TimeGt(11000)); + + Assert.assertEquals(1, timestamps.length); + Assert.assertEquals(11001, timestamps[0]); + Assert.assertTrue(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertTrue(bitMaps.get(0).isMarked(2)); + } + + @Test + public void testNonAlignedWritableMemChunkRegionScan() { + PrimitiveMemTable memTable = new PrimitiveMemTable("root.test", "0"); + try { + MeasurementSchema measurementSchema = new MeasurementSchema("s1", TSDataType.INT32); + int size = 100000; + for (int i = 0; i < size; i++) { + memTable.write( + new StringArrayDeviceID("root.test.d1"), + Collections.singletonList(measurementSchema), + i, + new Object[] {i}); + } + WritableMemChunk writableMemChunk = + (WritableMemChunk) + memTable.getWritableMemChunk(new StringArrayDeviceID("root.test.d1"), "s1"); + Optional timestamp = writableMemChunk.getAnySatisfiedTimestamp(null, null); + Assert.assertTrue(timestamp.isPresent()); + Assert.assertEquals(0, timestamp.get().longValue()); + + timestamp = + writableMemChunk.getAnySatisfiedTimestamp( + null, new TimeFilterOperators.TimeBetweenAnd(1000L, 2000L)); + Assert.assertTrue(timestamp.isPresent()); + Assert.assertEquals(1000, timestamp.get().longValue()); + + timestamp = + writableMemChunk.getAnySatisfiedTimestamp( + Collections.singletonList(new TimeRange(1, 1500)), + new TimeFilterOperators.TimeBetweenAnd(1000L, 2000L)); + Assert.assertTrue(timestamp.isPresent()); + Assert.assertEquals(1501, timestamp.get().longValue()); + + timestamp = + writableMemChunk.getAnySatisfiedTimestamp( + Collections.singletonList(new TimeRange(1, 1500)), + new TimeFilterOperators.TimeBetweenAnd(100000L, 200000L)); + Assert.assertFalse(timestamp.isPresent()); + + Map> chunkHandleMap = new HashMap<>(); + memTable.queryForDeviceRegionScan( + new StringArrayDeviceID("root.test.d1"), + false, + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1)); + Assert.assertEquals(1, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {2, 2}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + memTable.queryForSeriesRegionScan( + new NonAlignedFullPath(new StringArrayDeviceID("root.test.d1"), measurementSchema), + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1)); + Assert.assertEquals(1, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {2, 2}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + } finally { + memTable.release(); + } + } +} diff --git a/iotdb-core/node-commons/pom.xml b/iotdb-core/node-commons/pom.xml index 85ff69ee8ac7a..e7c508c195d55 100644 --- a/iotdb-core/node-commons/pom.xml +++ b/iotdb-core/node-commons/pom.xml @@ -65,6 +65,11 @@ iotdb-thrift-confignode 2.0.6-SNAPSHOT + + org.apache.iotdb + iotdb-thrift-ainode + 2.0.6-SNAPSHOT + org.apache.iotdb iotdb-thrift diff --git a/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template b/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template index 5a32d6a123140..0a9a47ee89073 100644 --- a/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template +++ b/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template @@ -1080,6 +1080,12 @@ max_tsblock_line_number=1000 # Datatype: long slow_query_threshold=10000 +# Time window threshold(min) for record of history queries. +# effectiveMode: hot_reload +# Datatype: int +# Privilege: SYSTEM +query_cost_stat_window=0 + # The max executing time of query. unit: ms # effectiveMode: restart # Datatype: int diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java index 106d67b6279d9..115f322348c06 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java @@ -20,6 +20,7 @@ package org.apache.iotdb.commons.client; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncConfigNodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeExternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeInternalServiceClient; @@ -390,4 +391,31 @@ public GenericKeyedObjectPool create return clientPool; } } + + public static class AsyncAINodeHeartbeatServiceClientPoolFactory + implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AsyncAINodeInternalServiceClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) + .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) + .setPrintLogWhenEncounterException(false) + .build(), + ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java similarity index 83% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java rename to iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java index 26130287697c4..8cbd55759633f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.db.protocol.client.ainode; +package org.apache.iotdb.commons.client.async; import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -35,20 +35,20 @@ import java.io.IOException; -public class AsyncAINodeServiceClient extends IAINodeRPCService.AsyncClient +public class AsyncAINodeInternalServiceClient extends IAINodeRPCService.AsyncClient implements ThriftClient { private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - private final boolean printLogWhenEncounterException; private final TEndPoint endPoint; - private final ClientManager clientManager; + private final boolean printLogWhenEncounterException; + private final ClientManager clientManager; - public AsyncAINodeServiceClient( + public AsyncAINodeInternalServiceClient( ThriftClientProperty property, TEndPoint endPoint, TAsyncClientManager tClientManager, - ClientManager clientManager) + ClientManager clientManager) throws IOException { super( property.getProtocolFactory(), @@ -122,10 +122,10 @@ public boolean isReady() { } public static class Factory - extends AsyncThriftClientFactory { + extends AsyncThriftClientFactory { public Factory( - ClientManager clientManager, + ClientManager clientManager, ThriftClientProperty thriftClientProperty, String threadName) { super(clientManager, thriftClientProperty, threadName); @@ -133,14 +133,15 @@ public Factory( @Override public void destroyObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { pooledObject.getObject().close(); } @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { + public PooledObject makeObject(TEndPoint endPoint) + throws Exception { return new DefaultPooledObject<>( - new AsyncAINodeServiceClient( + new AsyncAINodeInternalServiceClient( thriftClientProperty, endPoint, tManagers[clientCnt.incrementAndGet() % tManagers.length], @@ -149,7 +150,7 @@ public PooledObject makeObject(TEndPoint endPoint) thr @Override public boolean validateObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { return pooledObject.getObject().isReady(); } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java index 390e9f80e9b39..6f9f95ca8fe88 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java @@ -35,6 +35,7 @@ public enum ThreadName { FRAGMENT_INSTANCE_NOTIFICATION("Fragment-Instance-Notification"), FRAGMENT_INSTANCE_DISPATCH("Fragment-Instance-Dispatch"), DRIVER_TASK_SCHEDULER_NOTIFICATION("Driver-Task-Scheduler-Notification"), + EXPIRED_QUERIES_INFO_CLEAR("Expired-Queries-Info-Clear"), // -------------------------- MPP -------------------------- MPP_COORDINATOR_SCHEDULED_EXECUTOR("MPP-Coordinator-Scheduled-Executor"), MPP_DATA_EXCHANGE_TASK_EXECUTOR("MPP-Data-Exchange-Task-Executors"), diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java index dfbb2104cfb8d..ab28952dbcd07 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java @@ -21,9 +21,9 @@ import org.apache.iotdb.commons.conf.ConfigurationFileUtils; import org.apache.iotdb.commons.conf.IoTDBConstant; +import org.apache.iotdb.commons.utils.FileUtils; import org.apache.ratis.util.AutoCloseableLock; -import org.apache.ratis.util.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -184,15 +184,8 @@ private void replaceFormalFile() throws IOException { "Delete formal system properties file fail: %s", formalFile.getAbsoluteFile()); throw new IOException(msg); } - try { - FileUtils.move(tmpFile.toPath(), formalFile.toPath()); - } catch (IOException e) { - String msg = - String.format( - "Failed to replace formal system properties file, you may manually rename it: %s -> %s", - tmpFile.getAbsolutePath(), formalFile.getAbsolutePath()); - throw new IOException(msg, e); - } + + FileUtils.moveFileSafe(tmpFile, formalFile); } public void resetFilePath(String filePath) { diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java index 3fa107685438e..01968833db7fb 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -32,9 +32,12 @@ public class ModelInformation { + private static final int[] DEFAULT_MODEL_INPUT_SHAPE = new int[] {2880, 1}; + private static final int[] DEFAULT_MODEL_OUTPUT_SHAPE = new int[] {720, 1}; + ModelType modelType; - private final String modelName; + private final String modelId; private final int[] inputShape; @@ -48,9 +51,17 @@ public class ModelInformation { String attribute = ""; + public ModelInformation(String modelId) { + this.modelId = modelId; + this.inputShape = DEFAULT_MODEL_INPUT_SHAPE; + this.inputDataType = new TSDataType[] {TSDataType.DOUBLE}; + this.outputShape = DEFAULT_MODEL_OUTPUT_SHAPE; + this.outputDataType = new TSDataType[] {TSDataType.DOUBLE}; + } + public ModelInformation( ModelType modelType, - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, @@ -58,7 +69,7 @@ public ModelInformation( String attribute, ModelStatus status) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -68,14 +79,14 @@ public ModelInformation( } public ModelInformation( - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, TSDataType[] outputDataType, String attribute) { this.modelType = ModelType.USER_DEFINED; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -83,9 +94,9 @@ public ModelInformation( this.attribute = attribute; } - public ModelInformation(String modelName, ModelStatus status) { + public ModelInformation(String modelId, ModelStatus status) { this.modelType = ModelType.BUILT_IN_FORECAST; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[0]; this.outputShape = new int[0]; this.outputDataType = new TSDataType[0]; @@ -94,9 +105,9 @@ public ModelInformation(String modelName, ModelStatus status) { } // init built-in modelInformation - public ModelInformation(ModelType modelType, String modelName) { + public ModelInformation(ModelType modelType, String modelId) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[2]; this.outputShape = new int[2]; this.inputDataType = new TSDataType[0]; @@ -116,8 +127,8 @@ public void updateStatus(ModelStatus status) { this.status = status; } - public String getModelName() { - return modelName; + public String getModelId() { + return modelId; } public void setInputLength(int length) { @@ -197,7 +208,7 @@ public void setAttribute(String attribute) { public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -222,7 +233,7 @@ public void serialize(DataOutputStream stream) throws IOException { public void serialize(FileOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -247,7 +258,7 @@ public void serialize(FileOutputStream stream) throws IOException { public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(modelType.ordinal(), byteBuffer); ReadWriteIOUtils.write(status.ordinal(), byteBuffer); - ReadWriteIOUtils.write(modelName, byteBuffer); + ReadWriteIOUtils.write(modelId, byteBuffer); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -353,7 +364,7 @@ public static ModelInformation deserialize(InputStream stream) throws IOExceptio public ByteBuffer serializeShowModelResult() throws IOException { PublicBAOS buffer = new PublicBAOS(); DataOutputStream stream = new DataOutputStream(buffer); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelType.toString(), stream); ReadWriteIOUtils.write(status.toString(), stream); ReadWriteIOUtils.write(Arrays.toString(inputShape), stream); @@ -370,7 +381,7 @@ public ByteBuffer serializeShowModelResult() throws IOException { public boolean equals(Object obj) { if (obj instanceof ModelInformation) { ModelInformation other = (ModelInformation) obj; - return modelName.equals(other.modelName) + return modelId.equals(other.modelId) && modelType.equals(other.modelType) && Arrays.equals(inputShape, other.inputShape) && Arrays.equals(outputShape, other.outputShape) diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java index 64aff12f284ef..6c6100086316e 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java @@ -42,7 +42,7 @@ public boolean containsModel(String modelId) { } public void addModel(ModelInformation modelInformation) { - modelInfoMap.put(modelInformation.getModelName(), modelInformation); + modelInfoMap.put(modelInformation.getModelId(), modelInformation); } public void removeModel(String modelId) { @@ -63,7 +63,7 @@ public ModelInformation getModelInformationById(String modelId) { public void clearFailedModel() { for (ModelInformation modelInformation : modelInfoMap.values()) { if (modelInformation.getStatus() == ModelStatus.UNAVAILABLE) { - modelInfoMap.remove(modelInformation.getModelName()); + modelInfoMap.remove(modelInformation.getModelId()); } } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java index 9d28ea0130e0b..276af31bb1a06 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java @@ -22,6 +22,7 @@ import org.apache.iotdb.commons.executable.ExecutableManager; import org.apache.iotdb.commons.file.SystemFileFactory; import org.apache.iotdb.commons.pipe.agent.plugin.meta.PipePluginMeta; +import org.apache.iotdb.commons.utils.FileUtils; import org.apache.iotdb.pipe.api.exception.PipeException; import org.apache.tsfile.external.commons.codec.digest.DigestUtils; @@ -123,6 +124,14 @@ public String getPluginInstallPathV1(String fileName) { return this.libRoot + File.separator + INSTALL_DIR + File.separator + fileName; } + public void linkExistedPlugin( + final String oldPluginName, final String newPluginName, final String fileName) + throws IOException { + FileUtils.createHardLink( + new File(getPluginsDirPath(oldPluginName), fileName), + new File(getPluginsDirPath(newPluginName), fileName)); + } + /** * @param byteBuffer file * @param pluginName diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java index 943bdeb9cba46..0459d4d2c86e1 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java @@ -206,7 +206,7 @@ private ColumnHeaderConstant() { public static final String STATEMENT = "Statement"; // column names for show idle connection - public static final String DATANODE_ID = "data_node_id"; + public static final String DATANODE_ID = "datanode_id"; public static final String USERID = "user_id"; public static final String SESSION_ID = "session_id"; public static final String USER_NAME = "user_name"; @@ -214,11 +214,19 @@ private ColumnHeaderConstant() { public static final String CLIENT_IP = "client_ip"; public static final String QUERY_ID_TABLE_MODEL = "query_id"; - public static final String QUERY_ID_START_TIME_TABLE_MODEL = "start_time"; public static final String DATA_NODE_ID_TABLE_MODEL = "datanode_id"; public static final String START_TIME_TABLE_MODEL = "start_time"; public static final String ELAPSED_TIME_TABLE_MODEL = "elapsed_time"; + // column names for current_queries and queries_costs_histogram + public static final String STATE_TABLE_MODEL = "state"; + public static final String END_TIME_TABLE_MODEL = "end_time"; + public static final String COST_TIME = "cost_time"; + public static final String STATEMENT_TABLE_MODEL = "statement"; + public static final String USER_TABLE_MODEL = "user"; + public static final String BIN = "bin"; + public static final String NUMS = "nums"; + public static final String TABLE_NAME_TABLE_MODEL = "table_name"; public static final String TABLE_TYPE_TABLE_MODEL = "table_type"; public static final String COLUMN_NAME_TABLE_MODEL = "column_name"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java index 2db41cc3c2dd7..b8e03423d61c6 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java @@ -43,7 +43,6 @@ public class InformationSchema { public static final String TOPICS = "topics"; public static final String SUBSCRIPTIONS = "subscriptions"; public static final String VIEWS = "views"; - public static final String MODELS = "models"; public static final String FUNCTIONS = "functions"; public static final String CONFIGURATIONS = "configurations"; public static final String KEYWORDS = "keywords"; @@ -51,6 +50,8 @@ public class InformationSchema { public static final String CONFIG_NODES = "config_nodes"; public static final String DATA_NODES = "data_nodes"; public static final String CONNECTIONS = "connections"; + public static final String CURRENT_QUERIES = "current_queries"; + public static final String QUERIES_COSTS_HISTOGRAM = "queries_costs_histogram"; static { final TsTable queriesTable = new TsTable(QUERIES); @@ -58,17 +59,15 @@ public class InformationSchema { new TagColumnSchema(ColumnHeaderConstant.QUERY_ID_TABLE_MODEL, TSDataType.STRING)); queriesTable.addColumnSchema( new AttributeColumnSchema( - ColumnHeaderConstant.QUERY_ID_START_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); + ColumnHeaderConstant.START_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); queriesTable.addColumnSchema( new AttributeColumnSchema(ColumnHeaderConstant.DATA_NODE_ID_TABLE_MODEL, TSDataType.INT32)); queriesTable.addColumnSchema( new AttributeColumnSchema(ColumnHeaderConstant.ELAPSED_TIME_TABLE_MODEL, TSDataType.FLOAT)); queriesTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.STATEMENT.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); + new AttributeColumnSchema(ColumnHeaderConstant.STATEMENT_TABLE_MODEL, TSDataType.STRING)); queriesTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.USER.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); + new AttributeColumnSchema(ColumnHeaderConstant.USER_TABLE_MODEL, TSDataType.STRING)); queriesTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(QUERIES, queriesTable); @@ -256,23 +255,6 @@ public class InformationSchema { viewTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(VIEWS, viewTable); - final TsTable modelTable = new TsTable(MODELS); - modelTable.addColumnSchema( - new TagColumnSchema(ColumnHeaderConstant.MODEL_ID_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema(ColumnHeaderConstant.MODEL_TYPE_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.STATE.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.CONFIGS.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.NOTES.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); - schemaTables.put(MODELS, modelTable); - final TsTable functionTable = new TsTable(FUNCTIONS); functionTable.addColumnSchema( new TagColumnSchema(ColumnHeaderConstant.FUNCTION_NAME_TABLE_MODEL, TSDataType.STRING)); @@ -379,6 +361,37 @@ public class InformationSchema { new AttributeColumnSchema(ColumnHeaderConstant.CLIENT_IP, TSDataType.STRING)); connectionsTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(CONNECTIONS, connectionsTable); + + final TsTable currentQueriesTable = new TsTable(CURRENT_QUERIES); + currentQueriesTable.addColumnSchema( + new TagColumnSchema(ColumnHeaderConstant.QUERY_ID_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.STATE_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema( + ColumnHeaderConstant.START_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.END_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.DATA_NODE_ID_TABLE_MODEL, TSDataType.INT32)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.COST_TIME, TSDataType.FLOAT)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.STATEMENT_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.USER_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.CLIENT_IP, TSDataType.STRING)); + currentQueriesTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); + schemaTables.put(CURRENT_QUERIES, currentQueriesTable); + + final TsTable queriesCostsHistogramTable = new TsTable(QUERIES_COSTS_HISTOGRAM); + queriesCostsHistogramTable.addColumnSchema( + new TagColumnSchema(ColumnHeaderConstant.BIN, TSDataType.STRING)); + queriesCostsHistogramTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.NUMS, TSDataType.INT32)); + queriesCostsHistogramTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); + schemaTables.put(QUERIES_COSTS_HISTOGRAM, queriesCostsHistogramTable); } public static Map getSchemaTables() { diff --git a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 index 4bd8a2b0bd7e9..1690ea4c855b0 100644 --- a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 +++ b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 @@ -174,6 +174,12 @@ statement | loadModelStatement | unloadModelStatement + // Prepared Statement + | prepareStatement + | executeStatement + | executeImmediateStatement + | deallocateStatement + // View, Trigger, CQ, Quota are not supported yet ; @@ -857,6 +863,23 @@ unloadModelStatement : UNLOAD MODEL existingModelId=identifier FROM DEVICES deviceIdList=string ; +// ------------------------------------------- Prepared Statement --------------------------------------------------------- +prepareStatement + : PREPARE statementName=identifier FROM sql=statement + ; + +executeStatement + : EXECUTE statementName=identifier (USING literalExpression (',' literalExpression)*)? + ; + +executeImmediateStatement + : EXECUTE IMMEDIATE sql=string (USING literalExpression (',' literalExpression)*)? + ; + +deallocateStatement + : DEALLOCATE PREPARE statementName=identifier + ; + // ------------------------------------------- Query Statement --------------------------------------------------------- queryStatement : query #statementDefault diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index e0680cd29b794..1dc2f025f5c34 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -60,13 +60,7 @@ struct TRegisterModelResp { struct TInferenceReq { 1: required string modelId 2: required binary dataset - 3: optional TWindowParams windowParams - 4: optional map inferenceAttributes -} - -struct TWindowParams { - 1: required i32 windowInterval - 2: required i32 windowStep + 3: optional map inferenceAttributes } struct TInferenceResp { diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift index d8f6318063ebb..f2b8ec6b8b071 100644 --- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift +++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift @@ -1096,34 +1096,6 @@ struct TUnsetSchemaTemplateReq { 4: optional bool isGeneratedByPipe } -struct TCreateModelReq { - 1: required string modelName - 2: required string uri -} - -struct TDropModelReq { - 1: required string modelId -} - -struct TGetModelInfoReq { - 1: required string modelId -} - -struct TGetModelInfoResp { - 1: required common.TSStatus status - 2: optional binary modelInfo - 3: optional common.TEndPoint aiNodeAddress -} - -struct TUpdateModelInfoReq { - 1: required string modelId - 2: required i32 modelStatus - 3: optional string attributes - 4: optional list aiNodeIds - 5: optional i32 inputLength - 6: optional i32 outputLength -} - struct TDataSchemaForTable{ 1: required string targetSql } @@ -1132,16 +1104,6 @@ struct TDataSchemaForTree{ 1: required list path } -struct TCreateTrainingReq { - 1: required string modelId - 2: required bool isTableModel - 3: required string existingModelId - 4: optional TDataSchemaForTable dataSchemaForTable - 5: optional TDataSchemaForTree dataSchemaForTree - 6: optional map parameters - 7: optional list> timeRanges -} - // ==================================================== // Quota // ==================================================== @@ -2006,31 +1968,6 @@ service IConfigNodeRPCService { */ TShowCQResp showCQ() - // ==================================================== - // AI Model - // ==================================================== - - /** - * Create a model - * - * @return SUCCESS_STATUS if the model was created successfully - */ - common.TSStatus createModel(TCreateModelReq req) - - /** - * Drop a model - * - * @return SUCCESS_STATUS if the model was removed successfully - */ - common.TSStatus dropModel(TDropModelReq req) - - /** - * Return the model info by model_id - */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req) - - common.TSStatus updateModelInfo(TUpdateModelInfoReq req) - // ====================================================== // Quota // ====================================================== diff --git a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift index 7d334059fff01..48afb89d33661 100644 --- a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift +++ b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift @@ -164,6 +164,7 @@ struct TSCloseOperationReq { 1: required i64 sessionId 2: optional i64 queryId 3: optional i64 statementId + 4: optional string preparedStatementName } struct TSFetchResultsReq{ diff --git a/pom.xml b/pom.xml index d9df3de1bf7b1..66c7116a9fb5d 100644 --- a/pom.xml +++ b/pom.xml @@ -173,7 +173,7 @@ 0.14.1 1.9 1.5.6-3 - 2.2.0-251209-SNAPSHOT + 2.2.0-251210-SNAPSHOT