From bf00ae851b57d0a895eaca9217ce0dbb689bfa79 Mon Sep 17 00:00:00 2001 From: Caideyipi <87789683+Caideyipi@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:34:52 +0800 Subject: [PATCH 01/13] Pipe: Fixed the NPE caused by new regions + history only logic (#16879) (cherry picked from commit c560247f8b541d61a7d1f1302e2f20756417ee62) --- .../procedure/impl/pipe/task/AlterPipeProcedureV2.java | 1 + 1 file changed, 1 insertion(+) diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java index f6b84cb1f6e3..53f908bf4fc1 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/pipe/task/AlterPipeProcedureV2.java @@ -187,6 +187,7 @@ public void executeFromCalculateInfoForTask(final ConfigNodeProcedureEnv env) { && !databaseName.startsWith(SchemaConstant.SYSTEM_DATABASE + ".") && !databaseName.equals(SchemaConstant.AUDIT_DATABASE) && !databaseName.startsWith(SchemaConstant.AUDIT_DATABASE + ".") + && !Objects.isNull(currentPipeTaskMeta) && !(PipeTaskAgent.isHistoryOnlyPipe( currentPipeStaticMeta.getSourceParameters()) && PipeTaskAgent.isHistoryOnlyPipe( From 98ecdc2b6d3618297bbadc3d1b5b6d7afda47c8a Mon Sep 17 00:00:00 2001 From: Caideyipi <87789683+Caideyipi@users.noreply.github.com> Date: Mon, 8 Dec 2025 14:41:33 +0800 Subject: [PATCH 02/13] Pipe: Fixed the bug that reused plugins may not get loader and visibility (#16877) (cherry picked from commit a0dbf9b18e0b8f1042d613c672eab91e3fc5ef2e) --- .../persistence/pipe/PipePluginInfo.java | 78 +++++++++++-------- .../service/PipePluginExecutableManager.java | 9 +++ 2 files changed, 55 insertions(+), 32 deletions(-) diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java index 280a891b598d..27cc3cc4cbf5 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/pipe/PipePluginInfo.java @@ -161,19 +161,19 @@ public boolean isJarNeededToBeSavedWhenCreatingPipePlugin(final String jarName) } public void checkPipePluginExistence( - final Map extractorAttributes, + final Map sourceAttributes, final Map processorAttributes, - final Map connectorAttributes) { - final PipeParameters extractorParameters = new PipeParameters(extractorAttributes); - final String extractorPluginName = - extractorParameters.getStringOrDefault( + final Map sinkAttributes) { + final PipeParameters sourceParameters = new PipeParameters(sourceAttributes); + final String sourcePluginName = + sourceParameters.getStringOrDefault( Arrays.asList(PipeSourceConstant.EXTRACTOR_KEY, PipeSourceConstant.SOURCE_KEY), IOTDB_EXTRACTOR.getPipePluginName()); - if (!pipePluginMetaKeeper.containsPipePlugin(extractorPluginName)) { + if (!pipePluginMetaKeeper.containsPipePlugin(sourcePluginName)) { final String exceptionMessage = String.format( "Failed to create or alter pipe, the pipe extractor plugin %s does not exist", - extractorPluginName); + sourcePluginName); LOGGER.warn(exceptionMessage); throw new PipeException(exceptionMessage); } @@ -191,16 +191,16 @@ public void checkPipePluginExistence( throw new PipeException(exceptionMessage); } - final PipeParameters connectorParameters = new PipeParameters(connectorAttributes); - final String connectorPluginName = - connectorParameters.getStringOrDefault( + final PipeParameters sinkParameters = new PipeParameters(sinkAttributes); + final String sinkPluginName = + sinkParameters.getStringOrDefault( Arrays.asList(PipeSinkConstant.CONNECTOR_KEY, PipeSinkConstant.SINK_KEY), IOTDB_THRIFT_CONNECTOR.getPipePluginName()); - if (!pipePluginMetaKeeper.containsPipePlugin(connectorPluginName)) { + if (!pipePluginMetaKeeper.containsPipePlugin(sinkPluginName)) { final String exceptionMessage = String.format( "Failed to create or alter pipe, the pipe connector plugin %s does not exist", - connectorPluginName); + sinkPluginName); LOGGER.warn(exceptionMessage); throw new PipeException(exceptionMessage); } @@ -212,34 +212,29 @@ public TSStatus createPipePlugin(final CreatePipePluginPlan createPipePluginPlan try { final PipePluginMeta pipePluginMeta = createPipePluginPlan.getPipePluginMeta(); final String pluginName = pipePluginMeta.getPluginName(); + final String className = pipePluginMeta.getClassName(); + final String jarName = pipePluginMeta.getJarName(); // try to drop the old pipe plugin if exists to reduce the effect of the inconsistency dropPipePlugin(new DropPipePluginPlan(pluginName)); pipePluginMetaKeeper.addPipePluginMeta(pluginName, pipePluginMeta); - pipePluginMetaKeeper.addJarNameAndMd5( - pipePluginMeta.getJarName(), pipePluginMeta.getJarMD5()); + pipePluginMetaKeeper.addJarNameAndMd5(jarName, pipePluginMeta.getJarMD5()); if (createPipePluginPlan.getJarFile() != null) { pipePluginExecutableManager.savePluginToInstallDir( - ByteBuffer.wrap(createPipePluginPlan.getJarFile().getValues()), - pluginName, - pipePluginMeta.getJarName()); - final String pluginDirPath = pipePluginExecutableManager.getPluginsDirPath(pluginName); - final PipePluginClassLoader pipePluginClassLoader = - classLoaderManager.createPipePluginClassLoader(pluginDirPath); - try { - final Class pluginClass = - Class.forName(pipePluginMeta.getClassName(), true, pipePluginClassLoader); - pipePluginMetaKeeper.addPipePluginVisibility( - pluginName, VisibilityUtils.calculateFromPluginClass(pluginClass)); - classLoaderManager.addPluginAndClassLoader(pluginName, pipePluginClassLoader); - } catch (final Exception e) { - try { - pipePluginClassLoader.close(); - } catch (final Exception ignored) { - } - throw e; + ByteBuffer.wrap(createPipePluginPlan.getJarFile().getValues()), pluginName, jarName); + computeFromPluginClass(pluginName, className); + } else { + final String existed = pipePluginMetaKeeper.getPluginNameByJarName(jarName); + if (Objects.nonNull(existed)) { + pipePluginExecutableManager.linkExistedPlugin(existed, pluginName, jarName); + computeFromPluginClass(pluginName, className); + } else { + throw new PipeException( + String.format( + "The %s's creation has not passed in jarName, which does not exist in other pipePlugins. Please check", + pluginName)); } } @@ -255,6 +250,25 @@ public TSStatus createPipePlugin(final CreatePipePluginPlan createPipePluginPlan } } + private void computeFromPluginClass(final String pluginName, final String className) + throws Exception { + final String pluginDirPath = pipePluginExecutableManager.getPluginsDirPath(pluginName); + final PipePluginClassLoader pipePluginClassLoader = + classLoaderManager.createPipePluginClassLoader(pluginDirPath); + try { + final Class pluginClass = Class.forName(className, true, pipePluginClassLoader); + pipePluginMetaKeeper.addPipePluginVisibility( + pluginName, VisibilityUtils.calculateFromPluginClass(pluginClass)); + classLoaderManager.addPluginAndClassLoader(pluginName, pipePluginClassLoader); + } catch (final Exception e) { + try { + pipePluginClassLoader.close(); + } catch (final Exception ignored) { + } + throw e; + } + } + public TSStatus dropPipePlugin(final DropPipePluginPlan dropPipePluginPlan) { final String pluginName = dropPipePluginPlan.getPluginName(); diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java index 9d28ea0130e0..276af31bb1a0 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/pipe/agent/plugin/service/PipePluginExecutableManager.java @@ -22,6 +22,7 @@ import org.apache.iotdb.commons.executable.ExecutableManager; import org.apache.iotdb.commons.file.SystemFileFactory; import org.apache.iotdb.commons.pipe.agent.plugin.meta.PipePluginMeta; +import org.apache.iotdb.commons.utils.FileUtils; import org.apache.iotdb.pipe.api.exception.PipeException; import org.apache.tsfile.external.commons.codec.digest.DigestUtils; @@ -123,6 +124,14 @@ public String getPluginInstallPathV1(String fileName) { return this.libRoot + File.separator + INSTALL_DIR + File.separator + fileName; } + public void linkExistedPlugin( + final String oldPluginName, final String newPluginName, final String fileName) + throws IOException { + FileUtils.createHardLink( + new File(getPluginsDirPath(oldPluginName), fileName), + new File(getPluginsDirPath(newPluginName), fileName)); + } + /** * @param byteBuffer file * @param pluginName From 5b3e6f1d4f57b95986917af1b2ab382fce9f0090 Mon Sep 17 00:00:00 2001 From: Jackie Tien Date: Mon, 8 Dec 2025 17:34:47 +0800 Subject: [PATCH 03/13] Implement PreparedStmt on the Server side (#16764) (#16880) (cherry picked from commit 7436c88304e71a21f961f2de3c2fd85285c13050) --- .../itbase/runtime/ClusterTestStatement.java | 7 + .../it/db/it/IoTDBPreparedStatementIT.java | 385 ++++++++++++++++++ .../db/protocol/session/ClientSession.java | 23 ++ .../db/protocol/session/IClientSession.java | 31 ++ .../session/InternalClientSession.java | 24 ++ .../protocol/session/MqttClientSession.java | 24 ++ .../session/PreparedStatementInfo.java | 99 +++++ .../protocol/session/RestClientSession.java | 25 ++ .../db/protocol/session/SessionManager.java | 41 +- .../handler/BaseServerContextHandler.java | 1 + .../thrift/impl/ClientRPCServiceImpl.java | 1 + .../db/queryengine/plan/Coordinator.java | 89 +++- .../config/TableConfigTaskVisitor.java | 16 + .../config/session/DeallocateTask.java | 72 ++++ .../execution/config/session/PrepareTask.java | 85 ++++ .../PreparedStatementMemoryManager.java | 157 +++++++ .../analyzer/StatementAnalyzer.java | 76 ++-- .../relational/planner/TableModelPlanner.java | 16 +- .../relational/sql/AstMemoryEstimator.java | 67 +++ .../relational/sql/ParameterExtractor.java | 121 ++++++ .../plan/relational/sql/ast/AstVisitor.java | 16 + .../plan/relational/sql/ast/Deallocate.java | 79 ++++ .../plan/relational/sql/ast/Execute.java | 96 +++++ .../relational/sql/ast/ExecuteImmediate.java | 99 +++++ .../plan/relational/sql/ast/Prepare.java | 87 ++++ .../relational/sql/parser/AstBuilder.java | 38 ++ .../relational/grammar/sql/RelationalSql.g4 | 23 ++ .../src/main/thrift/client.thrift | 1 + 28 files changed, 1747 insertions(+), 52 deletions(-) create mode 100644 integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java b/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java index 3f96fdf1372f..0523a3848289 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/runtime/ClusterTestStatement.java @@ -78,6 +78,13 @@ private void updateConfig(Statement statement, int timeout) throws SQLException statement.setQueryTimeout(timeout); } + /** + * Executes a SQL query on all read statements in parallel. + * + *

Note: For PreparedStatement EXECUTE queries, use the write connection directly instead, + * because PreparedStatements are session-scoped and this method may route queries to different + * nodes where the PreparedStatement doesn't exist. + */ @Override public ResultSet executeQuery(String sql) throws SQLException { return new ClusterTestResultSet(readStatements, readEndpoints, sql, queryTimeout); diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java new file mode 100644 index 000000000000..f06d46201aff --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBPreparedStatementIT.java @@ -0,0 +1,385 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.relational.it.db.it; + +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.TableClusterIT; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; +import org.apache.iotdb.itbase.runtime.ClusterTestConnection; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({TableLocalStandaloneIT.class, TableClusterIT.class}) +public class IoTDBPreparedStatementIT { + private static final String DATABASE_NAME = "test"; + private static final String[] sqls = + new String[] { + "CREATE DATABASE " + DATABASE_NAME, + "USE " + DATABASE_NAME, + "CREATE TABLE test_table(id INT64 FIELD, name STRING FIELD, value DOUBLE FIELD)", + "INSERT INTO test_table VALUES (2025-01-01T00:00:00, 1, 'Alice', 100.5)", + "INSERT INTO test_table VALUES (2025-01-01T00:01:00, 2, 'Bob', 200.3)", + "INSERT INTO test_table VALUES (2025-01-01T00:02:00, 3, 'Charlie', 300.7)", + "INSERT INTO test_table VALUES (2025-01-01T00:03:00, 4, 'David', 400.2)", + "INSERT INTO test_table VALUES (2025-01-01T00:04:00, 5, 'Eve', 500.9)", + }; + + protected static void insertData() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + for (String sql : sqls) { + statement.execute(sql); + } + } catch (Exception e) { + fail("insertData failed: " + e.getMessage()); + } + } + + @BeforeClass + public static void setUp() { + EnvFactory.getEnv().initClusterEnvironment(); + insertData(); + } + + @AfterClass + public static void tearDown() { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + /** + * Execute a prepared statement query and verify the result. For PreparedStatement EXECUTE + * queries, use the write connection directly instead of tableResultSetEqualTest, because + * PreparedStatements are session-scoped and tableResultSetEqualTest may route queries to + * different nodes where the PreparedStatement doesn't exist. + */ + private static void executePreparedStatementAndVerify( + Connection connection, + Statement statement, + String executeSql, + String[] expectedHeader, + String[] expectedRetArray) + throws SQLException { + // Execute with parameters using write connection directly + // In cluster test, we need to use write connection to ensure same session + if (connection instanceof ClusterTestConnection) { + // Use write connection directly for PreparedStatement queries + try (Statement writeStatement = + ((ClusterTestConnection) connection) + .writeConnection + .getUnderlyingConnection() + .createStatement(); + ResultSet resultSet = writeStatement.executeQuery(executeSql)) { + ResultSetMetaData metaData = resultSet.getMetaData(); + + // Verify header + assertEquals(expectedHeader.length, metaData.getColumnCount()); + for (int i = 1; i <= metaData.getColumnCount(); i++) { + assertEquals(expectedHeader[i - 1], metaData.getColumnName(i)); + } + + // Verify data + int cnt = 0; + while (resultSet.next()) { + StringBuilder builder = new StringBuilder(); + for (int i = 1; i <= expectedHeader.length; i++) { + builder.append(resultSet.getString(i)).append(","); + } + assertEquals(expectedRetArray[cnt], builder.toString()); + cnt++; + } + assertEquals(expectedRetArray.length, cnt); + } + } else { + try (ResultSet resultSet = statement.executeQuery(executeSql)) { + ResultSetMetaData metaData = resultSet.getMetaData(); + + // Verify header + assertEquals(expectedHeader.length, metaData.getColumnCount()); + for (int i = 1; i <= metaData.getColumnCount(); i++) { + assertEquals(expectedHeader[i - 1], metaData.getColumnName(i)); + } + + // Verify data + int cnt = 0; + while (resultSet.next()) { + StringBuilder builder = new StringBuilder(); + for (int i = 1; i <= expectedHeader.length; i++) { + builder.append(resultSet.getString(i)).append(","); + } + assertEquals(expectedRetArray[cnt], builder.toString()); + cnt++; + } + assertEquals(expectedRetArray.length, cnt); + } + } + } + + @Test + public void testPrepareAndExecute() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:01:00.000Z,2,Bob,200.3,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement + statement.execute("PREPARE stmt1 FROM SELECT * FROM test_table WHERE id = ?"); + // Execute with parameter using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt1 USING 2", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt1"); + } catch (SQLException e) { + fail("testPrepareAndExecute failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareAndExecuteMultipleTimes() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray1 = new String[] {"2025-01-01T00:00:00.000Z,1,Alice,100.5,"}; + String[] retArray2 = new String[] {"2025-01-01T00:02:00.000Z,3,Charlie,300.7,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement + statement.execute("PREPARE stmt2 FROM SELECT * FROM test_table WHERE id = ?"); + // Execute multiple times with different parameters using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt2 USING 1", expectedHeader, retArray1); + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt2 USING 3", expectedHeader, retArray2); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt2"); + } catch (SQLException e) { + fail("testPrepareAndExecuteMultipleTimes failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareWithMultipleParameters() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:01:00.000Z,2,Bob,200.3,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement with multiple parameters + statement.execute("PREPARE stmt3 FROM SELECT * FROM test_table WHERE id = ? AND value > ?"); + // Execute with multiple parameters using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt3 USING 2, 150.0", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt3"); + } catch (SQLException e) { + fail("testPrepareWithMultipleParameters failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteImmediate() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:03:00.000Z,4,David,400.2,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Execute immediate with SQL string and parameters + tableResultSetEqualTest( + "EXECUTE IMMEDIATE 'SELECT * FROM test_table WHERE id = ?' USING 4", + expectedHeader, + retArray, + DATABASE_NAME); + } catch (SQLException e) { + fail("testExecuteImmediate failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteImmediateWithoutParameters() { + String[] expectedHeader = new String[] {"_col0"}; + String[] retArray = new String[] {"5,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Execute immediate without parameters + tableResultSetEqualTest( + "EXECUTE IMMEDIATE 'SELECT COUNT(*) FROM test_table'", + expectedHeader, + retArray, + DATABASE_NAME); + } catch (SQLException e) { + fail("testExecuteImmediateWithoutParameters failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteImmediateWithMultipleParameters() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:04:00.000Z,5,Eve,500.9,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Execute immediate with multiple parameters + tableResultSetEqualTest( + "EXECUTE IMMEDIATE 'SELECT * FROM test_table WHERE id = ? AND value > ?' USING 5, 450.0", + expectedHeader, + retArray, + DATABASE_NAME); + } catch (SQLException e) { + fail("testExecuteImmediateWithMultipleParameters failed: " + e.getMessage()); + } + } + + @Test + public void testDeallocateNonExistentStatement() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Try to deallocate a non-existent statement + SQLException exception = + assertThrows( + SQLException.class, () -> statement.execute("DEALLOCATE PREPARE non_existent_stmt")); + assertTrue( + exception.getMessage().contains("does not exist") + || exception.getMessage().contains("Prepared statement")); + } catch (SQLException e) { + fail("testDeallocateNonExistentStatement failed: " + e.getMessage()); + } + } + + @Test + public void testExecuteNonExistentStatement() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Try to execute a non-existent statement + SQLException exception = + assertThrows( + SQLException.class, () -> statement.execute("EXECUTE non_existent_stmt USING 1")); + assertTrue( + exception.getMessage().contains("does not exist") + || exception.getMessage().contains("Prepared statement")); + } catch (SQLException e) { + fail("testExecuteNonExistentStatement failed: " + e.getMessage()); + } + } + + @Test + public void testMultiplePreparedStatements() { + String[] expectedHeader1 = new String[] {"time", "id", "name", "value"}; + String[] retArray1 = new String[] {"2025-01-01T00:00:00.000Z,1,Alice,100.5,"}; + String[] expectedHeader2 = new String[] {"_col0"}; + String[] retArray2 = new String[] {"4,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare multiple statements + statement.execute("PREPARE stmt4 FROM SELECT * FROM test_table WHERE id = ?"); + statement.execute("PREPARE stmt5 FROM SELECT COUNT(*) FROM test_table WHERE value > ?"); + // Execute both statements using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt4 USING 1", expectedHeader1, retArray1); + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt5 USING 200.0", expectedHeader2, retArray2); + // Deallocate both + statement.execute("DEALLOCATE PREPARE stmt4"); + statement.execute("DEALLOCATE PREPARE stmt5"); + } catch (SQLException e) { + fail("testMultiplePreparedStatements failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareDuplicateName() { + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement + statement.execute("PREPARE stmt6 FROM SELECT * FROM test_table WHERE id = ?"); + // Try to prepare another statement with the same name + SQLException exception = + assertThrows( + SQLException.class, + () -> statement.execute("PREPARE stmt6 FROM SELECT * FROM test_table WHERE id = ?")); + assertTrue( + exception.getMessage().contains("already exists") + || exception.getMessage().contains("Prepared statement")); + // Cleanup + statement.execute("DEALLOCATE PREPARE stmt6"); + } catch (SQLException e) { + fail("testPrepareDuplicateName failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareAndExecuteWithAggregation() { + String[] expectedHeader = new String[] {"_col0"}; + String[] retArray = new String[] {"300.40000000000003,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement with aggregation + statement.execute( + "PREPARE stmt7 FROM SELECT AVG(value) FROM test_table WHERE id >= ? AND id <= ?"); + // Execute with parameters using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt7 USING 2, 4", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt7"); + } catch (SQLException e) { + fail("testPrepareAndExecuteWithAggregation failed: " + e.getMessage()); + } + } + + @Test + public void testPrepareAndExecuteWithStringParameter() { + String[] expectedHeader = new String[] {"time", "id", "name", "value"}; + String[] retArray = new String[] {"2025-01-01T00:02:00.000Z,3,Charlie,300.7,"}; + try (Connection connection = EnvFactory.getEnv().getTableConnection(); + Statement statement = connection.createStatement()) { + statement.execute("USE " + DATABASE_NAME); + // Prepare a statement with string parameter + statement.execute("PREPARE stmt8 FROM SELECT * FROM test_table WHERE name = ?"); + // Execute with string parameter using write connection directly + executePreparedStatementAndVerify( + connection, statement, "EXECUTE stmt8 USING 'Charlie'", expectedHeader, retArray); + // Deallocate + statement.execute("DEALLOCATE PREPARE stmt8"); + } catch (SQLException e) { + fail("testPrepareAndExecuteWithStringParameter failed: " + e.getMessage()); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java index 6aa862b9242f..ea90fbafeccd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/ClientSession.java @@ -33,6 +33,9 @@ public class ClientSession extends IClientSession { private final Map> statementIdToQueryId = new ConcurrentHashMap<>(); + // Map from statement name to PreparedStatementInfo + private final Map preparedStatements = new ConcurrentHashMap<>(); + public ClientSession(Socket clientSocket) { this.clientSocket = clientSocket; } @@ -103,4 +106,24 @@ public static void removeQueryId( } } } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + preparedStatements.put(statementName, info); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + return preparedStatements.remove(statementName); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + return preparedStatements.get(statementName); + } + + @Override + public Set getPreparedStatementNames() { + return preparedStatements.keySet(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java index 351806de099c..97585673e824 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/IClientSession.java @@ -188,6 +188,37 @@ public void setDatabaseName(@Nullable String databaseName) { this.databaseName = databaseName; } + /** + * Add a prepared statement to this session. + * + * @param statementName the name of the prepared statement + * @param info the prepared statement information + */ + public abstract void addPreparedStatement(String statementName, PreparedStatementInfo info); + + /** + * Remove a prepared statement from this session. + * + * @param statementName the name of the prepared statement + * @return the removed prepared statement info, or null if not found + */ + public abstract PreparedStatementInfo removePreparedStatement(String statementName); + + /** + * Get a prepared statement from this session. + * + * @param statementName the name of the prepared statement + * @return the prepared statement info, or null if not found + */ + public abstract PreparedStatementInfo getPreparedStatement(String statementName); + + /** + * Get all prepared statement names in this session. + * + * @return set of prepared statement names + */ + public abstract Set getPreparedStatementNames(); + public long getLastActiveTime() { return lastActiveTime; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java index 3c72d083a8c7..ed87d0b0ee32 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/InternalClientSession.java @@ -88,4 +88,28 @@ public void addQueryId(Long statementId, long queryId) { public void removeQueryId(Long statementId, Long queryId) { ClientSession.removeQueryId(statementIdToQueryId, statementId, queryId); } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } + + @Override + public Set getPreparedStatementNames() { + throw new UnsupportedOperationException( + "InternalClientSession should never call PREPARE statement methods."); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java index ae9e2cd03616..c0b68e885a14 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/MqttClientSession.java @@ -76,4 +76,28 @@ public void addQueryId(Long statementId, long queryId) { public void removeQueryId(Long statementId, Long queryId) { throw new UnsupportedOperationException(); } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } + + @Override + public Set getPreparedStatementNames() { + throw new UnsupportedOperationException( + "MQTT client session does not support PREPARE statement."); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java new file mode 100644 index 000000000000..0bfc750c4ba0 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/PreparedStatementInfo.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.session; + +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +/** + * Information about a prepared statement stored in a session. The AST is cached here to avoid + * reparsing on EXECUTE. + */ +public class PreparedStatementInfo { + + private final String statementName; + private final Statement sql; // Cached AST (contains Parameter nodes) + private final long createTime; + private final long memorySizeInBytes; // Memory size allocated for this PreparedStatement + + public PreparedStatementInfo(String statementName, Statement sql, long memorySizeInBytes) { + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + this.createTime = System.currentTimeMillis(); + this.memorySizeInBytes = memorySizeInBytes; + } + + public PreparedStatementInfo( + String statementName, Statement sql, long createTime, long memorySizeInBytes) { + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + this.createTime = createTime; + this.memorySizeInBytes = memorySizeInBytes; + } + + public String getStatementName() { + return statementName; + } + + public Statement getSql() { + return sql; + } + + public long getCreateTime() { + return createTime; + } + + public long getMemorySizeInBytes() { + return memorySizeInBytes; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + PreparedStatementInfo that = (PreparedStatementInfo) o; + return Objects.equals(statementName, that.statementName) && Objects.equals(sql, that.sql); + } + + @Override + public int hashCode() { + return Objects.hash(statementName, sql); + } + + @Override + public String toString() { + return "PreparedStatementInfo{" + + "statementName='" + + statementName + + '\'' + + ", sql=" + + sql + + ", createTime=" + + createTime + + '}'; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java index fa830ace3fbc..d122c3c7dc5f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/RestClientSession.java @@ -22,12 +22,17 @@ import org.apache.iotdb.service.rpc.thrift.TSConnectionType; import java.util.Collections; +import java.util.Map; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; public class RestClientSession extends IClientSession { private final String clientID; + // Map from statement name to PreparedStatementInfo + private final Map preparedStatements = new ConcurrentHashMap<>(); + public RestClientSession(String clientID) { this.clientID = clientID; } @@ -76,4 +81,24 @@ public void addQueryId(Long statementId, long queryId) { public void removeQueryId(Long statementId, Long queryId) { throw new UnsupportedOperationException(); } + + @Override + public void addPreparedStatement(String statementName, PreparedStatementInfo info) { + preparedStatements.put(statementName, info); + } + + @Override + public PreparedStatementInfo removePreparedStatement(String statementName) { + return preparedStatements.remove(statementName); + } + + @Override + public PreparedStatementInfo getPreparedStatement(String statementName) { + return preparedStatements.get(statementName); + } + + @Override + public Set getPreparedStatementNames() { + return preparedStatements.keySet(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java index a4a28efa5ec3..e5e95d4f82ca 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/session/SessionManager.java @@ -40,6 +40,7 @@ import org.apache.iotdb.db.protocol.thrift.OperationType; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; import org.apache.iotdb.db.queryengine.common.SessionInfo; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.PreparedStatementMemoryManager; import org.apache.iotdb.db.storageengine.dataregion.read.control.QueryResourceManager; import org.apache.iotdb.db.utils.DataNodeAuthUtils; import org.apache.iotdb.metrics.utils.MetricLevel; @@ -276,6 +277,7 @@ public boolean closeSession(IClientSession session, LongConsumer releaseByQueryI } private void releaseSessionResource(IClientSession session, LongConsumer releaseQueryResource) { + // Release query resources Iterable statementIds = session.getStatementIds(); if (statementIds != null) { for (Long statementId : statementIds) { @@ -287,6 +289,17 @@ private void releaseSessionResource(IClientSession session, LongConsumer release } } } + + // Release PreparedStatement memory resources + try { + PreparedStatementMemoryManager.getInstance().releaseAllForSession(session); + } catch (Exception e) { + LOGGER.warn( + "Failed to release PreparedStatement resources for session {}: {}", + session, + e.getMessage(), + e); + } } public TSStatus closeOperation( @@ -295,6 +308,7 @@ public TSStatus closeOperation( long statementId, boolean haveStatementId, boolean haveSetQueryId, + String preparedStatementName, LongConsumer releaseByQueryId) { if (!checkLogin(session)) { return RpcUtils.getStatus( @@ -307,7 +321,7 @@ public TSStatus closeOperation( if (haveSetQueryId) { this.closeDataset(session, statementId, queryId, releaseByQueryId); } else { - this.closeStatement(session, statementId, releaseByQueryId); + this.closeStatement(session, statementId, preparedStatementName, releaseByQueryId); } return RpcUtils.getStatus(TSStatusCode.SUCCESS_STATUS); } else { @@ -342,14 +356,35 @@ public long requestStatementId(IClientSession session) { } public void closeStatement( - IClientSession session, long statementId, LongConsumer releaseByQueryId) { + IClientSession session, + long statementId, + String preparedStatementName, + LongConsumer releaseByQueryId) { Set queryIdSet = session.removeStatementId(statementId); if (queryIdSet != null) { for (Long queryId : queryIdSet) { releaseByQueryId.accept(queryId); } } - session.removeStatementId(statementId); + + // If preparedStatementName is provided, release the prepared statement resources + if (preparedStatementName != null && !preparedStatementName.isEmpty()) { + try { + PreparedStatementInfo removedInfo = session.removePreparedStatement(preparedStatementName); + if (removedInfo != null) { + // Release the memory allocated for this PreparedStatement + PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); + } + } catch (Exception e) { + LOGGER.warn( + "Failed to release PreparedStatement '{}' resources when closing statement {} for session {}: {}", + preparedStatementName, + statementId, + session, + e.getMessage(), + e); + } + } } public long requestQueryId(IClientSession session, Long statementId) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java index e633caa45f6a..9b7efd827806 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/handler/BaseServerContextHandler.java @@ -71,6 +71,7 @@ public ServerContext createContext(TProtocol in, TProtocol out) { public void deleteContext(ServerContext context, TProtocol in, TProtocol out) { getSessionManager().removeCurrSession(); + if (context != null && factory != null) { ((JudgableServerContext) context).whenDisconnect(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java index ea7646efb2b2..ad16535bf573 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java @@ -1445,6 +1445,7 @@ public TSStatus closeOperation(TSCloseOperationReq req) { req.statementId, req.isSetStatementId(), req.isSetQueryId(), + req.isSetPreparedStatementName() ? req.getPreparedStatementName() : null, COORDINATOR::cleanupQueryExecution); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java index f10600cbda94..7708f6c18cda 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java @@ -29,10 +29,14 @@ import org.apache.iotdb.commons.conf.CommonConfig; import org.apache.iotdb.commons.conf.CommonDescriptor; import org.apache.iotdb.commons.conf.IoTDBConstant; +import org.apache.iotdb.commons.memory.IMemoryBlock; +import org.apache.iotdb.commons.memory.MemoryBlockType; import org.apache.iotdb.db.auth.AuthorityChecker; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; import org.apache.iotdb.db.queryengine.common.DataNodeEndPoints; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.QueryId; @@ -49,6 +53,7 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.TreeConfigTaskVisitor; import org.apache.iotdb.db.queryengine.plan.planner.LocalExecutionPlanner; import org.apache.iotdb.db.queryengine.plan.planner.TreeModelPlanner; +import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; import org.apache.iotdb.db.queryengine.plan.relational.planner.TableModelPlanner; @@ -56,6 +61,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.DistributedOptimizeFactory; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.LogicalOptimizeFactory; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ParameterExtractor; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.AddColumn; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.AlterDB; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ClearCache; @@ -64,6 +70,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateModel; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTable; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateTraining; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Deallocate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DescribeTable; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropColumn; @@ -71,13 +78,19 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropFunction; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropModel; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropTable; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Execute; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExecuteImmediate; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExtendRegion; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Flush; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.KillQuery; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LoadConfiguration; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LoadModel; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.MigrateRegion; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PipeStatement; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Prepare; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ReconstructRegion; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.RelationalAuthorStatement; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.RemoveAINode; @@ -132,7 +145,9 @@ import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; @@ -184,6 +199,8 @@ public class Coordinator { private static final Coordinator INSTANCE = new Coordinator(); + private static final IMemoryBlock coordinatorMemoryBlock; + private final ConcurrentHashMap queryExecutionMap; private final StatementRewrite statementRewrite; @@ -192,6 +209,20 @@ public class Coordinator { private final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier; private final TypeManager typeManager; + static { + coordinatorMemoryBlock = + IoTDBDescriptor.getInstance() + .getMemoryConfig() + .getCoordinatorMemoryManager() + .exactAllocate("Coordinator", MemoryBlockType.DYNAMIC); + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Initialized shared MemoryBlock 'Coordinator' with all available memory: {} bytes", + coordinatorMemoryBlock.getTotalMemorySizeInBytes()); + } + } + private Coordinator() { this.queryExecutionMap = new ConcurrentHashMap<>(); this.typeManager = new InternalTypeManager(); @@ -401,6 +432,8 @@ private IQueryExecution createQueryExecutionForTableModel( distributionPlanOptimizers, AuthorityChecker.getAccessControl(), dataNodeLocationSupplier, + Collections.emptyList(), + Collections.emptyMap(), typeManager); return new QueryExecution(tableModelPlanner, queryContext, executor); } @@ -475,7 +508,9 @@ private IQueryExecution createQueryExecutionForTableModel( || statement instanceof LoadModel || statement instanceof UnloadModel || statement instanceof ShowLoadedModels - || statement instanceof RemoveRegion) { + || statement instanceof RemoveRegion + || statement instanceof Prepare + || statement instanceof Deallocate) { return new ConfigExecution( queryContext, null, @@ -485,12 +520,54 @@ private IQueryExecution createQueryExecutionForTableModel( clientSession, metadata, AuthorityChecker.getAccessControl(), typeManager), queryContext)); } + // Initialize variables for TableModelPlanner + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statementToUse = statement; + List parameters = Collections.emptyList(); + Map, Expression> parameterLookup = Collections.emptyMap(); + + if (statement instanceof Execute) { + Execute executeStatement = (Execute) statement; + String statementName = executeStatement.getStatementName().getValue(); + + // Get prepared statement from session (contains cached AST) + PreparedStatementInfo preparedInfo = clientSession.getPreparedStatement(statementName); + if (preparedInfo == null) { + throw new SemanticException( + String.format("Prepared statement '%s' does not exist", statementName)); + } + + // Use cached AST + statementToUse = preparedInfo.getSql(); + + // Bind parameters: create parameterLookup map + // Note: bindParameters() internally validates parameter count + parameterLookup = + ParameterExtractor.bindParameters(statementToUse, executeStatement.getParameters()); + parameters = new ArrayList<>(executeStatement.getParameters()); + + } else if (statement instanceof ExecuteImmediate) { + ExecuteImmediate executeImmediateStatement = (ExecuteImmediate) statement; + + // EXECUTE IMMEDIATE needs to parse SQL first + String sql = executeImmediateStatement.getSqlString(); + List literalParameters = executeImmediateStatement.getParameters(); + + statementToUse = sqlParser.createStatement(sql, clientSession.getZoneId(), clientSession); + + if (!literalParameters.isEmpty()) { + parameterLookup = ParameterExtractor.bindParameters(statementToUse, literalParameters); + parameters = new ArrayList<>(literalParameters); + } + } + if (statement instanceof WrappedInsertStatement) { ((WrappedInsertStatement) statement).setContext(queryContext); } - final TableModelPlanner tableModelPlanner = + + // Create QueryExecution with TableModelPlanner + TableModelPlanner tableModelPlanner = new TableModelPlanner( - statement, + statementToUse, sqlParser, metadata, scheduledExecutor, @@ -501,6 +578,8 @@ private IQueryExecution createQueryExecutionForTableModel( distributionPlanOptimizers, AuthorityChecker.getAccessControl(), dataNodeLocationSupplier, + parameters, + parameterLookup, typeManager); return new QueryExecution(tableModelPlanner, queryContext, executor); } @@ -609,6 +688,10 @@ public static Coordinator getInstance() { return INSTANCE; } + public static IMemoryBlock getCoordinatorMemoryBlock() { + return coordinatorMemoryBlock; + } + public void recordExecutionTime(long queryId, long executionTime) { IQueryExecution queryExecution = getQueryExecution(queryId); if (queryExecution != null) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java index 3356ed2b5849..f9d668944e00 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java @@ -101,6 +101,8 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.ShowTablesDetailsTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.ShowTablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.UseDBTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.DeallocateTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.session.PrepareTask; import org.apache.iotdb.db.queryengine.plan.execution.config.session.SetSqlDialectTask; import org.apache.iotdb.db.queryengine.plan.execution.config.session.ShowCurrentDatabaseTask; import org.apache.iotdb.db.queryengine.plan.execution.config.session.ShowCurrentSqlDialectTask; @@ -150,6 +152,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CreateView; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataType; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DatabaseStatement; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Deallocate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DescribeTable; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DropColumn; @@ -171,6 +174,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LongLiteral; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.MigrateRegion; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Prepare; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Property; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.QualifiedName; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ReconstructRegion; @@ -1375,6 +1379,18 @@ protected IConfigTask visitSetSqlDialect(SetSqlDialect node, MPPQueryContext con return new SetSqlDialectTask(node.getSqlDialect()); } + @Override + protected IConfigTask visitPrepare(Prepare node, MPPQueryContext context) { + context.setQueryType(QueryType.WRITE); + return new PrepareTask(node.getStatementName().getValue(), node.getSql()); + } + + @Override + protected IConfigTask visitDeallocate(Deallocate node, MPPQueryContext context) { + context.setQueryType(QueryType.WRITE); + return new DeallocateTask(node.getStatementName().getValue()); + } + @Override protected IConfigTask visitShowCurrentDatabase( ShowCurrentDatabase node, MPPQueryContext context) { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java new file mode 100644 index 000000000000..6f5f3f484615 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/DeallocateTask.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.protocol.session.SessionManager; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +/** + * Task for executing DEALLOCATE PREPARE statement. Removes the prepared statement from the session + * and releases its allocated memory. + */ +public class DeallocateTask implements IConfigTask { + + private final String statementName; + + public DeallocateTask(String statementName) { + this.statementName = statementName; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + SettableFuture future = SettableFuture.create(); + IClientSession session = SessionManager.getInstance().getCurrSession(); + if (session == null) { + future.setException( + new IllegalStateException("No current session available for DEALLOCATE statement")); + return future; + } + + // Remove the prepared statement + PreparedStatementInfo removedInfo = session.removePreparedStatement(statementName); + if (removedInfo == null) { + future.setException( + new SemanticException( + String.format("Prepared statement '%s' does not exist", statementName))); + return future; + } + + // Release the memory allocated for this PreparedStatement from the shared MemoryBlock + PreparedStatementMemoryManager.getInstance().release(removedInfo.getMemorySizeInBytes()); + + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + return future; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java new file mode 100644 index 000000000000..bf61e702c72d --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PrepareTask.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.protocol.session.SessionManager; +import org.apache.iotdb.db.queryengine.plan.execution.config.ConfigTaskResult; +import org.apache.iotdb.db.queryengine.plan.execution.config.IConfigTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.executor.IConfigTaskExecutor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.AstMemoryEstimator; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; +import org.apache.iotdb.rpc.TSStatusCode; + +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; + +/** + * Task for executing PREPARE statement. Stores the prepared statement AST in the session. The AST + * is cached to avoid reparsing on EXECUTE (skipping Parser phase). Memory is allocated from + * CoordinatorMemoryManager and shared across all sessions. + */ +public class PrepareTask implements IConfigTask { + + private final String statementName; + private final Statement sql; // AST containing Parameter nodes + + public PrepareTask(String statementName, Statement sql) { + this.statementName = statementName; + this.sql = sql; + } + + @Override + public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) + throws InterruptedException { + SettableFuture future = SettableFuture.create(); + IClientSession session = SessionManager.getInstance().getCurrSession(); + if (session == null) { + future.setException( + new IllegalStateException("No current session available for PREPARE statement")); + return future; + } + + // Check if prepared statement with the same name already exists + PreparedStatementInfo existingInfo = session.getPreparedStatement(statementName); + if (existingInfo != null) { + future.setException( + new SemanticException( + String.format("Prepared statement '%s' already exists.", statementName))); + return future; + } + + // Estimate memory size of the AST + long memorySizeInBytes = AstMemoryEstimator.estimateMemorySize(sql); + + // Allocate memory from CoordinatorMemoryManager + // This memory is shared across all sessions using a single MemoryBlock + PreparedStatementMemoryManager.getInstance().allocate(statementName, memorySizeInBytes); + + // Create and store the prepared statement info (AST is cached) + PreparedStatementInfo info = new PreparedStatementInfo(statementName, sql, memorySizeInBytes); + session.addPreparedStatement(statementName, info); + + future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); + return future; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java new file mode 100644 index 000000000000..9d5a3fb098e5 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/session/PreparedStatementMemoryManager.java @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.execution.config.session; + +import org.apache.iotdb.commons.memory.IMemoryBlock; +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.protocol.session.IClientSession; +import org.apache.iotdb.db.protocol.session.PreparedStatementInfo; +import org.apache.iotdb.db.queryengine.plan.Coordinator; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Set; + +/** + * Memory manager for PreparedStatement. All PreparedStatements from all sessions share a single + * MemoryBlock named "Coordinator" allocated from CoordinatorMemoryManager. The MemoryBlock is + * initialized in Coordinator with all available memory. + */ +public class PreparedStatementMemoryManager { + private static final Logger LOGGER = + LoggerFactory.getLogger(PreparedStatementMemoryManager.class); + + private static final PreparedStatementMemoryManager INSTANCE = + new PreparedStatementMemoryManager(); + + private static final String SHARED_MEMORY_BLOCK_NAME = "Coordinator"; + + private PreparedStatementMemoryManager() { + // singleton + } + + public static PreparedStatementMemoryManager getInstance() { + return INSTANCE; + } + + private IMemoryBlock getSharedMemoryBlock() { + return Coordinator.getCoordinatorMemoryBlock(); + } + + /** + * Allocate memory for a PreparedStatement. + * + * @param statementName the name of the prepared statement + * @param memorySizeInBytes the memory size in bytes to allocate + * @throws SemanticException if memory allocation fails + */ + public void allocate(String statementName, long memorySizeInBytes) { + IMemoryBlock sharedMemoryBlock = getSharedMemoryBlock(); + // Allocate memory from the shared block + boolean allocated = sharedMemoryBlock.allocate(memorySizeInBytes); + if (!allocated) { + LOGGER.warn( + "Failed to allocate {} bytes from shared MemoryBlock '{}' for PreparedStatement '{}'", + memorySizeInBytes, + SHARED_MEMORY_BLOCK_NAME, + statementName); + throw new SemanticException( + String.format( + "Insufficient memory for PreparedStatement '%s'. " + + "Please deallocate some PreparedStatements and try again.", + statementName)); + } + + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Allocated {} bytes for PreparedStatement '{}' from shared MemoryBlock '{}'. ", + memorySizeInBytes, + statementName, + SHARED_MEMORY_BLOCK_NAME); + } + } + + /** + * Release memory for a PreparedStatement. + * + * @param memorySizeInBytes the memory size in bytes to release + */ + public void release(long memorySizeInBytes) { + if (memorySizeInBytes <= 0) { + return; + } + + IMemoryBlock sharedMemoryBlock = getSharedMemoryBlock(); + if (!sharedMemoryBlock.isReleased()) { + long releasedSize = sharedMemoryBlock.release(memorySizeInBytes); + if (LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Released {} bytes from shared MemoryBlock '{}' for PreparedStatement. ", + releasedSize, + SHARED_MEMORY_BLOCK_NAME); + } + } else { + LOGGER.error( + "Attempted to release memory from shared MemoryBlock '{}' but it is released", + SHARED_MEMORY_BLOCK_NAME); + } + } + + /** + * Release all PreparedStatements for a session. This method should be called when a session is + * closed or connection is lost. + * + * @param session the session whose PreparedStatements should be released + */ + public void releaseAllForSession(IClientSession session) { + if (session == null) { + return; + } + + Set preparedStatementNames = session.getPreparedStatementNames(); + if (preparedStatementNames == null || preparedStatementNames.isEmpty()) { + return; + } + + int releasedCount = 0; + long totalReleasedBytes = 0; + + for (String statementName : preparedStatementNames) { + PreparedStatementInfo info = session.getPreparedStatement(statementName); + if (info != null) { + long memorySize = info.getMemorySizeInBytes(); + if (memorySize > 0) { + release(memorySize); + releasedCount++; + totalReleasedBytes += memorySize; + } + } + } + + if (releasedCount > 0 && LOGGER.isDebugEnabled()) { + LOGGER.debug( + "Released {} PreparedStatement(s) ({} bytes total) for session {}", + releasedCount, + totalReleasedBytes, + session); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java index 5acf353e32e5..73c57157f722 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java @@ -46,7 +46,6 @@ import org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName; import org.apache.iotdb.db.queryengine.plan.relational.metadata.TableMetadataImpl; import org.apache.iotdb.db.queryengine.plan.relational.metadata.TableSchema; -import org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter; import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext; import org.apache.iotdb.db.queryengine.plan.relational.planner.ScopeAware; import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol; @@ -133,6 +132,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NullIfExpression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Offset; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.OrderBy; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PipeEnriched; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Property; @@ -211,6 +211,7 @@ import org.apache.iotdb.udf.api.relational.table.specification.TableParameterSpecification; import com.google.common.base.Joiner; +import com.google.common.base.VerifyException; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -281,6 +282,7 @@ import static org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction.TIMECOL_PARAMETER_NAME; import static org.apache.iotdb.db.queryengine.plan.relational.metadata.MetadataUtil.createQualifiedObjectName; import static org.apache.iotdb.db.queryengine.plan.relational.metadata.TableMetadataImpl.isTimestampType; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.IrExpressionInterpreter.evaluateConstantExpression; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression.getQualifiedName; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Join.Type.FULL; import static org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Join.Type.INNER; @@ -3981,15 +3983,12 @@ private void analyzeOffset(Offset node, Scope scope) { if (node.getRowCount() instanceof LongLiteral) { rowCount = ((LongLiteral) node.getRowCount()).getParsedValue(); } else { - // checkState( - // node.getRowCount() instanceof Parameter, - // "unexpected OFFSET rowCount: " + - // node.getRowCount().getClass().getSimpleName()); - throw new SemanticException( + checkState( + node.getRowCount() instanceof Parameter, "unexpected OFFSET rowCount: " + node.getRowCount().getClass().getSimpleName()); - // OptionalLong providedValue = - // analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "OFFSET"); - // rowCount = providedValue.orElse(0); + OptionalLong providedValue = + analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "OFFSET"); + rowCount = providedValue.orElse(0); } if (rowCount < 0) { throw new SemanticException( @@ -4050,14 +4049,10 @@ private boolean analyzeLimit(Limit node, Scope scope) { } else if (node.getRowCount() instanceof LongLiteral) { rowCount = OptionalLong.of(((LongLiteral) node.getRowCount()).getParsedValue()); } else { - // checkState( - // node.getRowCount() instanceof Parameter, - // "unexpected LIMIT rowCount: " + - // node.getRowCount().getClass().getSimpleName()); - throw new SemanticException( + checkState( + node.getRowCount() instanceof Parameter, "unexpected LIMIT rowCount: " + node.getRowCount().getClass().getSimpleName()); - // rowCount = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, - // "LIMIT"); + rowCount = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "LIMIT"); } rowCount.ifPresent( count -> { @@ -4073,32 +4068,27 @@ private boolean analyzeLimit(Limit node, Scope scope) { return false; } - // private OptionalLong analyzeParameterAsRowCount( - // Parameter parameter, Scope scope, String context) { - // // validate parameter index - // analyzeExpression(parameter, scope); - // Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); - // Object value; - // try { - // value = - // evaluateConstantExpression( - // providedValue, - // BIGINT, - // plannerContext, - // session, - // accessControl, - // analysis.getParameters()); - // } catch (VerifyException e) { - // throw new SemanticException( - // String.format("Non constant parameter value for %s: %s", context, providedValue)); - // } - // if (value == null) { - // throw new SemanticException( - // String.format("Parameter value provided for %s is NULL: %s", context, - // providedValue)); - // } - // return OptionalLong.of((long) value); - // } + private OptionalLong analyzeParameterAsRowCount( + Parameter parameter, Scope scope, String context) { + // validate parameter index + analyzeExpression(parameter, scope); + Expression providedValue = analysis.getParameters().get(NodeRef.of(parameter)); + Object value; + try { + value = + evaluateConstantExpression( + providedValue, new PlannerContext(metadata, typeManager), sessionContext); + + } catch (VerifyException e) { + throw new SemanticException( + String.format("Non constant parameter value for %s: %s", context, providedValue)); + } + if (value == null) { + throw new SemanticException( + String.format("Parameter value provided for %s is NULL: %s", context, providedValue)); + } + return OptionalLong.of((long) value); + } private void analyzeAggregations( QuerySpecification node, @@ -5191,7 +5181,7 @@ private ArgumentAnalysis analyzeScalarArgument( Expression expression, ScalarParameterSpecification argumentSpecification) { // currently, only constant arguments are supported Object constantValue = - IrExpressionInterpreter.evaluateConstantExpression( + evaluateConstantExpression( expression, new PlannerContext(metadata, typeManager), sessionContext); if (!argumentSpecification.getType().checkObjectType(constantValue)) { if ((argumentSpecification.getType().equals(org.apache.iotdb.udf.api.type.Type.STRING) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java index fddff825f0d2..78f9729ece34 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/TableModelPlanner.java @@ -35,13 +35,16 @@ import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analysis; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.Analyzer; +import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.StatementAnalyzerFactory; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.planner.distribute.TableDistributedPlanner; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.DataNodeLocationSupplierFactory; import org.apache.iotdb.db.queryengine.plan.relational.planner.optimizations.PlanOptimizer; import org.apache.iotdb.db.queryengine.plan.relational.security.AccessControl; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LoadTsFile; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PipeEnriched; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WrappedInsertStatement; @@ -57,8 +60,8 @@ import org.apache.iotdb.rpc.TSStatusCode; import java.util.ArrayList; -import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import static org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet.DISTRIBUTION_PLANNER; @@ -88,6 +91,9 @@ public class TableModelPlanner implements IPlanner { private final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier; + // Parameters for prepared statements (optional) + private final List parameters; + private final Map, Expression> parameterLookup; private final TypeManager typeManager; public TableModelPlanner( @@ -104,6 +110,8 @@ public TableModelPlanner( final List distributionPlanOptimizers, final AccessControl accessControl, final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier, + final List parameters, + final Map, Expression> parameterLookup, final TypeManager typeManager) { this.statement = statement; this.sqlParser = sqlParser; @@ -116,6 +124,8 @@ public TableModelPlanner( this.distributionPlanOptimizers = distributionPlanOptimizers; this.accessControl = accessControl; this.dataNodeLocationSupplier = dataNodeLocationSupplier; + this.parameters = parameters; + this.parameterLookup = parameterLookup; this.typeManager = typeManager; } @@ -125,8 +135,8 @@ public IAnalysis analyze(final MPPQueryContext context) { context, context.getSession(), new StatementAnalyzerFactory(metadata, sqlParser, accessControl, typeManager), - Collections.emptyList(), - Collections.emptyMap(), + parameters, + parameterLookup, statementRewrite, warningCollector) .analyze(statement); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java new file mode 100644 index 000000000000..d45f6546e0f6 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/AstMemoryEstimator.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql; + +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultTraversalVisitor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Node; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +import org.apache.tsfile.utils.RamUsageEstimator; + +/** + * Utility class for estimating memory usage of AST nodes. Uses RamUsageEstimator to calculate + * approximate memory size. + */ +public final class AstMemoryEstimator { + private AstMemoryEstimator() {} + + /** + * Estimate the memory size of a Statement AST node in bytes. + * + * @param statement the statement AST to estimate + * @return estimated memory size in bytes + */ + public static long estimateMemorySize(Statement statement) { + if (statement == null) { + return 0L; + } + MemoryEstimatingVisitor visitor = new MemoryEstimatingVisitor(); + visitor.process(statement, null); + return visitor.getTotalMemorySize(); + } + + private static class MemoryEstimatingVisitor extends DefaultTraversalVisitor { + private long totalMemorySize = 0L; + + public long getTotalMemorySize() { + return totalMemorySize; + } + + @Override + protected Void visitNode(Node node, Void context) { + // Estimate shallow size of the node object + long nodeSize = RamUsageEstimator.shallowSizeOfInstance(node.getClass()); + totalMemorySize += nodeSize; + + // Traverse children (DefaultTraversalVisitor handles this) + return super.visitNode(node, context); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java new file mode 100644 index 000000000000..d727acdc35e2 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ParameterExtractor.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.db.queryengine.plan.relational.sql; + +import org.apache.iotdb.db.exception.sql.SemanticException; +import org.apache.iotdb.db.queryengine.plan.relational.analyzer.NodeRef; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DefaultTraversalVisitor; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Parameter; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement; + +import com.google.common.collect.ImmutableMap; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +/** Utility class for extracting and binding parameters in prepared statements. */ +public final class ParameterExtractor { + private ParameterExtractor() {} + + /** + * Get the number of parameters in a statement. + * + * @param statement the statement to analyze + * @return the number of parameters + */ + public static int getParameterCount(Statement statement) { + return extractParameters(statement).size(); + } + + /** + * Extract all Parameter nodes from a statement in order of appearance. + * + * @param statement the statement to analyze + * @return list of Parameter nodes in order of appearance + */ + public static List extractParameters(Statement statement) { + ParameterExtractingVisitor visitor = new ParameterExtractingVisitor(); + visitor.process(statement, null); + return visitor.getParameters().stream() + .sorted( + Comparator.comparing( + parameter -> + parameter + .getLocation() + .orElseThrow( + () -> new SemanticException("Parameter node must have a location")), + Comparator.comparing( + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NodeLocation + ::getLineNumber) + .thenComparing( + org.apache.iotdb.db.queryengine.plan.relational.sql.ast.NodeLocation + ::getColumnNumber))) + .collect(toImmutableList()); + } + + /** + * Bind parameter values to Parameter nodes in a statement. Creates a map from Parameter node + * references to their corresponding Expression values. + * + * @param statement the statement containing Parameter nodes + * @param values the parameter values (in order) + * @return map from Parameter node references to Expression values + * @throws SemanticException if the number of parameters doesn't match + */ + public static Map, Expression> bindParameters( + Statement statement, List values) { + List parametersList = extractParameters(statement); + + // Validate parameter count + if (parametersList.size() != values.size()) { + throw new SemanticException( + String.format( + "Invalid number of parameters: expected %d, got %d", + parametersList.size(), values.size())); + } + + ImmutableMap.Builder, Expression> builder = ImmutableMap.builder(); + Iterator iterator = values.iterator(); + for (Parameter parameter : parametersList) { + builder.put(NodeRef.of(parameter), iterator.next()); + } + return builder.buildOrThrow(); + } + + private static class ParameterExtractingVisitor extends DefaultTraversalVisitor { + private final List parameters = new ArrayList<>(); + + public List getParameters() { + return parameters; + } + + @Override + protected Void visitParameter(Parameter node, Void context) { + parameters.add(node); + return null; + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java index 54802a1d2f60..2728750418fe 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/AstVisitor.java @@ -93,6 +93,22 @@ protected R visitUse(Use node, C context) { return visitStatement(node, context); } + protected R visitPrepare(Prepare node, C context) { + return visitStatement(node, context); + } + + protected R visitExecute(Execute node, C context) { + return visitStatement(node, context); + } + + protected R visitExecuteImmediate(ExecuteImmediate node, C context) { + return visitStatement(node, context); + } + + protected R visitDeallocate(Deallocate node, C context) { + return visitStatement(node, context); + } + protected R visitGenericLiteral(GenericLiteral node, C context) { return visitLiteral(node, context); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java new file mode 100644 index 000000000000..fd579bb64b5b --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Deallocate.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** DEALLOCATE PREPARE statement AST node. Example: DEALLOCATE PREPARE stmt1 */ +public final class Deallocate extends Statement { + + private final Identifier statementName; + + public Deallocate(Identifier statementName) { + this(null, statementName); + } + + public Deallocate(NodeLocation location, Identifier statementName) { + super(location); + this.statementName = requireNonNull(statementName, "statementName is null"); + } + + public Identifier getStatementName() { + return statementName; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitDeallocate(this, context); + } + + @Override + public List getChildren() { + return ImmutableList.of(statementName); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Deallocate that = (Deallocate) o; + return Objects.equals(statementName, that.statementName); + } + + @Override + public int hashCode() { + return Objects.hash(statementName); + } + + @Override + public String toString() { + return toStringHelper(this).add("statementName", statementName).toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java new file mode 100644 index 000000000000..d7e219faf1b9 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Execute.java @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** EXECUTE statement AST node. Example: EXECUTE stmt1 USING 100, 'test' */ +public final class Execute extends Statement { + + private final Identifier statementName; + private final List parameters; + + public Execute(Identifier statementName) { + this(null, statementName, ImmutableList.of()); + } + + public Execute(Identifier statementName, List parameters) { + this(null, statementName, parameters); + } + + public Execute(NodeLocation location, Identifier statementName, List parameters) { + super(location); + this.statementName = requireNonNull(statementName, "statementName is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + } + + public Identifier getStatementName() { + return statementName; + } + + public List getParameters() { + return parameters; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitExecute(this, context); + } + + @Override + public List getChildren() { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(statementName); + children.addAll(parameters); + return children.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Execute that = (Execute) o; + return Objects.equals(statementName, that.statementName) + && Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(statementName, parameters); + } + + @Override + public String toString() { + return toStringHelper(this) + .add("statementName", statementName) + .add("parameters", parameters) + .toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java new file mode 100644 index 000000000000..955ac54e4fb8 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/ExecuteImmediate.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** + * EXECUTE IMMEDIATE statement AST node. Example: EXECUTE IMMEDIATE 'SELECT * FROM table WHERE id = + * 100' + */ +public final class ExecuteImmediate extends Statement { + + private final StringLiteral sql; + private final List parameters; + + public ExecuteImmediate(StringLiteral sql) { + this(null, sql, ImmutableList.of()); + } + + public ExecuteImmediate(StringLiteral sql, List parameters) { + this(null, sql, parameters); + } + + public ExecuteImmediate(NodeLocation location, StringLiteral sql, List parameters) { + super(location); + this.sql = requireNonNull(sql, "sql is null"); + this.parameters = ImmutableList.copyOf(requireNonNull(parameters, "parameters is null")); + } + + public StringLiteral getSql() { + return sql; + } + + public String getSqlString() { + return sql.getValue(); + } + + public List getParameters() { + return parameters; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitExecuteImmediate(this, context); + } + + @Override + public List getChildren() { + ImmutableList.Builder children = ImmutableList.builder(); + children.add(sql); + children.addAll(parameters); + return children.build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ExecuteImmediate that = (ExecuteImmediate) o; + return Objects.equals(sql, that.sql) && Objects.equals(parameters, that.parameters); + } + + @Override + public int hashCode() { + return Objects.hash(sql, parameters); + } + + @Override + public String toString() { + return toStringHelper(this).add("sql", sql).add("parameters", parameters).toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java new file mode 100644 index 000000000000..f413b8a19266 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Prepare.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.plan.relational.sql.ast; + +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Objects; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static java.util.Objects.requireNonNull; + +/** PREPARE statement AST node. Example: PREPARE stmt1 FROM SELECT * FROM table WHERE id = ? */ +public final class Prepare extends Statement { + + private final Identifier statementName; + private final Statement sql; + + public Prepare(Identifier statementName, Statement sql) { + super(null); + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + } + + public Prepare(NodeLocation location, Identifier statementName, Statement sql) { + super(location); + this.statementName = requireNonNull(statementName, "statementName is null"); + this.sql = requireNonNull(sql, "sql is null"); + } + + public Identifier getStatementName() { + return statementName; + } + + public Statement getSql() { + return sql; + } + + @Override + public R accept(AstVisitor visitor, C context) { + return visitor.visitPrepare(this, context); + } + + @Override + public List getChildren() { + return ImmutableList.of(statementName, sql); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + Prepare that = (Prepare) o; + return Objects.equals(statementName, that.statementName) && Objects.equals(sql, that.sql); + } + + @Override + public int hashCode() { + return Objects.hash(statementName, sql); + } + + @Override + public String toString() { + return toStringHelper(this).add("statementName", statementName).add("sql", sql).toString(); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java index f62b45d8b8a8..b041c56b2bc9 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/parser/AstBuilder.java @@ -71,6 +71,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.CurrentUser; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataType; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DataTypeParameter; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Deallocate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Delete; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DeleteDevice; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.DereferenceExpression; @@ -89,6 +90,8 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.EmptyPattern; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Except; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExcludedPattern; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Execute; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExecuteImmediate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExistsPredicate; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Explain; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ExplainAnalyze; @@ -145,6 +148,7 @@ import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternRecognitionRelation.RowsPerMatch; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.PatternVariable; +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Prepare; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ProcessingMode; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Property; import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.QualifiedName; @@ -3786,6 +3790,40 @@ public Node visitDropModelStatement(RelationalSqlParser.DropModelStatementContex return new DropModel(modelId); } + @Override + public Node visitPrepareStatement(RelationalSqlParser.PrepareStatementContext ctx) { + Identifier statementName = lowerIdentifier((Identifier) visit(ctx.statementName)); + Statement sql = (Statement) visit(ctx.sql); + return new Prepare(getLocation(ctx), statementName, sql); + } + + @Override + public Node visitExecuteStatement(RelationalSqlParser.ExecuteStatementContext ctx) { + Identifier statementName = lowerIdentifier((Identifier) visit(ctx.statementName)); + List parameters = + ctx.literalExpression() != null && !ctx.literalExpression().isEmpty() + ? visit(ctx.literalExpression(), Literal.class) + : ImmutableList.of(); + return new Execute(getLocation(ctx), statementName, parameters); + } + + @Override + public Node visitExecuteImmediateStatement( + RelationalSqlParser.ExecuteImmediateStatementContext ctx) { + StringLiteral sql = (StringLiteral) visit(ctx.sql); + List parameters = + ctx.literalExpression() != null && !ctx.literalExpression().isEmpty() + ? visit(ctx.literalExpression(), Literal.class) + : ImmutableList.of(); + return new ExecuteImmediate(getLocation(ctx), sql, parameters); + } + + @Override + public Node visitDeallocateStatement(RelationalSqlParser.DeallocateStatementContext ctx) { + Identifier statementName = lowerIdentifier((Identifier) visit(ctx.statementName)); + return new Deallocate(getLocation(ctx), statementName); + } + // ***************** arguments ***************** @Override public Node visitGenericType(RelationalSqlParser.GenericTypeContext ctx) { diff --git a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 index 4bd8a2b0bd7e..1690ea4c855b 100644 --- a/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 +++ b/iotdb-core/relational-grammar/src/main/antlr4/org/apache/iotdb/db/relational/grammar/sql/RelationalSql.g4 @@ -174,6 +174,12 @@ statement | loadModelStatement | unloadModelStatement + // Prepared Statement + | prepareStatement + | executeStatement + | executeImmediateStatement + | deallocateStatement + // View, Trigger, CQ, Quota are not supported yet ; @@ -857,6 +863,23 @@ unloadModelStatement : UNLOAD MODEL existingModelId=identifier FROM DEVICES deviceIdList=string ; +// ------------------------------------------- Prepared Statement --------------------------------------------------------- +prepareStatement + : PREPARE statementName=identifier FROM sql=statement + ; + +executeStatement + : EXECUTE statementName=identifier (USING literalExpression (',' literalExpression)*)? + ; + +executeImmediateStatement + : EXECUTE IMMEDIATE sql=string (USING literalExpression (',' literalExpression)*)? + ; + +deallocateStatement + : DEALLOCATE PREPARE statementName=identifier + ; + // ------------------------------------------- Query Statement --------------------------------------------------------- queryStatement : query #statementDefault diff --git a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift index 7d334059fff0..48afb89d3366 100644 --- a/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift +++ b/iotdb-protocol/thrift-datanode/src/main/thrift/client.thrift @@ -164,6 +164,7 @@ struct TSCloseOperationReq { 1: required i64 sessionId 2: optional i64 queryId 3: optional i64 statementId + 4: optional string preparedStatementName } struct TSFetchResultsReq{ From 8e1d0e36b4917a8825d81574386d2aac0ce7ed2f Mon Sep 17 00:00:00 2001 From: Zhenyu Luo Date: Tue, 9 Dec 2025 14:16:40 +0800 Subject: [PATCH 04/13] Load: Fix excessive GC caused by loading too many TsFiles at once (#16853) * Fix excessive GC caused by loading too many TsFiles at once When loading multiple TsFiles, all file resources were loaded into memory simultaneously, causing excessive memory consumption and frequent GC pauses. This commit introduces batch execution for multi-file loading scenarios: 1. Split LoadTsFileStatement/LoadTsFile into sub-statements, each handling one TsFile, to avoid loading all file resources at once 2. Refactor duplicate code in ClientRPCServiceImpl by extracting helper methods for both tree model and table model 3. Add progress logging to track the loading status of each file 4. Support both synchronous and asynchronous loading modes Changes: - Added getSubStatement() method to LoadTsFileStatement and LoadTsFile for splitting multi-file statements - Extracted shouldSplitLoadTsFileStatement() and shouldSplitTableLoadTsFile() to determine if splitting is needed - Extracted executeBatchLoadTsFile() and executeBatchTableLoadTsFile() to handle batch execution with progress logging - Applied the optimization to 4 execution paths (tree/table model, sync/async loading) This fix significantly reduces memory pressure and improves system stability when loading large numbers of TsFiles. * fix * update (cherry picked from commit bc4f8e9bd8160e2a10745c7dd2fc5cdb8732de7a) --- .../org/apache/iotdb/db/conf/IoTDBConfig.java | 55 +++ .../apache/iotdb/db/conf/IoTDBDescriptor.java | 24 ++ .../thrift/impl/ClientRPCServiceImpl.java | 319 +++++++++++++++--- .../plan/relational/sql/ast/LoadTsFile.java | 61 +++- .../plan/relational/sql/ast/Statement.java | 25 ++ .../queryengine/plan/statement/Statement.java | 23 ++ .../statement/crud/LoadTsFileStatement.java | 48 +++ 7 files changed, 511 insertions(+), 44 deletions(-) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java index 375285dfa2fb..7591e9d3cabf 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java @@ -1121,6 +1121,21 @@ public class IoTDBConfig { private int loadTsFileSpiltPartitionMaxSize = 10; + /** + * The threshold for splitting statement when loading multiple TsFiles. When the number of TsFiles + * exceeds this threshold, the statement will be split into multiple sub-statements for batch + * execution to limit resource consumption during statement analysis. Default value is 10, which + * means splitting will occur when there are more than 10 files. + */ + private int loadTsFileStatementSplitThreshold = 10; + + /** + * The number of TsFiles that each sub-statement handles when splitting a statement. This + * parameter controls how many files are grouped together in each sub-statement during batch + * execution. Default value is 10, which means each sub-statement handles 10 files. + */ + private int loadTsFileSubStatementBatchSize = 10; + private String[] loadActiveListeningDirs = new String[] { IoTDBConstant.EXT_FOLDER_NAME @@ -4057,6 +4072,46 @@ public void setLoadTsFileSpiltPartitionMaxSize(int loadTsFileSpiltPartitionMaxSi this.loadTsFileSpiltPartitionMaxSize = loadTsFileSpiltPartitionMaxSize; } + public int getLoadTsFileStatementSplitThreshold() { + return loadTsFileStatementSplitThreshold; + } + + public void setLoadTsFileStatementSplitThreshold(final int loadTsFileStatementSplitThreshold) { + if (loadTsFileStatementSplitThreshold < 0) { + logger.warn( + "Invalid loadTsFileStatementSplitThreshold value: {}. Using default value: 10", + loadTsFileStatementSplitThreshold); + return; + } + if (this.loadTsFileStatementSplitThreshold != loadTsFileStatementSplitThreshold) { + logger.info( + "loadTsFileStatementSplitThreshold changed from {} to {}", + this.loadTsFileStatementSplitThreshold, + loadTsFileStatementSplitThreshold); + } + this.loadTsFileStatementSplitThreshold = loadTsFileStatementSplitThreshold; + } + + public int getLoadTsFileSubStatementBatchSize() { + return loadTsFileSubStatementBatchSize; + } + + public void setLoadTsFileSubStatementBatchSize(final int loadTsFileSubStatementBatchSize) { + if (loadTsFileSubStatementBatchSize <= 0) { + logger.warn( + "Invalid loadTsFileSubStatementBatchSize value: {}. Using default value: 10", + loadTsFileSubStatementBatchSize); + return; + } + if (this.loadTsFileSubStatementBatchSize != loadTsFileSubStatementBatchSize) { + logger.info( + "loadTsFileSubStatementBatchSize changed from {} to {}", + this.loadTsFileSubStatementBatchSize, + loadTsFileSubStatementBatchSize); + } + this.loadTsFileSubStatementBatchSize = loadTsFileSubStatementBatchSize; + } + public String[] getPipeReceiverFileDirs() { return (Objects.isNull(this.pipeReceiverFileDirs) || this.pipeReceiverFileDirs.length == 0) ? new String[] {systemDir + File.separator + "pipe" + File.separator + "receiver"} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java index c194f304e7f6..bd579b24c3f4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java @@ -2406,6 +2406,18 @@ private void loadLoadTsFileProps(TrimProperties properties) { properties.getProperty( "skip_failed_table_schema_check", String.valueOf(conf.isSkipFailedTableSchemaCheck())))); + + conf.setLoadTsFileStatementSplitThreshold( + Integer.parseInt( + properties.getProperty( + "load_tsfile_statement_split_threshold", + Integer.toString(conf.getLoadTsFileStatementSplitThreshold())))); + + conf.setLoadTsFileSubStatementBatchSize( + Integer.parseInt( + properties.getProperty( + "load_tsfile_sub_statement_batch_size", + Integer.toString(conf.getLoadTsFileSubStatementBatchSize())))); } private void loadLoadTsFileHotModifiedProp(TrimProperties properties) throws IOException { @@ -2454,6 +2466,18 @@ private void loadLoadTsFileHotModifiedProp(TrimProperties properties) throws IOE "load_tsfile_split_partition_max_size", Integer.toString(conf.getLoadTsFileSpiltPartitionMaxSize())))); + conf.setLoadTsFileStatementSplitThreshold( + Integer.parseInt( + properties.getProperty( + "load_tsfile_statement_split_threshold", + Integer.toString(conf.getLoadTsFileStatementSplitThreshold())))); + + conf.setLoadTsFileSubStatementBatchSize( + Integer.parseInt( + properties.getProperty( + "load_tsfile_sub_statement_batch_size", + Integer.toString(conf.getLoadTsFileSubStatementBatchSize())))); + conf.setSkipFailedTableSchemaCheck( Boolean.parseBoolean( properties.getProperty( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java index ad16535bf573..07c2800799c8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java @@ -365,16 +365,30 @@ private TSExecuteStatementResp executeStatementInternal( queryId = SESSION_MANAGER.requestQueryId(clientSession, req.statementId); - result = - COORDINATOR.executeForTreeModel( - s, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - partitionFetcher, - schemaFetcher, - req.getTimeout(), - true); + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchStatement( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + config.getQueryTimeoutThreshold(), + true); + } else { + result = + COORDINATOR.executeForTreeModel( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + req.getTimeout(), + true); + } } } else { org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement s = @@ -396,17 +410,32 @@ private TSExecuteStatementResp executeStatementInternal( queryId = SESSION_MANAGER.requestQueryId(clientSession, req.statementId); - result = - COORDINATOR.executeForTableModel( - s, - relationSqlParser, - clientSession, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - metadata, - req.getTimeout(), - true); + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchTableStatement( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + config.getQueryTimeoutThreshold(), + true); + } else { + result = + COORDINATOR.executeForTableModel( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + req.getTimeout(), + true); + } } if (result.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode() @@ -1846,16 +1875,31 @@ public TSStatus executeBatchStatement(TSExecuteBatchStatementReq req) { queryId = SESSION_MANAGER.requestQueryId(); type = s.getType() == null ? null : s.getType().name(); // create and cache dataset - result = - COORDINATOR.executeForTreeModel( - s, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - partitionFetcher, - schemaFetcher, - config.getQueryTimeoutThreshold(), - false); + + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchStatement( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + config.getQueryTimeoutThreshold(), + false); + } else { + result = + COORDINATOR.executeForTreeModel( + s, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + partitionFetcher, + schemaFetcher, + config.getQueryTimeoutThreshold(), + false); + } } } else { @@ -1876,17 +1920,32 @@ public TSStatus executeBatchStatement(TSExecuteBatchStatementReq req) { queryId = SESSION_MANAGER.requestQueryId(); - result = - COORDINATOR.executeForTableModel( - s, - relationSqlParser, - clientSession, - queryId, - SESSION_MANAGER.getSessionInfo(clientSession), - statement, - metadata, - config.getQueryTimeoutThreshold(), - false); + // Split statement if needed to limit resource consumption during statement analysis + if (s.shouldSplit()) { + result = + executeBatchTableStatement( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + config.getQueryTimeoutThreshold(), + false); + } else { + result = + COORDINATOR.executeForTableModel( + s, + relationSqlParser, + clientSession, + queryId, + SESSION_MANAGER.getSessionInfo(clientSession), + statement, + metadata, + config.getQueryTimeoutThreshold(), + false); + } } results.add(result.status); @@ -3191,4 +3250,180 @@ public void handleClientExit() { PipeDataNodeAgent.receiver().legacy().handleClientExit(); SubscriptionAgent.receiver().handleClientExit(); } + + /** + * Executes tree-model Statement sub-statements in batch. + * + * @param statement the Statement to be executed + * @param queryId the query ID + * @param sessionInfo the session information + * @param statementStr the SQL statement string + * @param partitionFetcher the partition fetcher + * @param schemaFetcher the schema fetcher + * @param timeoutMs the timeout in milliseconds + * @return the execution result + */ + private ExecutionResult executeBatchStatement( + final Statement statement, + final long queryId, + final SessionInfo sessionInfo, + final String statementStr, + final IPartitionFetcher partitionFetcher, + final ISchemaFetcher schemaFetcher, + final long timeoutMs, + final boolean userQuery) { + + ExecutionResult result = null; + final List subStatements = statement.getSubStatements(); + final int totalSubStatements = subStatements.size(); + + LOGGER.info( + "Start batch executing {} sub-statement(s) in tree model, queryId: {}", + totalSubStatements, + queryId); + + for (int i = 0; i < totalSubStatements; i++) { + final Statement subStatement = subStatements.get(i); + + LOGGER.info( + "Executing sub-statement {}/{} in tree model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + + result = + COORDINATOR.executeForTreeModel( + subStatement, + queryId, + sessionInfo, + statementStr, + partitionFetcher, + schemaFetcher, + timeoutMs, + userQuery); + + // Exit early if any sub-statement execution fails + if (result != null + && result.status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + final int completed = i + 1; + final int remaining = totalSubStatements - completed; + final double percentage = (completed * 100.0) / totalSubStatements; + LOGGER.warn( + "Failed to execute sub-statement {}/{} in tree model, queryId: {}, completed: {}, remaining: {}, progress: {}%, error: {}", + i + 1, + totalSubStatements, + queryId, + completed, + remaining, + String.format("%.2f", percentage), + result.status.getMessage()); + break; + } + + LOGGER.info( + "Successfully executed sub-statement {}/{} in tree model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + } + + if (result != null && result.status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.info( + "Completed batch executing all {} sub-statement(s) in tree model, queryId: {}", + totalSubStatements, + queryId); + } + + return result; + } + + /** + * Executes table-model Statement sub-statements in batch. + * + * @param statement the Statement to be executed + * @param relationSqlParser the relational SQL parser + * @param clientSession the client session + * @param queryId the query ID + * @param sessionInfo the session information + * @param statementStr the SQL statement string + * @param metadata the metadata + * @param timeoutMs the timeout in milliseconds + * @return the execution result + */ + private ExecutionResult executeBatchTableStatement( + final org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement statement, + final SqlParser relationSqlParser, + final IClientSession clientSession, + final long queryId, + final SessionInfo sessionInfo, + final String statementStr, + final Metadata metadata, + final long timeoutMs, + final boolean userQuery) { + + ExecutionResult result = null; + List + subStatements = statement.getSubStatements(); + int totalSubStatements = subStatements.size(); + LOGGER.info( + "Start batch executing {} sub-statement(s) in table model, queryId: {}", + totalSubStatements, + queryId); + + for (int i = 0; i < totalSubStatements; i++) { + final org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement subStatement = + subStatements.get(i); + + LOGGER.info( + "Executing sub-statement {}/{} in table model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + + result = + COORDINATOR.executeForTableModel( + subStatement, + relationSqlParser, + clientSession, + queryId, + sessionInfo, + statementStr, + metadata, + timeoutMs, + userQuery); + + // Exit early if any sub-statement execution fails + if (result != null + && result.status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + final int completed = i + 1; + final int remaining = totalSubStatements - completed; + final double percentage = (completed * 100.0) / totalSubStatements; + LOGGER.warn( + "Failed to execute sub-statement {}/{} in table model, queryId: {}, completed: {}, remaining: {}, progress: {}%, error: {}", + i + 1, + totalSubStatements, + queryId, + completed, + remaining, + String.format("%.2f", percentage), + result.status.getMessage()); + break; + } + + LOGGER.info( + "Successfully executed sub-statement {}/{} in table model, queryId: {}", + i + 1, + totalSubStatements, + queryId); + } + + if (result != null && result.status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { + LOGGER.info( + "Completed batch executing all {} sub-statement(s) in table model, queryId: {}", + totalSubStatements, + queryId); + } + + return result; + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java index 93fc8c7b5833..166f06b85e32 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/LoadTsFile.java @@ -37,7 +37,7 @@ public class LoadTsFile extends Statement { - private final String filePath; + private String filePath; private int databaseLevel; // For loading to tree-model only private String database; // For loading to table-model only @@ -50,7 +50,7 @@ public class LoadTsFile extends Statement { private boolean isGeneratedByPipe = false; - private final Map loadAttributes; + private Map loadAttributes; private List tsFiles; private List resources; @@ -232,6 +232,63 @@ public boolean reconstructStatementIfMiniFileConverted(final List isMin return tsFiles == null || tsFiles.isEmpty(); } + @Override + public boolean shouldSplit() { + final int splitThreshold = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileStatementSplitThreshold(); + return tsFiles.size() > splitThreshold && !isAsyncLoad; + } + + /** + * Splits the current LoadTsFile statement into multiple sub-statements, each handling a batch of + * TsFiles. Used to limit resource consumption during statement analysis, etc. + * + * @return the list of sub-statements + */ + @Override + public List getSubStatements() { + final int batchSize = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileSubStatementBatchSize(); + final int totalBatches = (tsFiles.size() + batchSize - 1) / batchSize; // Ceiling division + final List subStatements = new ArrayList<>(totalBatches); + + for (int i = 0; i < tsFiles.size(); i += batchSize) { + final int endIndex = Math.min(i + batchSize, tsFiles.size()); + final List batchFiles = tsFiles.subList(i, endIndex); + + // Use the first file's path for the sub-statement + final String filePath = batchFiles.get(0).getAbsolutePath(); + final Map properties = this.loadAttributes; + + final LoadTsFile subStatement = + new LoadTsFile(getLocation().orElse(null), filePath, properties); + + // Copy all configuration properties + subStatement.databaseLevel = this.databaseLevel; + subStatement.database = this.database; + subStatement.verify = this.verify; + subStatement.deleteAfterLoad = this.deleteAfterLoad; + subStatement.convertOnTypeMismatch = this.convertOnTypeMismatch; + subStatement.tabletConversionThresholdBytes = this.tabletConversionThresholdBytes; + subStatement.autoCreateDatabase = this.autoCreateDatabase; + subStatement.isAsyncLoad = this.isAsyncLoad; + subStatement.isGeneratedByPipe = this.isGeneratedByPipe; + + // Set all files in the batch + subStatement.tsFiles = new ArrayList<>(batchFiles); + subStatement.resources = new ArrayList<>(batchFiles.size()); + subStatement.writePointCountList = new ArrayList<>(batchFiles.size()); + subStatement.isTableModel = new ArrayList<>(batchFiles.size()); + for (int j = 0; j < batchFiles.size(); j++) { + subStatement.isTableModel.add(true); + } + + subStatements.add(subStatement); + } + + return subStatements; + } + @Override public R accept(AstVisitor visitor, C context) { return visitor.visitLoadTsFile(this, context); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java index 0352c85f1eb4..7ba19b972a29 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/sql/ast/Statement.java @@ -21,6 +21,9 @@ import javax.annotation.Nullable; +import java.util.Collections; +import java.util.List; + public abstract class Statement extends Node { protected Statement(final @Nullable NodeLocation location) { @@ -31,4 +34,26 @@ protected Statement(final @Nullable NodeLocation location) { public R accept(final AstVisitor visitor, final C context) { return visitor.visitStatement(this, context); } + + /** + * Checks whether this statement should be split into multiple sub-statements based on the given + * async requirement. Used to limit resource consumption during statement analysis, etc. + * + * @param requireAsync whether async execution is required + * @return true if the statement should be split, false otherwise. Default implementation returns + * false. + */ + public boolean shouldSplit() { + return false; + } + + /** + * Splits the current statement into multiple sub-statements. Used to limit resource consumption + * during statement analysis, etc. + * + * @return the list of sub-statements. Default implementation returns empty list. + */ + public List getSubStatements() { + return Collections.emptyList(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java index 5b31f08ca677..0b2ecff6b5bc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/Statement.java @@ -23,6 +23,7 @@ import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.plan.parser.ASTVisitor; +import java.util.Collections; import java.util.List; /** @@ -68,4 +69,26 @@ public org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Statement toRelat public String getPipeLoggingString() { return toString(); } + + /** + * Checks whether this statement should be split into multiple sub-statements based on the given + * async requirement. Used to limit resource consumption during statement analysis, etc. + * + * @param requireAsync whether async execution is required + * @return true if the statement should be split, false otherwise. Default implementation returns + * false. + */ + public boolean shouldSplit() { + return false; + } + + /** + * Splits the current statement into multiple sub-statements. Used to limit resource consumption + * during statement analysis, etc. + * + * @return the list of sub-statements. Default implementation returns empty list. + */ + public List getSubStatements() { + return Collections.emptyList(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java index d1dff1bb9cf2..a51dcaf09d2b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/LoadTsFileStatement.java @@ -306,6 +306,54 @@ public boolean reconstructStatementIfMiniFileConverted(final List isMin return tsFiles == null || tsFiles.isEmpty(); } + @Override + public boolean shouldSplit() { + final int splitThreshold = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileStatementSplitThreshold(); + return tsFiles.size() > splitThreshold && !isAsyncLoad; + } + + /** + * Splits the current LoadTsFileStatement into multiple sub-statements, each handling a batch of + * TsFiles. Used to limit resource consumption during statement analysis, etc. + * + * @return the list of sub-statements + */ + @Override + public List getSubStatements() { + final int batchSize = + IoTDBDescriptor.getInstance().getConfig().getLoadTsFileSubStatementBatchSize(); + final int totalBatches = (tsFiles.size() + batchSize - 1) / batchSize; // Ceiling division + final List subStatements = new ArrayList<>(totalBatches); + + for (int i = 0; i < tsFiles.size(); i += batchSize) { + final int endIndex = Math.min(i + batchSize, tsFiles.size()); + final List batchFiles = tsFiles.subList(i, endIndex); + + final LoadTsFileStatement statement = new LoadTsFileStatement(); + statement.databaseLevel = this.databaseLevel; + statement.verifySchema = this.verifySchema; + statement.deleteAfterLoad = this.deleteAfterLoad; + statement.convertOnTypeMismatch = this.convertOnTypeMismatch; + statement.tabletConversionThresholdBytes = this.tabletConversionThresholdBytes; + statement.autoCreateDatabase = this.autoCreateDatabase; + statement.isAsyncLoad = this.isAsyncLoad; + statement.isGeneratedByPipe = this.isGeneratedByPipe; + + statement.tsFiles = new ArrayList<>(batchFiles); + statement.resources = new ArrayList<>(batchFiles.size()); + statement.writePointCountList = new ArrayList<>(batchFiles.size()); + statement.isTableModel = new ArrayList<>(batchFiles.size()); + for (int j = 0; j < batchFiles.size(); j++) { + statement.isTableModel.add(false); + } + + subStatements.add(statement); + } + + return subStatements; + } + @Override public List getPaths() { return Collections.emptyList(); From b4604237b032d736f263af62ef39ea8123fbefa3 Mon Sep 17 00:00:00 2001 From: Zhenyu Luo Date: Tue, 9 Dec 2025 15:00:00 +0800 Subject: [PATCH 05/13] Pipe: Modify the TableRawReq deserialization method to support directconversion to TableStatement. (#16844) * Pipe: Modify the TableRawReq deserialization method to support direct conversion to TableStatement. * fix * fix * fix * fix * fix * update * update * update * refactor: optimize TabletStatementConverter according to code review - Optimize times array copy: skip copy when lengths are equal, use System.arraycopy - Add warning logs when times array is null or too small - Ensure all arrays (values, times, bitMaps) are copied to rowSize length for immutability - Filter out null columns when converting Statement to Tablet - Rename idColumnIndices to tagColumnIndices - Add skipString method to avoid constructing temporary objects - Add comments explaining skipped fields in readMeasurement - Use direct buffer position increment instead of reading bytes for skipping - Ensure all column values are copied to ensure immutability * update * update (cherry picked from commit 13b0582dfb1a63a3f242b9545304d0d9fdede5cd) --- .../request/PipeTransferTabletBatchReqV2.java | 7 +- .../request/PipeTransferTabletRawReq.java | 110 +++- .../request/PipeTransferTabletRawReqV2.java | 50 +- .../sink/util/TabletStatementConverter.java | 476 ++++++++++++++ .../util/sorter/InsertEventDataAdapter.java | 127 ++++ .../sorter/InsertTabletStatementAdapter.java | 118 ++++ ...Sorter.java => PipeInsertEventSorter.java} | 94 ++- .../PipeTableModelTabletEventSorter.java | 67 +- .../PipeTreeModelTabletEventSorter.java | 48 +- .../pipe/sink/util/sorter/TabletAdapter.java | 113 ++++ .../statement/crud/InsertBaseStatement.java | 10 + .../statement/crud/InsertTabletStatement.java | 197 ++++++ .../sink/PipeDataNodeThriftRequestTest.java | 4 +- .../sink/PipeStatementEventSorterTest.java | 313 +++++++++ .../util/TabletStatementConverterTest.java | 607 ++++++++++++++++++ 15 files changed, 2257 insertions(+), 84 deletions(-) create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java rename iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/{PipeTabletEventSorter.java => PipeInsertEventSorter.java} (65%) create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java create mode 100644 iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java create mode 100644 iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java index d9c3fabfae8a..c136ffbe7d3e 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletBatchReqV2.java @@ -33,7 +33,6 @@ import org.apache.tsfile.utils.PublicBAOS; import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.apache.tsfile.write.record.Tablet; import java.io.DataOutputStream; import java.io.IOException; @@ -247,11 +246,7 @@ public static PipeTransferTabletBatchReqV2 fromTPipeTransferReq( size = ReadWriteIOUtils.readInt(transferReq.body); for (int i = 0; i < size; ++i) { - batchReq.tabletReqs.add( - PipeTransferTabletRawReqV2.toTPipeTransferRawReq( - Tablet.deserialize(transferReq.body), - ReadWriteIOUtils.readBool(transferReq.body), - ReadWriteIOUtils.readString(transferReq.body))); + batchReq.tabletReqs.add(PipeTransferTabletRawReqV2.toTPipeTransferRawReq(transferReq.body)); } batchReq.version = transferReq.version; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java index 7da44f297d14..af9b37edbf6f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReq.java @@ -22,6 +22,7 @@ import org.apache.iotdb.commons.exception.MetadataException; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.PipeRequestType; +import org.apache.iotdb.db.pipe.sink.util.TabletStatementConverter; import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter; import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; import org.apache.iotdb.service.rpc.thrift.TPipeTransferReq; @@ -43,10 +44,25 @@ public class PipeTransferTabletRawReq extends TPipeTransferReq { private static final Logger LOGGER = LoggerFactory.getLogger(PipeTransferTabletRawReq.class); - protected transient Tablet tablet; + protected transient InsertTabletStatement statement; + protected transient boolean isAligned; + protected transient Tablet tablet; + /** + * Get Tablet. If tablet is null, convert from statement. + * + * @return Tablet object + */ public Tablet getTablet() { + if (tablet == null && statement != null) { + try { + tablet = statement.convertToTablet(); + } catch (final MetadataException e) { + LOGGER.warn("Failed to convert statement to tablet.", e); + return null; + } + } return tablet; } @@ -54,16 +70,29 @@ public boolean getIsAligned() { return isAligned; } + /** + * Construct Statement. If statement already exists, return it. Otherwise, convert from tablet. + * + * @return InsertTabletStatement + */ public InsertTabletStatement constructStatement() { + if (statement != null) { + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + return statement; + } + + // Sort and deduplicate tablet before converting new PipeTreeModelTabletEventSorter(tablet).deduplicateAndSortTimestampsIfNecessary(); try { if (isTabletEmpty(tablet)) { // Empty statement, will be filtered after construction - return new InsertTabletStatement(); + statement = new InsertTabletStatement(); + return statement; } - return new InsertTabletStatement(tablet, isAligned, null); + statement = new InsertTabletStatement(tablet, isAligned, null); + return statement; } catch (final MetadataException e) { LOGGER.warn("Generate Statement from tablet {} error.", tablet, e); return null; @@ -107,8 +136,20 @@ public static PipeTransferTabletRawReq toTPipeTransferReq( public static PipeTransferTabletRawReq fromTPipeTransferReq(final TPipeTransferReq transferReq) { final PipeTransferTabletRawReq tabletReq = new PipeTransferTabletRawReq(); - tabletReq.tablet = Tablet.deserialize(transferReq.body); - tabletReq.isAligned = ReadWriteIOUtils.readBool(transferReq.body); + final ByteBuffer buffer = transferReq.body; + final int startPosition = buffer.position(); + try { + // V1: no databaseName, readDatabaseName = false + final InsertTabletStatement insertTabletStatement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer, false); + tabletReq.isAligned = insertTabletStatement.isAligned(); + // devicePath is already set in deserializeStatementFromTabletFormat for V1 format + tabletReq.statement = insertTabletStatement; + } catch (final Exception e) { + buffer.position(startPosition); + tabletReq.tablet = Tablet.deserialize(buffer); + tabletReq.isAligned = ReadWriteIOUtils.readBool(buffer); + } tabletReq.version = transferReq.version; tabletReq.type = transferReq.type; @@ -118,18 +159,56 @@ public static PipeTransferTabletRawReq fromTPipeTransferReq(final TPipeTransferR /////////////////////////////// Air Gap /////////////////////////////// - public static byte[] toTPipeTransferBytes(final Tablet tablet, final boolean isAligned) - throws IOException { + /** + * Serialize to bytes. If tablet is null, convert from statement first. + * + * @return serialized bytes + * @throws IOException if serialization fails + */ + public byte[] toTPipeTransferBytes() throws IOException { + Tablet tabletToSerialize = tablet; + boolean isAlignedToSerialize = isAligned; + + // If tablet is null, convert from statement + if (tabletToSerialize == null && statement != null) { + try { + tabletToSerialize = statement.convertToTablet(); + isAlignedToSerialize = statement.isAligned(); + } catch (final MetadataException e) { + throw new IOException("Failed to convert statement to tablet for serialization", e); + } + } + + if (tabletToSerialize == null) { + throw new IOException("Cannot serialize: both tablet and statement are null"); + } + try (final PublicBAOS byteArrayOutputStream = new PublicBAOS(); final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) { ReadWriteIOUtils.write(IoTDBSinkRequestVersion.VERSION_1.getVersion(), outputStream); ReadWriteIOUtils.write(PipeRequestType.TRANSFER_TABLET_RAW.getType(), outputStream); - tablet.serialize(outputStream); - ReadWriteIOUtils.write(isAligned, outputStream); + tabletToSerialize.serialize(outputStream); + ReadWriteIOUtils.write(isAlignedToSerialize, outputStream); return byteArrayOutputStream.toByteArray(); } } + /** + * Static method for backward compatibility. Creates a temporary instance and serializes. + * + * @param tablet Tablet to serialize + * @param isAligned whether aligned + * @return serialized bytes + * @throws IOException if serialization fails + */ + public static byte[] toTPipeTransferBytes(final Tablet tablet, final boolean isAligned) + throws IOException { + final PipeTransferTabletRawReq req = new PipeTransferTabletRawReq(); + req.tablet = tablet; + req.isAligned = isAligned; + return req.toTPipeTransferBytes(); + } + /////////////////////////////// Object /////////////////////////////// @Override @@ -141,7 +220,16 @@ public boolean equals(final Object obj) { return false; } final PipeTransferTabletRawReq that = (PipeTransferTabletRawReq) obj; - return Objects.equals(tablet, that.tablet) + // Compare statement if both have it, otherwise compare tablet + if (statement != null && that.statement != null) { + return Objects.equals(statement, that.statement) + && isAligned == that.isAligned + && version == that.version + && type == that.type + && Objects.equals(body, that.body); + } + // Fallback to tablet comparison + return Objects.equals(getTablet(), that.getTablet()) && isAligned == that.isAligned && version == that.version && type == that.type @@ -150,6 +238,6 @@ public boolean equals(final Object obj) { @Override public int hashCode() { - return Objects.hash(tablet, isAligned, version, type, body); + return Objects.hash(getTablet(), isAligned, version, type, body); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java index 43d8501252c5..3c5f420a317f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/payload/evolvable/request/PipeTransferTabletRawReqV2.java @@ -22,6 +22,7 @@ import org.apache.iotdb.commons.exception.MetadataException; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.IoTDBSinkRequestVersion; import org.apache.iotdb.commons.pipe.sink.payload.thrift.request.PipeRequestType; +import org.apache.iotdb.db.pipe.sink.util.TabletStatementConverter; import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTableModelTabletEventSorter; import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter; import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; @@ -52,6 +53,16 @@ public String getDataBaseName() { @Override public InsertTabletStatement constructStatement() { + if (statement != null) { + if (Objects.isNull(dataBaseName)) { + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + } else { + new PipeTableModelTabletEventSorter(statement).sortByTimestampIfNecessary(); + } + + return statement; + } + if (Objects.isNull(dataBaseName)) { new PipeTreeModelTabletEventSorter(tablet).deduplicateAndSortTimestampsIfNecessary(); } else { @@ -86,6 +97,16 @@ public static PipeTransferTabletRawReqV2 toTPipeTransferRawReq( return tabletReq; } + public static PipeTransferTabletRawReqV2 toTPipeTransferRawReq(final ByteBuffer buffer) { + final PipeTransferTabletRawReqV2 tabletReq = new PipeTransferTabletRawReqV2(); + + tabletReq.deserializeTPipeTransferRawReq(buffer); + tabletReq.version = IoTDBSinkRequestVersion.VERSION_1.getVersion(); + tabletReq.type = PipeRequestType.TRANSFER_TABLET_RAW_V2.getType(); + + return tabletReq; + } + /////////////////////////////// Thrift /////////////////////////////// public static PipeTransferTabletRawReqV2 toTPipeTransferReq( @@ -114,13 +135,11 @@ public static PipeTransferTabletRawReqV2 fromTPipeTransferReq( final TPipeTransferReq transferReq) { final PipeTransferTabletRawReqV2 tabletReq = new PipeTransferTabletRawReqV2(); - tabletReq.tablet = Tablet.deserialize(transferReq.body); - tabletReq.isAligned = ReadWriteIOUtils.readBool(transferReq.body); - tabletReq.dataBaseName = ReadWriteIOUtils.readString(transferReq.body); + tabletReq.deserializeTPipeTransferRawReq(transferReq.body); + tabletReq.body = transferReq.body; tabletReq.version = transferReq.version; tabletReq.type = transferReq.type; - tabletReq.body = transferReq.body; return tabletReq; } @@ -161,4 +180,27 @@ public boolean equals(Object o) { public int hashCode() { return Objects.hash(super.hashCode(), dataBaseName); } + + /////////////////////////////// Util /////////////////////////////// + + public void deserializeTPipeTransferRawReq(final ByteBuffer buffer) { + final int startPosition = buffer.position(); + try { + // V2: read databaseName, readDatabaseName = true + final InsertTabletStatement insertTabletStatement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer, true); + this.isAligned = insertTabletStatement.isAligned(); + // databaseName is already set in deserializeStatementFromTabletFormat when + // readDatabaseName=true + this.dataBaseName = insertTabletStatement.getDatabaseName().orElse(null); + this.statement = insertTabletStatement; + } catch (final Exception e) { + // If Statement deserialization fails, fallback to Tablet format + // Reset buffer position for Tablet deserialization + buffer.position(startPosition); + this.tablet = Tablet.deserialize(buffer); + this.isAligned = ReadWriteIOUtils.readBool(buffer); + this.dataBaseName = ReadWriteIOUtils.readString(buffer); + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java new file mode 100644 index 000000000000..c5b9ebed4d5f --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverter.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util; + +import org.apache.iotdb.commons.exception.IllegalPathException; +import org.apache.iotdb.commons.path.PartialPath; +import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory; +import org.apache.iotdb.db.pipe.resource.memory.InsertNodeMemoryEstimator; +import org.apache.iotdb.db.queryengine.plan.analyze.cache.schema.DataNodeDevicePathCache; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.ColumnCategory; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.BitMap; +import org.apache.tsfile.utils.BytesUtils; +import org.apache.tsfile.utils.Pair; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.apache.tsfile.write.UnSupportedDataTypeException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.nio.ByteBuffer; +import java.util.Arrays; + +/** + * Utility class for converting between InsertTabletStatement and Tablet format ByteBuffer. This + * avoids creating intermediate Tablet objects and directly converts between formats with only the + * fields needed. + */ +public class TabletStatementConverter { + + private static final Logger LOGGER = LoggerFactory.getLogger(TabletStatementConverter.class); + + // Memory calculation constants - extracted from RamUsageEstimator for better performance + private static final long NUM_BYTES_ARRAY_HEADER = + org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + private static final long NUM_BYTES_OBJECT_REF = + org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_OBJECT_REF; + private static final long NUM_BYTES_OBJECT_HEADER = + org.apache.tsfile.utils.RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + private static final long SIZE_OF_ARRAYLIST = + org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance(java.util.ArrayList.class); + private static final long SIZE_OF_BITMAP = + org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance( + org.apache.tsfile.utils.BitMap.class); + + private TabletStatementConverter() { + // Utility class, no instantiation + } + + /** + * Deserialize InsertTabletStatement from Tablet format ByteBuffer. + * + * @param byteBuffer ByteBuffer containing serialized data + * @param readDatabaseName whether to read databaseName from buffer (for V2 format) + * @return InsertTabletStatement with all fields set, including devicePath + */ + public static InsertTabletStatement deserializeStatementFromTabletFormat( + final ByteBuffer byteBuffer, final boolean readDatabaseName) throws IllegalPathException { + final InsertTabletStatement statement = new InsertTabletStatement(); + + // Calculate memory size during deserialization, use INSTANCE_SIZE constant + long memorySize = InsertTabletStatement.getInstanceSize(); + + final String insertTargetName = ReadWriteIOUtils.readString(byteBuffer); + + final int rowSize = ReadWriteIOUtils.readInt(byteBuffer); + + // deserialize schemas + final int schemaSize = + BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)) + ? ReadWriteIOUtils.readInt(byteBuffer) + : 0; + final String[] measurement = new String[schemaSize]; + final TsTableColumnCategory[] columnCategories = new TsTableColumnCategory[schemaSize]; + final TSDataType[] dataTypes = new TSDataType[schemaSize]; + + // Calculate memory for arrays headers and references during deserialization + // measurements array: array header + object references + long measurementMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + + // dataTypes array: shallow size (array header + references) + long dataTypesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + + // columnCategories array: shallow size (array header + references) + long columnCategoriesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + + // tagColumnIndices (TAG columns): ArrayList base + array header + long tagColumnIndicesSize = SIZE_OF_ARRAYLIST; + tagColumnIndicesSize += NUM_BYTES_ARRAY_HEADER; + + // Deserialize and calculate memory in the same loop + for (int i = 0; i < schemaSize; i++) { + final boolean hasSchema = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (hasSchema) { + final Pair pair = readMeasurement(byteBuffer); + measurement[i] = pair.getLeft(); + dataTypes[i] = pair.getRight(); + columnCategories[i] = + TsTableColumnCategory.fromTsFileColumnCategory( + ColumnCategory.values()[byteBuffer.get()]); + + // Calculate memory for each measurement string + if (measurement[i] != null) { + measurementMemorySize += org.apache.tsfile.utils.RamUsageEstimator.sizeOf(measurement[i]); + } + + // Calculate memory for TAG column indices + if (columnCategories[i] != null && columnCategories[i].equals(TsTableColumnCategory.TAG)) { + tagColumnIndicesSize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + Integer.BYTES + NUM_BYTES_OBJECT_HEADER) + + NUM_BYTES_OBJECT_REF; + } + } + } + + // Add all calculated memory to total + memorySize += measurementMemorySize; + memorySize += dataTypesMemorySize; + + // deserialize times and calculate memory during deserialization + final long[] times = new long[rowSize]; + // Calculate memory: array header + long size * rowSize + final long timesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * rowSize); + + final boolean isTimesNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isTimesNotNull) { + for (int i = 0; i < rowSize; i++) { + times[i] = ReadWriteIOUtils.readLong(byteBuffer); + } + } + + // Add times memory to total + memorySize += timesMemorySize; + + // deserialize bitmaps and calculate memory during deserialization + final BitMap[] bitMaps; + final long bitMapsMemorySize; + + final boolean isBitMapsNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isBitMapsNotNull) { + // Use the method that returns both BitMap array and memory size + final Pair bitMapsAndMemory = + readBitMapsFromBufferWithMemory(byteBuffer, schemaSize); + bitMaps = bitMapsAndMemory.getLeft(); + bitMapsMemorySize = bitMapsAndMemory.getRight(); + } else { + // Calculate memory for empty BitMap array: array header + references + bitMaps = new BitMap[schemaSize]; + bitMapsMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + } + + // Add bitMaps memory to total + memorySize += bitMapsMemorySize; + + // Deserialize values and calculate memory during deserialization + final Object[] values; + final long valuesMemorySize; + + final boolean isValuesNotNull = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isValuesNotNull) { + // Use the method that returns both values array and memory size + final Pair valuesAndMemory = + readValuesFromBufferWithMemory(byteBuffer, dataTypes, schemaSize, rowSize); + values = valuesAndMemory.getLeft(); + valuesMemorySize = valuesAndMemory.getRight(); + } else { + // Calculate memory for empty values array: array header + references + values = new Object[schemaSize]; + valuesMemorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * schemaSize); + } + + // Add values memory to total + memorySize += valuesMemorySize; + + final boolean isAligned = ReadWriteIOUtils.readBoolean(byteBuffer); + + statement.setMeasurements(measurement); + statement.setTimes(times); + statement.setBitMaps(bitMaps); + statement.setDataTypes(dataTypes); + statement.setColumns(values); + statement.setRowCount(rowSize); + statement.setAligned(isAligned); + + // Read databaseName if requested (V2 format) + if (readDatabaseName) { + final String databaseName = ReadWriteIOUtils.readString(byteBuffer); + if (databaseName != null) { + statement.setDatabaseName(databaseName); + statement.setWriteToTable(true); + // For table model, insertTargetName is table name, convert to lowercase + statement.setDevicePath(new PartialPath(insertTargetName.toLowerCase(), false)); + // Calculate memory for databaseName + memorySize += org.apache.tsfile.utils.RamUsageEstimator.sizeOf(databaseName); + + statement.setColumnCategories(columnCategories); + + memorySize += columnCategoriesMemorySize; + memorySize += tagColumnIndicesSize; + } else { + // For tree model, use DataNodeDevicePathCache + statement.setDevicePath( + DataNodeDevicePathCache.getInstance().getPartialPath(insertTargetName)); + statement.setColumnCategories(null); + } + } else { + // V1 format: no databaseName in buffer, always use DataNodeDevicePathCache + statement.setDevicePath( + DataNodeDevicePathCache.getInstance().getPartialPath(insertTargetName)); + statement.setColumnCategories(null); + } + + // Calculate memory for devicePath + memorySize += InsertNodeMemoryEstimator.sizeOfPartialPath(statement.getDevicePath()); + + // Set the pre-calculated memory size to avoid recalculation + statement.setRamBytesUsed(memorySize); + + return statement; + } + + /** + * Deserialize InsertTabletStatement from Tablet format ByteBuffer (V1 format, no databaseName). + * + * @param byteBuffer ByteBuffer containing serialized data + * @return InsertTabletStatement with devicePath set using DataNodeDevicePathCache + */ + public static InsertTabletStatement deserializeStatementFromTabletFormat( + final ByteBuffer byteBuffer) throws IllegalPathException { + return deserializeStatementFromTabletFormat(byteBuffer, false); + } + + /** + * Skip a string in ByteBuffer without reading it. This is more efficient than reading and + * discarding the string. + * + * @param buffer ByteBuffer to skip string from + */ + private static void skipString(final ByteBuffer buffer) { + final int size = ReadWriteIOUtils.readInt(buffer); + if (size > 0) { + buffer.position(buffer.position() + size); + } + } + + /** + * Read measurement name and data type from buffer, skipping other measurement schema fields + * (encoding, compression, and tags/attributes) that are not needed for InsertTabletStatement. + * + * @param buffer ByteBuffer containing serialized measurement schema + * @return Pair of measurement name and data type + */ + private static Pair readMeasurement(final ByteBuffer buffer) { + // Read measurement name and data type + final Pair pair = + new Pair<>(ReadWriteIOUtils.readString(buffer), TSDataType.deserializeFrom(buffer)); + + // Skip encoding type (byte) and compression type (byte) - 2 bytes total + buffer.position(buffer.position() + 2); + + // Skip props map (Map) + final int size = ReadWriteIOUtils.readInt(buffer); + if (size > 0) { + for (int i = 0; i < size; i++) { + // Skip key (String) and value (String) without constructing temporary objects + skipString(buffer); + skipString(buffer); + } + } + + return pair; + } + + /** + * Deserialize bitmaps and calculate memory size during deserialization. Returns a Pair of BitMap + * array and the calculated memory size. + */ + private static Pair readBitMapsFromBufferWithMemory( + final ByteBuffer byteBuffer, final int columns) { + final BitMap[] bitMaps = new BitMap[columns]; + + // Calculate memory: array header + object references + long memorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * columns); + + for (int i = 0; i < columns; i++) { + final boolean hasBitMap = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (hasBitMap) { + final int size = ReadWriteIOUtils.readInt(byteBuffer); + final Binary valueBinary = ReadWriteIOUtils.readBinary(byteBuffer); + final byte[] byteArray = valueBinary.getValues(); + bitMaps[i] = new BitMap(size, byteArray); + + // Calculate memory for this BitMap: BitMap object + byte array + // BitMap shallow size + byte array (array header + array length) + memorySize += + SIZE_OF_BITMAP + + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + byteArray.length); + } + } + + return new Pair<>(bitMaps, memorySize); + } + + /** + * Deserialize values from buffer and calculate memory size during deserialization. Returns a Pair + * of values array and the calculated memory size. + * + * @param byteBuffer data values + * @param types data types + * @param columns column number + * @param rowSize row number + * @return Pair of values array and memory size + */ + @SuppressWarnings("squid:S3776") // Suppress high Cognitive Complexity warning + private static Pair readValuesFromBufferWithMemory( + final ByteBuffer byteBuffer, final TSDataType[] types, final int columns, final int rowSize) { + final Object[] values = new Object[columns]; + + // Calculate memory: array header + object references + long memorySize = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * columns); + + for (int i = 0; i < columns; i++) { + final boolean isValueColumnsNotNull = + BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isValueColumnsNotNull && types[i] == null) { + continue; + } + + switch (types[i]) { + case BOOLEAN: + final boolean[] boolValues = new boolean[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + boolValues[index] = BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + } + } + values[i] = boolValues; + // Calculate memory for boolean array: array header + 1 byte per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + rowSize); + break; + case INT32: + case DATE: + final int[] intValues = new int[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + intValues[index] = ReadWriteIOUtils.readInt(byteBuffer); + } + } + values[i] = intValues; + // Calculate memory for int array: array header + 4 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Integer.BYTES * rowSize); + break; + case INT64: + case TIMESTAMP: + final long[] longValues = new long[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + longValues[index] = ReadWriteIOUtils.readLong(byteBuffer); + } + } + values[i] = longValues; + // Calculate memory for long array: array header + 8 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Long.BYTES * rowSize); + break; + case FLOAT: + final float[] floatValues = new float[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + floatValues[index] = ReadWriteIOUtils.readFloat(byteBuffer); + } + } + values[i] = floatValues; + // Calculate memory for float array: array header + 4 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Float.BYTES * rowSize); + break; + case DOUBLE: + final double[] doubleValues = new double[rowSize]; + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + doubleValues[index] = ReadWriteIOUtils.readDouble(byteBuffer); + } + } + values[i] = doubleValues; + // Calculate memory for double array: array header + 8 bytes per element (aligned) + memorySize += + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + (long) Double.BYTES * rowSize); + break; + case TEXT: + case STRING: + case BLOB: + case OBJECT: + // Handle object array type: Binary[] is an array of objects + final Binary[] binaryValues = new Binary[rowSize]; + // Calculate memory for Binary array: array header + object references + long binaryArrayMemory = + org.apache.tsfile.utils.RamUsageEstimator.alignObjectSize( + NUM_BYTES_ARRAY_HEADER + NUM_BYTES_OBJECT_REF * rowSize); + + if (isValueColumnsNotNull) { + for (int index = 0; index < rowSize; index++) { + final boolean isNotNull = + BytesUtils.byteToBool(ReadWriteIOUtils.readByte(byteBuffer)); + if (isNotNull) { + binaryValues[index] = ReadWriteIOUtils.readBinary(byteBuffer); + // Calculate memory for each Binary object during deserialization + binaryArrayMemory += binaryValues[index].ramBytesUsed(); + } else { + binaryValues[index] = Binary.EMPTY_VALUE; + // EMPTY_VALUE also has memory cost + binaryArrayMemory += Binary.EMPTY_VALUE.ramBytesUsed(); + } + } + } else { + Arrays.fill(binaryValues, Binary.EMPTY_VALUE); + // Calculate memory for all EMPTY_VALUE + binaryArrayMemory += (long) rowSize * Binary.EMPTY_VALUE.ramBytesUsed(); + } + values[i] = binaryValues; + // Add calculated Binary array memory to total + memorySize += binaryArrayMemory; + break; + default: + throw new UnSupportedDataTypeException( + String.format("data type %s is not supported when convert data at client", types[i])); + } + } + + return new Pair<>(values, memorySize); + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java new file mode 100644 index 000000000000..62111cb4129c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertEventDataAdapter.java @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util.sorter; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.IDeviceID; +import org.apache.tsfile.utils.BitMap; + +/** + * Adapter interface to encapsulate common operations needed for sorting and deduplication. This + * interface allows the sorter to work with both Tablet and InsertTabletStatement. + */ +public interface InsertEventDataAdapter { + + /** + * Get the number of columns. + * + * @return number of columns + */ + int getColumnCount(); + + /** + * Get data type for a specific column. + * + * @param columnIndex column index + * @return data type of the column + */ + TSDataType getDataType(int columnIndex); + + /** + * Get bit maps for null values. + * + * @return array of bit maps, may be null + */ + BitMap[] getBitMaps(); + + /** + * Set bit maps for null values. + * + * @param bitMaps array of bit maps + */ + void setBitMaps(BitMap[] bitMaps); + + /** + * Get value arrays for all columns. + * + * @return array of value arrays (Object[]) + */ + Object[] getValues(); + + /** + * Set value array for a specific column. + * + * @param columnIndex column index + * @param value value array + */ + void setValue(int columnIndex, Object value); + + /** + * Get timestamps array. + * + * @return array of timestamps + */ + long[] getTimestamps(); + + /** + * Set timestamps array. + * + * @param timestamps array of timestamps + */ + void setTimestamps(long[] timestamps); + + /** + * Get row size/count. + * + * @return number of rows + */ + int getRowSize(); + + /** + * Set row size/count. + * + * @param rowSize number of rows + */ + void setRowSize(int rowSize); + + /** + * Get timestamp at a specific row index. + * + * @param rowIndex row index + * @return timestamp value + */ + long getTimestamp(int rowIndex); + + /** + * Get device ID at a specific row index (for table model). + * + * @param rowIndex row index + * @return device ID + */ + IDeviceID getDeviceID(int rowIndex); + + /** + * Check if the DATE type column value is stored as LocalDate[] (Tablet) or int[] (Statement). + * + * @param columnIndex column index + * @return true if DATE type is stored as LocalDate[], false if stored as int[] + */ + boolean isDateStoredAsLocalDate(int columnIndex); +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java new file mode 100644 index 000000000000..30f74a22965a --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/InsertTabletStatementAdapter.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util.sorter; + +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.IDeviceID; +import org.apache.tsfile.utils.BitMap; + +/** Adapter for InsertTabletStatement to implement InsertEventDataAdapter interface. */ +public class InsertTabletStatementAdapter implements InsertEventDataAdapter { + + private final InsertTabletStatement statement; + + public InsertTabletStatementAdapter(final InsertTabletStatement statement) { + this.statement = statement; + } + + @Override + public int getColumnCount() { + final Object[] columns = statement.getColumns(); + return columns != null ? columns.length : 0; + } + + @Override + public TSDataType getDataType(int columnIndex) { + final TSDataType[] dataTypes = statement.getDataTypes(); + if (dataTypes != null && columnIndex < dataTypes.length) { + return dataTypes[columnIndex]; + } + return null; + } + + @Override + public BitMap[] getBitMaps() { + return statement.getBitMaps(); + } + + @Override + public void setBitMaps(BitMap[] bitMaps) { + statement.setBitMaps(bitMaps); + } + + @Override + public Object[] getValues() { + return statement.getColumns(); + } + + @Override + public void setValue(int columnIndex, Object value) { + Object[] columns = statement.getColumns(); + if (columns != null && columnIndex < columns.length) { + columns[columnIndex] = value; + } + } + + @Override + public long[] getTimestamps() { + return statement.getTimes(); + } + + @Override + public void setTimestamps(long[] timestamps) { + statement.setTimes(timestamps); + } + + @Override + public int getRowSize() { + return statement.getRowCount(); + } + + @Override + public void setRowSize(int rowSize) { + statement.setRowCount(rowSize); + } + + @Override + public long getTimestamp(int rowIndex) { + long[] times = statement.getTimes(); + if (times != null && rowIndex < times.length) { + return times[rowIndex]; + } + return 0; + } + + @Override + public IDeviceID getDeviceID(int rowIndex) { + return statement.getTableDeviceID(rowIndex); + } + + @Override + public boolean isDateStoredAsLocalDate(int columnIndex) { + // InsertTabletStatement stores DATE as int[] + return false; + } + + public InsertTabletStatement getStatement() { + return statement; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTabletEventSorter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeInsertEventSorter.java similarity index 65% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTabletEventSorter.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeInsertEventSorter.java index 6540bf9855d6..46a3fc6df942 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTabletEventSorter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeInsertEventSorter.java @@ -19,19 +19,20 @@ package org.apache.iotdb.db.pipe.sink.util.sorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; import org.apache.tsfile.write.UnSupportedDataTypeException; import org.apache.tsfile.write.record.Tablet; -import org.apache.tsfile.write.schema.IMeasurementSchema; import java.time.LocalDate; import java.util.Objects; -public class PipeTabletEventSorter { +public class PipeInsertEventSorter { - protected final Tablet tablet; + protected final InsertEventDataAdapter dataAdapter; protected Integer[] index; protected boolean isSorted = true; @@ -39,8 +40,31 @@ public class PipeTabletEventSorter { protected int[] deDuplicatedIndex; protected int deDuplicatedSize; - public PipeTabletEventSorter(final Tablet tablet) { - this.tablet = tablet; + /** + * Constructor for Tablet. + * + * @param tablet the tablet to sort + */ + public PipeInsertEventSorter(final Tablet tablet) { + this.dataAdapter = new TabletAdapter(tablet); + } + + /** + * Constructor for InsertTabletStatement. + * + * @param statement the insert tablet statement to sort + */ + public PipeInsertEventSorter(final InsertTabletStatement statement) { + this.dataAdapter = new InsertTabletStatementAdapter(statement); + } + + /** + * Constructor with adapter (for internal use or advanced scenarios). + * + * @param adapter the data adapter + */ + protected PipeInsertEventSorter(final InsertEventDataAdapter adapter) { + this.dataAdapter = adapter; } // Input: @@ -54,35 +78,42 @@ public PipeTabletEventSorter(final Tablet tablet) { // (Used index: [2(3), 4(0)]) // Col: [6, 1] protected void sortAndMayDeduplicateValuesAndBitMaps() { - int columnIndex = 0; - for (int i = 0, size = tablet.getSchemas().size(); i < size; i++) { - final IMeasurementSchema schema = tablet.getSchemas().get(i); - if (schema != null) { + final int columnCount = dataAdapter.getColumnCount(); + BitMap[] bitMaps = dataAdapter.getBitMaps(); + boolean bitMapsModified = false; + + for (int columnIndex = 0; columnIndex < columnCount; columnIndex++) { + final TSDataType dataType = dataAdapter.getDataType(columnIndex); + if (dataType != null) { BitMap deDuplicatedBitMap = null; BitMap originalBitMap = null; - if (tablet.getBitMaps() != null && tablet.getBitMaps()[columnIndex] != null) { - originalBitMap = tablet.getBitMaps()[columnIndex]; + if (bitMaps != null && columnIndex < bitMaps.length && bitMaps[columnIndex] != null) { + originalBitMap = bitMaps[columnIndex]; deDuplicatedBitMap = new BitMap(originalBitMap.getSize()); } - tablet.getValues()[columnIndex] = + final Object[] values = dataAdapter.getValues(); + final Object reorderedValue = reorderValueListAndBitMap( - tablet.getValues()[columnIndex], - schema.getType(), - originalBitMap, - deDuplicatedBitMap); + values[columnIndex], dataType, columnIndex, originalBitMap, deDuplicatedBitMap); + dataAdapter.setValue(columnIndex, reorderedValue); - if (tablet.getBitMaps() != null && tablet.getBitMaps()[columnIndex] != null) { - tablet.getBitMaps()[columnIndex] = deDuplicatedBitMap; + if (bitMaps != null && columnIndex < bitMaps.length && bitMaps[columnIndex] != null) { + bitMaps[columnIndex] = deDuplicatedBitMap; + bitMapsModified = true; } - columnIndex++; } } + + if (bitMapsModified) { + dataAdapter.setBitMaps(bitMaps); + } } protected Object reorderValueListAndBitMap( final Object valueList, final TSDataType dataType, + final int columnIndex, final BitMap originalBitMap, final BitMap deDuplicatedBitMap) { // Older version's sender may contain null values, we need to cover this case @@ -107,13 +138,26 @@ protected Object reorderValueListAndBitMap( } return deDuplicatedIntValues; case DATE: - final LocalDate[] dateValues = (LocalDate[]) valueList; - final LocalDate[] deDuplicatedDateValues = new LocalDate[dateValues.length]; - for (int i = 0; i < deDuplicatedSize; i++) { - deDuplicatedDateValues[i] = - dateValues[getLastNonnullIndex(i, originalBitMap, deDuplicatedBitMap)]; + // DATE type: Tablet uses LocalDate[], InsertTabletStatement uses int[] + if (dataAdapter.isDateStoredAsLocalDate(columnIndex)) { + // Tablet: LocalDate[] + final LocalDate[] dateValues = (LocalDate[]) valueList; + final LocalDate[] deDuplicatedDateValues = new LocalDate[dateValues.length]; + for (int i = 0; i < deDuplicatedSize; i++) { + deDuplicatedDateValues[i] = + dateValues[getLastNonnullIndex(i, originalBitMap, deDuplicatedBitMap)]; + } + return deDuplicatedDateValues; + } else { + // InsertTabletStatement: int[] + final int[] intDateValues = (int[]) valueList; + final int[] deDuplicatedIntDateValues = new int[intDateValues.length]; + for (int i = 0; i < deDuplicatedSize; i++) { + deDuplicatedIntDateValues[i] = + intDateValues[getLastNonnullIndex(i, originalBitMap, deDuplicatedBitMap)]; + } + return deDuplicatedIntDateValues; } - return deDuplicatedDateValues; case INT64: case TIMESTAMP: final long[] longValues = (long[]) valueList; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java index 5735b51c6d02..ba034cae3a31 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTableModelTabletEventSorter.java @@ -19,6 +19,8 @@ package org.apache.iotdb.db.pipe.sink.util.sorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.file.metadata.IDeviceID; import org.apache.tsfile.utils.Pair; @@ -31,14 +33,29 @@ import java.util.List; import java.util.Map; -public class PipeTableModelTabletEventSorter extends PipeTabletEventSorter { +public class PipeTableModelTabletEventSorter extends PipeInsertEventSorter { private int initIndexSize; + /** + * Constructor for Tablet. + * + * @param tablet the tablet to sort + */ public PipeTableModelTabletEventSorter(final Tablet tablet) { super(tablet); deDuplicatedSize = tablet == null ? 0 : tablet.getRowSize(); } + /** + * Constructor for InsertTabletStatement. + * + * @param statement the insert tablet statement to sort + */ + public PipeTableModelTabletEventSorter(final InsertTabletStatement statement) { + super(statement); + deDuplicatedSize = statement == null ? 0 : statement.getRowCount(); + } + /** * For the sorting and deduplication needs of the table model tablet, it is done according to the * {@link IDeviceID}. For sorting, it is necessary to sort the {@link IDeviceID} first, and then @@ -46,18 +63,19 @@ public PipeTableModelTabletEventSorter(final Tablet tablet) { * the same timestamp in different {@link IDeviceID} will not be processed. */ public void sortAndDeduplicateByDevIdTimestamp() { - if (tablet == null || tablet.getRowSize() < 1) { + if (dataAdapter == null || dataAdapter.getRowSize() < 1) { return; } final HashMap>> deviceIDToIndexMap = new HashMap<>(); - final long[] timestamps = tablet.getTimestamps(); + final long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); - IDeviceID lastDevice = tablet.getDeviceID(0); - long previousTimestamp = tablet.getTimestamp(0); + IDeviceID lastDevice = dataAdapter.getDeviceID(0); + long previousTimestamp = dataAdapter.getTimestamp(0); int lasIndex = 0; - for (int i = 1, size = tablet.getRowSize(); i < size; ++i) { - final IDeviceID deviceID = tablet.getDeviceID(i); + for (int i = 1; i < rowSize; ++i) { + final IDeviceID deviceID = dataAdapter.getDeviceID(i); final long currentTimestamp = timestamps[i]; final int deviceComparison = deviceID.compareTo(lastDevice); if (deviceComparison == 0) { @@ -92,7 +110,7 @@ public void sortAndDeduplicateByDevIdTimestamp() { if (!list.isEmpty()) { isSorted = false; } - list.add(new Pair<>(lasIndex, tablet.getRowSize())); + list.add(new Pair<>(lasIndex, rowSize)); if (isSorted && isDeDuplicated) { return; @@ -100,8 +118,8 @@ public void sortAndDeduplicateByDevIdTimestamp() { initIndexSize = 0; deDuplicatedSize = 0; - index = new Integer[tablet.getRowSize()]; - deDuplicatedIndex = new int[tablet.getRowSize()]; + index = new Integer[rowSize]; + deDuplicatedIndex = new int[rowSize]; deviceIDToIndexMap.entrySet().stream() .sorted(Map.Entry.comparingByKey()) .forEach( @@ -129,19 +147,22 @@ public void sortAndDeduplicateByDevIdTimestamp() { } private void sortAndDeduplicateValuesAndBitMapsWithTimestamp() { - tablet.setTimestamps( + // TIMESTAMP is not a DATE type, so columnIndex is not relevant here, use -1 + dataAdapter.setTimestamps( (long[]) - reorderValueListAndBitMap(tablet.getTimestamps(), TSDataType.TIMESTAMP, null, null)); + reorderValueListAndBitMap( + dataAdapter.getTimestamps(), TSDataType.TIMESTAMP, -1, null, null)); sortAndMayDeduplicateValuesAndBitMaps(); - tablet.setRowSize(deDuplicatedSize); + dataAdapter.setRowSize(deDuplicatedSize); } private void sortTimestamps(final int startIndex, final int endIndex) { - Arrays.sort(this.index, startIndex, endIndex, Comparator.comparingLong(tablet::getTimestamp)); + Arrays.sort( + this.index, startIndex, endIndex, Comparator.comparingLong(dataAdapter::getTimestamp)); } private void deDuplicateTimestamps(final int startIndex, final int endIndex) { - final long[] timestamps = tablet.getTimestamps(); + final long[] timestamps = dataAdapter.getTimestamps(); long lastTime = timestamps[index[startIndex]]; for (int i = startIndex + 1; i < endIndex; i++) { if (lastTime != (lastTime = timestamps[index[i]])) { @@ -153,12 +174,13 @@ private void deDuplicateTimestamps(final int startIndex, final int endIndex) { /** Sort by time only. */ public void sortByTimestampIfNecessary() { - if (tablet == null || tablet.getRowSize() == 0) { + if (dataAdapter == null || dataAdapter.getRowSize() == 0) { return; } - final long[] timestamps = tablet.getTimestamps(); - for (int i = 1, size = tablet.getRowSize(); i < size; ++i) { + final long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); + for (int i = 1; i < rowSize; ++i) { final long currentTimestamp = timestamps[i]; final long previousTimestamp = timestamps[i - 1]; @@ -172,8 +194,8 @@ public void sortByTimestampIfNecessary() { return; } - index = new Integer[tablet.getRowSize()]; - for (int i = 0, size = tablet.getRowSize(); i < size; i++) { + index = new Integer[rowSize]; + for (int i = 0; i < rowSize; i++) { index[i] = i; } @@ -185,7 +207,8 @@ public void sortByTimestampIfNecessary() { } private void sortTimestamps() { - Arrays.sort(this.index, Comparator.comparingLong(tablet::getTimestamp)); - Arrays.sort(tablet.getTimestamps(), 0, tablet.getRowSize()); + Arrays.sort(this.index, Comparator.comparingLong(dataAdapter::getTimestamp)); + final long[] timestamps = dataAdapter.getTimestamps(); + Arrays.sort(timestamps, 0, dataAdapter.getRowSize()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java index c26f59220f98..2a56b706463f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/PipeTreeModelTabletEventSorter.java @@ -19,25 +19,43 @@ package org.apache.iotdb.db.pipe.sink.util.sorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + import org.apache.tsfile.write.record.Tablet; import java.util.Arrays; import java.util.Comparator; -public class PipeTreeModelTabletEventSorter extends PipeTabletEventSorter { +public class PipeTreeModelTabletEventSorter extends PipeInsertEventSorter { + /** + * Constructor for Tablet. + * + * @param tablet the tablet to sort + */ public PipeTreeModelTabletEventSorter(final Tablet tablet) { super(tablet); deDuplicatedSize = tablet == null ? 0 : tablet.getRowSize(); } + /** + * Constructor for InsertTabletStatement. + * + * @param statement the insert tablet statement to sort + */ + public PipeTreeModelTabletEventSorter(final InsertTabletStatement statement) { + super(statement); + deDuplicatedSize = statement == null ? 0 : statement.getRowCount(); + } + public void deduplicateAndSortTimestampsIfNecessary() { - if (tablet == null || tablet.getRowSize() == 0) { + if (dataAdapter == null || dataAdapter.getRowSize() == 0) { return; } - long[] timestamps = tablet.getTimestamps(); - for (int i = 1, size = tablet.getRowSize(); i < size; ++i) { + long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); + for (int i = 1; i < rowSize; ++i) { final long currentTimestamp = timestamps[i]; final long previousTimestamp = timestamps[i - 1]; @@ -54,9 +72,9 @@ public void deduplicateAndSortTimestampsIfNecessary() { return; } - index = new Integer[tablet.getRowSize()]; - deDuplicatedIndex = new int[tablet.getRowSize()]; - for (int i = 0, size = tablet.getRowSize(); i < size; i++) { + index = new Integer[rowSize]; + deDuplicatedIndex = new int[rowSize]; + for (int i = 0; i < rowSize; i++) { index[i] = i; } @@ -78,14 +96,16 @@ public void deduplicateAndSortTimestampsIfNecessary() { private void sortTimestamps() { // Index is sorted stably because it is Integer[] - Arrays.sort(index, Comparator.comparingLong(tablet::getTimestamp)); - Arrays.sort(tablet.getTimestamps(), 0, tablet.getRowSize()); + Arrays.sort(index, Comparator.comparingLong(dataAdapter::getTimestamp)); + final long[] timestamps = dataAdapter.getTimestamps(); + Arrays.sort(timestamps, 0, dataAdapter.getRowSize()); } private void deduplicateTimestamps() { deDuplicatedSize = 0; - long[] timestamps = tablet.getTimestamps(); - for (int i = 1, size = tablet.getRowSize(); i < size; i++) { + long[] timestamps = dataAdapter.getTimestamps(); + final int rowSize = dataAdapter.getRowSize(); + for (int i = 1; i < rowSize; i++) { if (timestamps[i] != timestamps[i - 1]) { deDuplicatedIndex[deDuplicatedSize] = i - 1; timestamps[deDuplicatedSize] = timestamps[i - 1]; @@ -94,8 +114,8 @@ private void deduplicateTimestamps() { } } - deDuplicatedIndex[deDuplicatedSize] = tablet.getRowSize() - 1; - timestamps[deDuplicatedSize] = timestamps[tablet.getRowSize() - 1]; - tablet.setRowSize(++deDuplicatedSize); + deDuplicatedIndex[deDuplicatedSize] = rowSize - 1; + timestamps[deDuplicatedSize] = timestamps[rowSize - 1]; + dataAdapter.setRowSize(++deDuplicatedSize); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java new file mode 100644 index 000000000000..b200127a5d45 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/pipe/sink/util/sorter/TabletAdapter.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util.sorter; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.IDeviceID; +import org.apache.tsfile.utils.BitMap; +import org.apache.tsfile.write.record.Tablet; +import org.apache.tsfile.write.schema.IMeasurementSchema; + +import java.util.List; + +/** Adapter for Tablet to implement InsertEventDataAdapter interface. */ +public class TabletAdapter implements InsertEventDataAdapter { + + private final Tablet tablet; + + public TabletAdapter(final Tablet tablet) { + this.tablet = tablet; + } + + @Override + public int getColumnCount() { + final Object[] values = tablet.getValues(); + return values != null ? values.length : 0; + } + + @Override + public TSDataType getDataType(int columnIndex) { + final List schemas = tablet.getSchemas(); + if (schemas != null && columnIndex < schemas.size()) { + final IMeasurementSchema schema = schemas.get(columnIndex); + return schema != null ? schema.getType() : null; + } + return null; + } + + @Override + public BitMap[] getBitMaps() { + return tablet.getBitMaps(); + } + + @Override + public void setBitMaps(BitMap[] bitMaps) { + tablet.setBitMaps(bitMaps); + } + + @Override + public Object[] getValues() { + return tablet.getValues(); + } + + @Override + public void setValue(int columnIndex, Object value) { + tablet.getValues()[columnIndex] = value; + } + + @Override + public long[] getTimestamps() { + return tablet.getTimestamps(); + } + + @Override + public void setTimestamps(long[] timestamps) { + tablet.setTimestamps(timestamps); + } + + @Override + public int getRowSize() { + return tablet.getRowSize(); + } + + @Override + public void setRowSize(int rowSize) { + tablet.setRowSize(rowSize); + } + + @Override + public long getTimestamp(int rowIndex) { + return tablet.getTimestamp(rowIndex); + } + + @Override + public IDeviceID getDeviceID(int rowIndex) { + return tablet.getDeviceID(rowIndex); + } + + @Override + public boolean isDateStoredAsLocalDate(int columnIndex) { + return true; + } + + public Tablet getTablet() { + return tablet; + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java index 1aae871ea0c1..d8786e33959a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertBaseStatement.java @@ -839,6 +839,16 @@ public long ramBytesUsed() { return ramBytesUsed; } + /** + * Set the pre-calculated memory size. This is used when memory size is calculated during + * deserialization to avoid recalculation. + * + * @param ramBytesUsed the calculated memory size in bytes + */ + public void setRamBytesUsed(long ramBytesUsed) { + this.ramBytesUsed = ramBytesUsed; + } + private long shallowSizeOfList(List list) { return Objects.nonNull(list) ? UpdateDetailContainer.LIST_SIZE diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java index d2c1d2a783a9..12369d81bfd3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/statement/crud/InsertTabletStatement.java @@ -44,6 +44,7 @@ import org.apache.iotdb.db.queryengine.plan.statement.StatementVisitor; import org.apache.iotdb.db.utils.CommonUtils; +import org.apache.tsfile.enums.ColumnCategory; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.file.metadata.IDeviceID; import org.apache.tsfile.file.metadata.IDeviceID.Factory; @@ -58,6 +59,8 @@ import org.apache.tsfile.write.record.Tablet; import org.apache.tsfile.write.schema.IMeasurementSchema; import org.apache.tsfile.write.schema.MeasurementSchema; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.time.LocalDate; import java.util.ArrayList; @@ -69,11 +72,22 @@ import java.util.Objects; public class InsertTabletStatement extends InsertBaseStatement implements ISchemaValidation { + private static final Logger LOGGER = LoggerFactory.getLogger(InsertTabletStatement.class); + private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(InsertTabletStatement.class); private static final String DATATYPE_UNSUPPORTED = "Data type %s is not supported."; + /** + * Get the instance size of InsertTabletStatement for memory calculation. + * + * @return instance size in bytes + */ + public static long getInstanceSize() { + return INSTANCE_SIZE; + } + protected long[] times; // times should be sorted. It is done in the session API. protected BitMap[] nullBitMaps; protected Object[] columns; @@ -701,6 +715,189 @@ protected void subRemoveAttributeColumns(List columnsToKeep) { } } + /** + * Convert this InsertTabletStatement to Tablet. This method constructs a Tablet object from this + * statement, converting all necessary fields. All arrays are copied to rowSize length to ensure + * immutability. + * + * @return Tablet object + * @throws MetadataException if conversion fails + */ + public Tablet convertToTablet() throws MetadataException { + try { + // Get deviceId/tableName from devicePath + final String deviceIdOrTableName = + this.getDevicePath() != null ? this.getDevicePath().getFullPath() : ""; + + // Get schemas from measurementSchemas + final MeasurementSchema[] measurementSchemas = this.getMeasurementSchemas(); + final String[] measurements = this.getMeasurements(); + final TSDataType[] dataTypes = this.getDataTypes(); + // If measurements and dataTypes are not null, use measurements.length as the standard length + final int originalSchemaSize = measurements != null ? measurements.length : 0; + + // Build schemas and track valid column indices (skip null columns) + // measurements and dataTypes being null is standard - skip those columns + final List schemas = new ArrayList<>(); + final List validColumnIndices = new ArrayList<>(); + for (int i = 0; i < originalSchemaSize; i++) { + if (dataTypes != null && measurements[i] != null && dataTypes[i] != null) { + // Create MeasurementSchema if not present + schemas.add(new MeasurementSchema(measurements[i], dataTypes[i])); + validColumnIndices.add(i); + } + // Skip null columns - don't add to schemas or validColumnIndices + } + + final int schemaSize = schemas.size(); + + // Get columnTypes (for table model) - only for valid columns + final TsTableColumnCategory[] columnCategories = this.getColumnCategories(); + final List tabletColumnTypes = new ArrayList<>(); + if (columnCategories != null && columnCategories.length > 0) { + for (final int validIndex : validColumnIndices) { + if (columnCategories[validIndex] != null) { + tabletColumnTypes.add(columnCategories[validIndex].toTsFileColumnType()); + } else { + tabletColumnTypes.add(ColumnCategory.FIELD); + } + } + } else { + // Default to FIELD for all valid columns if not specified + for (int i = 0; i < schemaSize; i++) { + tabletColumnTypes.add(ColumnCategory.FIELD); + } + } + + // Get timestamps - always copy to ensure immutability + final long[] times = this.getTimes(); + final int rowSize = this.getRowCount(); + final long[] timestamps; + if (times != null && times.length >= rowSize && rowSize > 0) { + timestamps = new long[rowSize]; + System.arraycopy(times, 0, timestamps, 0, rowSize); + } else { + LOGGER.warn( + "Times array is null or too small. times.length={}, rowSize={}, deviceId={}", + times != null ? times.length : 0, + rowSize, + deviceIdOrTableName); + timestamps = new long[0]; + } + + // Get values - convert Statement columns to Tablet format, only for valid columns + // All arrays are truncated/copied to rowSize length + final Object[] statementColumns = this.getColumns(); + final Object[] tabletValues = new Object[schemaSize]; + if (statementColumns != null && statementColumns.length > 0) { + for (int i = 0; i < validColumnIndices.size(); i++) { + final int originalIndex = validColumnIndices.get(i); + if (statementColumns[originalIndex] != null && dataTypes[originalIndex] != null) { + tabletValues[i] = + convertColumnToTablet( + statementColumns[originalIndex], dataTypes[originalIndex], rowSize); + } else { + tabletValues[i] = null; + } + } + } + + // Get bitMaps - copy and truncate to rowSize, only for valid columns + final BitMap[] originalBitMaps = this.getBitMaps(); + final BitMap[] bitMaps; + if (originalBitMaps != null && originalBitMaps.length > 0) { + bitMaps = new BitMap[schemaSize]; + for (int i = 0; i < validColumnIndices.size(); i++) { + final int originalIndex = validColumnIndices.get(i); + if (originalBitMaps[originalIndex] != null) { + // Create a new BitMap truncated to rowSize + final byte[] truncatedBytes = + originalBitMaps[originalIndex].getTruncatedByteArray(rowSize); + bitMaps[i] = new BitMap(rowSize, truncatedBytes); + } else { + bitMaps[i] = null; + } + } + } else { + bitMaps = null; + } + + // Create Tablet using the full constructor + // Tablet(String tableName, List schemas, List + // columnTypes, + // long[] timestamps, Object[] values, BitMap[] bitMaps, int rowSize) + return new Tablet( + deviceIdOrTableName, + schemas, + tabletColumnTypes, + timestamps, + tabletValues, + bitMaps, + rowSize); + } catch (final Exception e) { + throw new MetadataException("Failed to convert InsertTabletStatement to Tablet", e); + } + } + + /** + * Convert a single column value from Statement format to Tablet format. Statement uses primitive + * arrays (e.g., int[], long[], float[]), while Tablet may need different format. All arrays are + * copied and truncated to rowSize length to ensure immutability - even if the original array is + * modified, the converted array remains unchanged. + * + * @param columnValue column value from Statement (primitive array) + * @param dataType data type of the column + * @param rowSize number of rows to copy (truncate to this length) + * @return column value in Tablet format (copied and truncated array) + */ + private Object convertColumnToTablet( + final Object columnValue, final TSDataType dataType, final int rowSize) { + + if (columnValue == null) { + return null; + } + + if (TSDataType.DATE.equals(dataType)) { + final int[] values = (int[]) columnValue; + // Copy and truncate to rowSize + final int[] copiedValues = Arrays.copyOf(values, Math.min(values.length, rowSize)); + final LocalDate[] localDateValue = new LocalDate[rowSize]; + for (int i = 0; i < copiedValues.length; i++) { + localDateValue[i] = DateUtils.parseIntToLocalDate(copiedValues[i]); + } + // Fill remaining with null if needed + for (int i = copiedValues.length; i < rowSize; i++) { + localDateValue[i] = null; + } + return localDateValue; + } + + // For primitive arrays, copy and truncate to rowSize + if (columnValue instanceof boolean[]) { + final boolean[] original = (boolean[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof int[]) { + final int[] original = (int[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof long[]) { + final long[] original = (long[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof float[]) { + final float[] original = (float[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof double[]) { + final double[] original = (double[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } else if (columnValue instanceof Binary[]) { + // For Binary arrays, create a new array and copy references, truncate to rowSize + final Binary[] original = (Binary[]) columnValue; + return Arrays.copyOf(original, Math.min(original.length, rowSize)); + } + + // For other types, return as-is (should not happen for standard types) + return columnValue; + } + @Override public String toString() { final int size = CommonDescriptor.getInstance().getConfig().getPathLogMaxSize(); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java index 1a122c63d4f2..0cc4470882ef 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeDataNodeThriftRequestTest.java @@ -437,7 +437,7 @@ public void testPipeTransferTabletBatchReqV2() throws IOException { try (final PublicBAOS byteArrayOutputStream = new PublicBAOS(); final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream)) { t.serialize(outputStream); - ReadWriteIOUtils.write(false, outputStream); + ReadWriteIOUtils.write(true, outputStream); tabletBuffers.add( ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size())); tabletDataBase.add("test"); @@ -459,7 +459,7 @@ public void testPipeTransferTabletBatchReqV2() throws IOException { new byte[] {'a', 'b'}, deserializedReq.getBinaryReqs().get(0).getByteBuffer().array()); Assert.assertEquals(node, deserializedReq.getInsertNodeReqs().get(0).getInsertNode()); Assert.assertEquals(t, deserializedReq.getTabletReqs().get(0).getTablet()); - Assert.assertFalse(deserializedReq.getTabletReqs().get(0).getIsAligned()); + Assert.assertTrue(deserializedReq.getTabletReqs().get(0).getIsAligned()); Assert.assertEquals("test", deserializedReq.getBinaryReqs().get(0).getDataBaseName()); Assert.assertEquals("test", deserializedReq.getTabletReqs().get(0).getDataBaseName()); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java new file mode 100644 index 000000000000..7c13d9764b3c --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/PipeStatementEventSorterTest.java @@ -0,0 +1,313 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink; + +import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTableModelTabletEventSorter; +import org.apache.iotdb.db.pipe.sink.util.sorter.PipeTreeModelTabletEventSorter; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.write.record.Tablet; +import org.apache.tsfile.write.schema.IMeasurementSchema; +import org.apache.tsfile.write.schema.MeasurementSchema; +import org.junit.Assert; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class PipeStatementEventSorterTest { + + @Test + public void testTreeModelDeduplicateAndSort() throws Exception { + List schemaList = new ArrayList<>(); + schemaList.add(new MeasurementSchema("s1", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s2", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s3", TSDataType.INT64)); + + Tablet tablet = new Tablet("root.sg.device", schemaList, 30); + + long timestamp = 300; + for (long i = 0; i < 10; i++) { + int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp + i); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, timestamp + i); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp - i); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, timestamp - i); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, timestamp); + } + } + + Set indices = new HashSet<>(); + for (int i = 0; i < 30; i++) { + indices.add((int) tablet.getTimestamp(i)); + } + + Assert.assertFalse(tablet.isSorted()); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, true, null); + + // Sort using Statement + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + + Assert.assertEquals(indices.size(), statement.getRowCount()); + + final long[] timestamps = Arrays.copyOfRange(statement.getTimes(), 0, statement.getRowCount()); + final Object[] columns = statement.getColumns(); + for (int i = 0; i < 3; ++i) { + Assert.assertArrayEquals( + timestamps, Arrays.copyOfRange((long[]) columns[i], 0, statement.getRowCount())); + } + + for (int i = 1; i < statement.getRowCount(); ++i) { + Assert.assertTrue(timestamps[i] > timestamps[i - 1]); + for (int j = 0; j < 3; ++j) { + Assert.assertTrue(((long[]) columns[j])[i] > ((long[]) columns[j])[i - 1]); + } + } + } + + @Test + public void testTreeModelDeduplicate() throws Exception { + final List schemaList = new ArrayList<>(); + schemaList.add(new MeasurementSchema("s1", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s2", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s3", TSDataType.INT64)); + + final Tablet tablet = new Tablet("root.sg.device", schemaList, 10); + + final long timestamp = 300; + for (long i = 0; i < 10; i++) { + final int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, timestamp); + for (int s = 0; s < 3; s++) { + tablet.addValue( + schemaList.get(s).getMeasurementName(), + rowIndex, + (i + s) % 3 != 0 ? timestamp + i : null); + } + } + + final Set indices = new HashSet<>(); + for (int i = 0; i < 10; i++) { + indices.add((int) tablet.getTimestamp(i)); + } + + Assert.assertTrue(tablet.isSorted()); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, true, null); + + // Sort using Statement + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + + Assert.assertEquals(indices.size(), statement.getRowCount()); + + final long[] timestamps = Arrays.copyOfRange(statement.getTimes(), 0, statement.getRowCount()); + final Object[] columns = statement.getColumns(); + Assert.assertEquals(timestamps[0] + 8, ((long[]) columns[0])[0]); + for (int i = 1; i < 3; ++i) { + Assert.assertEquals(timestamps[0] + 9, ((long[]) columns[i])[0]); + } + } + + @Test + public void testTreeModelSort() throws Exception { + List schemaList = new ArrayList<>(); + schemaList.add(new MeasurementSchema("s1", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s2", TSDataType.INT64)); + schemaList.add(new MeasurementSchema("s3", TSDataType.INT64)); + + Tablet tablet = new Tablet("root.sg.device", schemaList, 30); + + for (long i = 0; i < 10; i++) { + int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, (long) rowIndex + 2); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, (long) rowIndex + 2); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, rowIndex); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, (long) rowIndex); + } + + rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, (long) rowIndex - 2); + for (int s = 0; s < 3; s++) { + tablet.addValue(schemaList.get(s).getMeasurementName(), rowIndex, (long) rowIndex - 2); + } + } + + Set indices = new HashSet<>(); + for (int i = 0; i < 30; i++) { + indices.add((int) tablet.getTimestamp(i)); + } + + Assert.assertFalse(tablet.isSorted()); + + long[] timestamps = Arrays.copyOfRange(tablet.getTimestamps(), 0, tablet.getRowSize()); + for (int i = 0; i < 3; ++i) { + Assert.assertArrayEquals( + timestamps, Arrays.copyOfRange((long[]) tablet.getValues()[i], 0, tablet.getRowSize())); + } + + for (int i = 1; i < tablet.getRowSize(); ++i) { + Assert.assertTrue(timestamps[i] != timestamps[i - 1]); + for (int j = 0; j < 3; ++j) { + Assert.assertNotEquals((long) tablet.getValue(i, j), (long) tablet.getValue(i - 1, j)); + } + } + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, true, null); + + // Sort using Statement + new PipeTreeModelTabletEventSorter(statement).deduplicateAndSortTimestampsIfNecessary(); + + Assert.assertEquals(indices.size(), statement.getRowCount()); + + timestamps = Arrays.copyOfRange(statement.getTimes(), 0, statement.getRowCount()); + final Object[] columns = statement.getColumns(); + for (int i = 0; i < 3; ++i) { + Assert.assertArrayEquals( + timestamps, Arrays.copyOfRange((long[]) columns[i], 0, statement.getRowCount())); + } + + for (int i = 1; i < statement.getRowCount(); ++i) { + Assert.assertTrue(timestamps[i] > timestamps[i - 1]); + for (int j = 0; j < 3; ++j) { + Assert.assertTrue(((long[]) columns[j])[i] > ((long[]) columns[j])[i - 1]); + } + } + } + + @Test + public void testTableModelDeduplicateAndSort() throws Exception { + doTableModelTest(true, true); + } + + @Test + public void testTableModelDeduplicate() throws Exception { + doTableModelTest(true, false); + } + + @Test + public void testTableModelSort() throws Exception { + doTableModelTest(false, true); + } + + @Test + public void testTableModelSort1() throws Exception { + doTableModelTest1(); + } + + public void doTableModelTest(final boolean hasDuplicates, final boolean isUnSorted) + throws Exception { + final Tablet tablet = + PipeTabletEventSorterTest.generateTablet("test", 10, hasDuplicates, isUnSorted); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, false, "test_db"); + + // Sort using Statement + new PipeTableModelTabletEventSorter(statement).sortAndDeduplicateByDevIdTimestamp(); + + long[] timestamps = statement.getTimes(); + final Object[] columns = statement.getColumns(); + for (int i = 1; i < statement.getRowCount(); i++) { + long time = timestamps[i]; + Assert.assertTrue(time > timestamps[i - 1]); + Assert.assertEquals( + ((Binary[]) columns[0])[i], + new Binary(String.valueOf(i / 100).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[1])[i], (long) i); + Assert.assertEquals(((float[]) columns[2])[i], i * 1.0f, 0.001f); + Assert.assertEquals( + ((Binary[]) columns[3])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[4])[i], (long) i); + Assert.assertEquals(((int[]) columns[5])[i], i); + Assert.assertEquals(((double[]) columns[6])[i], i * 0.1, 0.0001); + // DATE is stored as int[] in Statement, not LocalDate[] + LocalDate expectedDate = PipeTabletEventSorterTest.getDate(i); + int expectedDateInt = + org.apache.tsfile.utils.DateUtils.parseDateExpressionToInt(expectedDate); + Assert.assertEquals(((int[]) columns[7])[i], expectedDateInt); + Assert.assertEquals( + ((Binary[]) columns[8])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + } + } + + public void doTableModelTest1() throws Exception { + final Tablet tablet = PipeTabletEventSorterTest.generateTablet("test", 10, false, true); + + // Convert Tablet to Statement + InsertTabletStatement statement = new InsertTabletStatement(tablet, false, "test_db"); + + // Sort using Statement + new PipeTableModelTabletEventSorter(statement).sortByTimestampIfNecessary(); + + long[] timestamps = statement.getTimes(); + final Object[] columns = statement.getColumns(); + for (int i = 1; i < statement.getRowCount(); i++) { + long time = timestamps[i]; + Assert.assertTrue(time > timestamps[i - 1]); + Assert.assertEquals( + ((Binary[]) columns[0])[i], + new Binary(String.valueOf(i / 100).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[1])[i], (long) i); + Assert.assertEquals(((float[]) columns[2])[i], i * 1.0f, 0.001f); + Assert.assertEquals( + ((Binary[]) columns[3])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + Assert.assertEquals(((long[]) columns[4])[i], (long) i); + Assert.assertEquals(((int[]) columns[5])[i], i); + Assert.assertEquals(((double[]) columns[6])[i], i * 0.1, 0.0001); + // DATE is stored as int[] in Statement, not LocalDate[] + LocalDate expectedDate = PipeTabletEventSorterTest.getDate(i); + int expectedDateInt = + org.apache.tsfile.utils.DateUtils.parseDateExpressionToInt(expectedDate); + Assert.assertEquals(((int[]) columns[7])[i], expectedDateInt); + Assert.assertEquals( + ((Binary[]) columns[8])[i], + new Binary(String.valueOf(i).getBytes(StandardCharsets.UTF_8))); + } + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java new file mode 100644 index 000000000000..410afc761303 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/pipe/sink/util/TabletStatementConverterTest.java @@ -0,0 +1,607 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under this License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.pipe.sink.util; + +import org.apache.iotdb.commons.exception.MetadataException; +import org.apache.iotdb.db.queryengine.plan.statement.crud.InsertTabletStatement; + +import org.apache.tsfile.enums.ColumnCategory; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.utils.Binary; +import org.apache.tsfile.utils.DateUtils; +import org.apache.tsfile.utils.PublicBAOS; +import org.apache.tsfile.utils.ReadWriteIOUtils; +import org.apache.tsfile.write.record.Tablet; +import org.apache.tsfile.write.schema.IMeasurementSchema; +import org.apache.tsfile.write.schema.MeasurementSchema; +import org.junit.Assert; +import org.junit.Test; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.time.LocalDate; +import java.util.ArrayList; +import java.util.List; + +public class TabletStatementConverterTest { + + @Test + public void testConvertStatementToTabletTreeModel() throws MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String deviceName = "root.sg.device"; + final boolean isAligned = true; + + // Generate Tablet and construct Statement from it + final Tablet originalTablet = generateTreeModelTablet(deviceName, columnCount, rowCount); + final InsertTabletStatement statement = + new InsertTabletStatement(originalTablet, isAligned, null); + + // Convert Statement to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + + // Verify conversion + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testConvertStatementToTabletTableModel() throws MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String tableName = "table1"; + final String databaseName = "test_db"; + final boolean isAligned = false; + + // Generate Tablet and construct Statement from it + final Tablet originalTablet = generateTableModelTablet(tableName, columnCount, rowCount); + final InsertTabletStatement statement = + new InsertTabletStatement(originalTablet, isAligned, databaseName); + + // Convert Statement to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + + // Verify conversion + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testDeserializeStatementFromTabletFormat() throws IOException, MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String deviceName = "root.sg.device"; + + // Generate test Tablet + final Tablet originalTablet = generateTreeModelTablet(deviceName, columnCount, rowCount); + + // Serialize Tablet to ByteBuffer + final PublicBAOS byteArrayOutputStream = new PublicBAOS(); + final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream); + // Then serialize the tablet + originalTablet.serialize(outputStream); + // Write isAligned at the end + final boolean isAligned = true; + ReadWriteIOUtils.write(isAligned, outputStream); + + final ByteBuffer buffer = + ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size()); + + // Deserialize Statement from Tablet format + final InsertTabletStatement statement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer); + + // Verify basic information + Assert.assertEquals(deviceName, statement.getDevicePath().getFullPath()); + Assert.assertEquals(rowCount, statement.getRowCount()); + Assert.assertEquals(columnCount, statement.getMeasurements().length); + Assert.assertEquals(isAligned, statement.isAligned()); + + // Verify data by converting Statement back to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testRoundTripConversionTreeModel() throws MetadataException, IOException { + final int columnCount = 1000; + final int rowCount = 100; + final String deviceName = "root.sg.device"; + + // Generate original Tablet + final Tablet originalTablet = generateTreeModelTablet(deviceName, columnCount, rowCount); + + // Serialize Tablet to ByteBuffer + final PublicBAOS byteArrayOutputStream = new PublicBAOS(); + final DataOutputStream outputStream = new DataOutputStream(byteArrayOutputStream); + originalTablet.serialize(outputStream); + final boolean isAligned = true; + ReadWriteIOUtils.write(isAligned, outputStream); + final ByteBuffer buffer = + ByteBuffer.wrap(byteArrayOutputStream.getBuf(), 0, byteArrayOutputStream.size()); + + // Deserialize to Statement + final InsertTabletStatement statement = + TabletStatementConverter.deserializeStatementFromTabletFormat(buffer); + // Convert Statement back to Tablet + final Tablet convertedTablet = statement.convertToTablet(); + + // Verify round trip + assertTabletsEqual(originalTablet, convertedTablet); + } + + @Test + public void testRoundTripConversionTableModel() throws MetadataException { + final int columnCount = 1000; + final int rowCount = 100; + final String tableName = "table1"; + final String databaseName = "test_db"; + final boolean isAligned = false; + + // Generate original Tablet for table model + final Tablet originalTablet = generateTableModelTablet(tableName, columnCount, rowCount); + + // Construct Statement from Tablet + final InsertTabletStatement originalStatement = + new InsertTabletStatement(originalTablet, isAligned, databaseName); + + // Convert Statement to Tablet + final Tablet convertedTablet = originalStatement.convertToTablet(); + + // Convert Tablet back to Statement + final InsertTabletStatement convertedStatement = + new InsertTabletStatement(convertedTablet, isAligned, databaseName); + + // Verify round trip: original Tablet should equal converted Tablet + assertTabletsEqual(originalTablet, convertedTablet); + } + + /** + * Generate a Tablet for tree model with all data types and specified number of columns and rows. + * + * @param deviceName device name + * @param columnCount number of columns + * @param rowCount number of rows + * @return Tablet with test data + */ + private Tablet generateTreeModelTablet( + final String deviceName, final int columnCount, final int rowCount) { + final List schemaList = new ArrayList<>(); + final TSDataType[] dataTypes = new TSDataType[columnCount]; + final String[] measurementNames = new String[columnCount]; + final Object[] columnData = new Object[columnCount]; + + // Create schemas and generate test data + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = ALL_DATA_TYPES[col % ALL_DATA_TYPES.length]; + final String measurementName = "col_" + col + "_" + dataType.name(); + schemaList.add(new MeasurementSchema(measurementName, dataType)); + dataTypes[col] = dataType; + measurementNames[col] = measurementName; + + // Generate test data for this column + switch (dataType) { + case BOOLEAN: + final boolean[] boolValues = new boolean[rowCount]; + for (int row = 0; row < rowCount; row++) { + boolValues[row] = (row + col) % 2 == 0; + } + columnData[col] = boolValues; + break; + case INT32: + final int[] intValues = new int[rowCount]; + for (int row = 0; row < rowCount; row++) { + intValues[row] = row * 100 + col; + } + columnData[col] = intValues; + break; + case DATE: + final LocalDate[] dateValues = new LocalDate[rowCount]; + for (int row = 0; row < rowCount; row++) { + // Generate valid dates starting from 2024-01-01 + dateValues[row] = LocalDate.of(2024, 1, 1).plusDays((row + col) % 365); + } + columnData[col] = dateValues; + break; + case INT64: + case TIMESTAMP: + final long[] longValues = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + longValues[row] = (long) row * 1000L + col; + } + columnData[col] = longValues; + break; + case FLOAT: + final float[] floatValues = new float[rowCount]; + for (int row = 0; row < rowCount; row++) { + floatValues[row] = row * 1.5f + col * 0.1f; + } + columnData[col] = floatValues; + break; + case DOUBLE: + final double[] doubleValues = new double[rowCount]; + for (int row = 0; row < rowCount; row++) { + doubleValues[row] = row * 2.5 + col * 0.01; + } + columnData[col] = doubleValues; + break; + case TEXT: + case STRING: + case BLOB: + final Binary[] binaryValues = new Binary[rowCount]; + for (int row = 0; row < rowCount; row++) { + binaryValues[row] = new Binary(("value_row_" + row + "_col_" + col).getBytes()); + } + columnData[col] = binaryValues; + break; + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + // Create and fill tablet + final Tablet tablet = new Tablet(deviceName, schemaList, rowCount); + final long[] times = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + times[row] = row * 1000L; + final int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, times[row]); + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = dataTypes[col]; + final Object data = columnData[col]; + switch (dataType) { + case BOOLEAN: + tablet.addValue(measurementNames[col], rowIndex, ((boolean[]) data)[row]); + break; + case INT32: + tablet.addValue(measurementNames[col], rowIndex, ((int[]) data)[row]); + break; + case DATE: + tablet.addValue(measurementNames[col], rowIndex, ((LocalDate[]) data)[row]); + break; + case INT64: + case TIMESTAMP: + tablet.addValue(measurementNames[col], rowIndex, ((long[]) data)[row]); + break; + case FLOAT: + tablet.addValue(measurementNames[col], rowIndex, ((float[]) data)[row]); + break; + case DOUBLE: + tablet.addValue(measurementNames[col], rowIndex, ((double[]) data)[row]); + break; + case TEXT: + case STRING: + case BLOB: + tablet.addValue(measurementNames[col], rowIndex, ((Binary[]) data)[row]); + break; + } + } + } + + return tablet; + } + + /** + * Generate a Tablet for table model with all data types and specified number of columns and rows. + * + * @param tableName table name + * @param columnCount number of columns + * @param rowCount number of rows + * @return Tablet with test data + */ + private Tablet generateTableModelTablet( + final String tableName, final int columnCount, final int rowCount) { + final List schemaList = new ArrayList<>(); + final TSDataType[] dataTypes = new TSDataType[columnCount]; + final String[] measurementNames = new String[columnCount]; + final List columnTypes = new ArrayList<>(); + final Object[] columnData = new Object[columnCount]; + + // Create schemas and generate test data + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = ALL_DATA_TYPES[col % ALL_DATA_TYPES.length]; + final String measurementName = "col_" + col + "_" + dataType.name(); + schemaList.add(new MeasurementSchema(measurementName, dataType)); + dataTypes[col] = dataType; + measurementNames[col] = measurementName; + // For table model, all columns are FIELD (can be TAG/ATTRIBUTE/FIELD, but we use FIELD for + // simplicity) + columnTypes.add(ColumnCategory.FIELD); + + // Generate test data for this column + switch (dataType) { + case BOOLEAN: + final boolean[] boolValues = new boolean[rowCount]; + for (int row = 0; row < rowCount; row++) { + boolValues[row] = (row + col) % 2 == 0; + } + columnData[col] = boolValues; + break; + case INT32: + final int[] intValues = new int[rowCount]; + for (int row = 0; row < rowCount; row++) { + intValues[row] = row * 100 + col; + } + columnData[col] = intValues; + break; + case DATE: + final LocalDate[] dateValues = new LocalDate[rowCount]; + for (int row = 0; row < rowCount; row++) { + // Generate valid dates starting from 2024-01-01 + dateValues[row] = LocalDate.of(2024, 1, 1).plusDays((row + col) % 365); + } + columnData[col] = dateValues; + break; + case INT64: + case TIMESTAMP: + final long[] longValues = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + longValues[row] = (long) row * 1000L + col; + } + columnData[col] = longValues; + break; + case FLOAT: + final float[] floatValues = new float[rowCount]; + for (int row = 0; row < rowCount; row++) { + floatValues[row] = row * 1.5f + col * 0.1f; + } + columnData[col] = floatValues; + break; + case DOUBLE: + final double[] doubleValues = new double[rowCount]; + for (int row = 0; row < rowCount; row++) { + doubleValues[row] = row * 2.5 + col * 0.01; + } + columnData[col] = doubleValues; + break; + case TEXT: + case STRING: + case BLOB: + final Binary[] binaryValues = new Binary[rowCount]; + for (int row = 0; row < rowCount; row++) { + binaryValues[row] = new Binary(("value_row_" + row + "_col_" + col).getBytes()); + } + columnData[col] = binaryValues; + break; + default: + throw new IllegalArgumentException("Unsupported data type: " + dataType); + } + } + + // Create tablet using table model constructor: Tablet(String, List, List, + // List, int) + final List measurementNameList = IMeasurementSchema.getMeasurementNameList(schemaList); + final List dataTypeList = IMeasurementSchema.getDataTypeList(schemaList); + final Tablet tablet = + new Tablet(tableName, measurementNameList, dataTypeList, columnTypes, rowCount); + tablet.initBitMaps(); + + // Fill tablet with data + final long[] times = new long[rowCount]; + for (int row = 0; row < rowCount; row++) { + times[row] = row * 1000L; + final int rowIndex = tablet.getRowSize(); + tablet.addTimestamp(rowIndex, times[row]); + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = dataTypes[col]; + final Object data = columnData[col]; + switch (dataType) { + case BOOLEAN: + tablet.addValue(measurementNames[col], rowIndex, ((boolean[]) data)[row]); + break; + case INT32: + tablet.addValue(measurementNames[col], rowIndex, ((int[]) data)[row]); + break; + case DATE: + tablet.addValue(measurementNames[col], rowIndex, ((LocalDate[]) data)[row]); + break; + case INT64: + case TIMESTAMP: + tablet.addValue(measurementNames[col], rowIndex, ((long[]) data)[row]); + break; + case FLOAT: + tablet.addValue(measurementNames[col], rowIndex, ((float[]) data)[row]); + break; + case DOUBLE: + tablet.addValue(measurementNames[col], rowIndex, ((double[]) data)[row]); + break; + case TEXT: + case STRING: + case BLOB: + tablet.addValue(measurementNames[col], rowIndex, ((Binary[]) data)[row]); + break; + } + } + tablet.setRowSize(rowIndex + 1); + } + + return tablet; + } + + /** + * Check if two Tablets are equal in all aspects. + * + * @param expected expected Tablet + * @param actual actual Tablet + */ + private void assertTabletsEqual(final Tablet expected, final Tablet actual) { + Assert.assertEquals(expected.getDeviceId(), actual.getDeviceId()); + Assert.assertEquals(expected.getRowSize(), actual.getRowSize()); + Assert.assertEquals(expected.getSchemas().size(), actual.getSchemas().size()); + + // Verify timestamps + final long[] expectedTimes = expected.getTimestamps(); + final long[] actualTimes = actual.getTimestamps(); + Assert.assertArrayEquals(expectedTimes, actualTimes); + + // Verify each column + final int columnCount = expected.getSchemas().size(); + final int rowCount = expected.getRowSize(); + final Object[] expectedValues = expected.getValues(); + final Object[] actualValues = actual.getValues(); + + for (int col = 0; col < columnCount; col++) { + final IMeasurementSchema schema = expected.getSchemas().get(col); + final TSDataType dataType = schema.getType(); + final Object expectedColumn = expectedValues[col]; + final Object actualColumn = actualValues[col]; + + Assert.assertNotNull(actualColumn); + + // Verify each row in this column + for (int row = 0; row < rowCount; row++) { + switch (dataType) { + case BOOLEAN: + final boolean expectedBool = ((boolean[]) expectedColumn)[row]; + final boolean actualBool = ((boolean[]) actualColumn)[row]; + Assert.assertEquals(expectedBool, actualBool); + break; + case INT32: + final int expectedInt = ((int[]) expectedColumn)[row]; + final int actualInt = ((int[]) actualColumn)[row]; + Assert.assertEquals(expectedInt, actualInt); + break; + case DATE: + final LocalDate expectedDate = ((LocalDate[]) expectedColumn)[row]; + final LocalDate actualDate = ((LocalDate[]) actualColumn)[row]; + Assert.assertEquals(expectedDate, actualDate); + break; + case INT64: + case TIMESTAMP: + final long expectedLong = ((long[]) expectedColumn)[row]; + final long actualLong = ((long[]) actualColumn)[row]; + Assert.assertEquals(expectedLong, actualLong); + break; + case FLOAT: + final float expectedFloat = ((float[]) expectedColumn)[row]; + final float actualFloat = ((float[]) actualColumn)[row]; + Assert.assertEquals(expectedFloat, actualFloat, 0.0001f); + break; + case DOUBLE: + final double expectedDouble = ((double[]) expectedColumn)[row]; + final double actualDouble = ((double[]) actualColumn)[row]; + Assert.assertEquals(expectedDouble, actualDouble, 0.0001); + break; + case TEXT: + case STRING: + case BLOB: + final Binary expectedBinary = ((Binary[]) expectedColumn)[row]; + final Binary actualBinary = ((Binary[]) actualColumn)[row]; + Assert.assertNotNull(actualBinary); + Assert.assertEquals(expectedBinary, actualBinary); + break; + } + } + } + } + + /** + * Check if a Tablet and an InsertTabletStatement contain the same data. + * + * @param tablet Tablet + * @param statement InsertTabletStatement + */ + private void assertTabletAndStatementEqual( + final Tablet tablet, final InsertTabletStatement statement) { + Assert.assertEquals( + tablet.getDeviceId(), + statement.getDevicePath() != null ? statement.getDevicePath().getFullPath() : null); + Assert.assertEquals(tablet.getRowSize(), statement.getRowCount()); + Assert.assertEquals(tablet.getSchemas().size(), statement.getMeasurements().length); + + // Verify timestamps + Assert.assertArrayEquals(tablet.getTimestamps(), statement.getTimes()); + + // Verify each column + final int columnCount = tablet.getSchemas().size(); + final int rowCount = tablet.getRowSize(); + final Object[] tabletValues = tablet.getValues(); + final Object[] statementColumns = statement.getColumns(); + + for (int col = 0; col < columnCount; col++) { + final TSDataType dataType = tablet.getSchemas().get(col).getType(); + final Object tabletColumn = tabletValues[col]; + final Object statementColumn = statementColumns[col]; + + Assert.assertNotNull(statementColumn); + + // Verify each row + for (int row = 0; row < rowCount; row++) { + switch (dataType) { + case BOOLEAN: + final boolean tabletBool = ((boolean[]) tabletColumn)[row]; + final boolean statementBool = ((boolean[]) statementColumn)[row]; + Assert.assertEquals(tabletBool, statementBool); + break; + case INT32: + final int tabletInt = ((int[]) tabletColumn)[row]; + final int statementInt = ((int[]) statementColumn)[row]; + Assert.assertEquals(tabletInt, statementInt); + break; + case DATE: + // DATE type: Tablet uses LocalDate[], Statement uses int[] (YYYYMMDD format) + final LocalDate tabletDate = ((LocalDate[]) tabletColumn)[row]; + final int statementDateInt = ((int[]) statementColumn)[row]; + // Convert LocalDate to int (YYYYMMDD format) for comparison + final int tabletDateInt = DateUtils.parseDateExpressionToInt(tabletDate); + Assert.assertEquals(tabletDateInt, statementDateInt); + break; + case INT64: + case TIMESTAMP: + final long tabletLong = ((long[]) tabletColumn)[row]; + final long statementLong = ((long[]) statementColumn)[row]; + Assert.assertEquals(tabletLong, statementLong); + break; + case FLOAT: + final float tabletFloat = ((float[]) tabletColumn)[row]; + final float statementFloat = ((float[]) statementColumn)[row]; + Assert.assertEquals(tabletFloat, statementFloat, 0.0001f); + break; + case DOUBLE: + final double tabletDouble = ((double[]) tabletColumn)[row]; + final double statementDouble = ((double[]) statementColumn)[row]; + Assert.assertEquals(tabletDouble, statementDouble, 0.0001); + break; + case TEXT: + case STRING: + case BLOB: + final Binary tabletBinary = ((Binary[]) tabletColumn)[row]; + final Binary statementBinary = ((Binary[]) statementColumn)[row]; + Assert.assertNotNull(statementBinary); + Assert.assertEquals(tabletBinary, statementBinary); + break; + } + } + } + } + + // Define all supported data types + private static final TSDataType[] ALL_DATA_TYPES = { + TSDataType.BOOLEAN, + TSDataType.INT32, + TSDataType.INT64, + TSDataType.FLOAT, + TSDataType.DOUBLE, + TSDataType.TEXT, + TSDataType.TIMESTAMP, + TSDataType.DATE, + TSDataType.BLOB, + TSDataType.STRING + }; +} From bdea04155b9e06dfeb2396bc7e82b20b35593580 Mon Sep 17 00:00:00 2001 From: shuwenwei <55970239+shuwenwei@users.noreply.github.com> Date: Tue, 9 Dec 2025 19:10:23 +0800 Subject: [PATCH 06/13] Optimize memtable region scan (#16883) (cherry picked from commit a899c48a09c94de1c6bfd7730d013b5d12b1a0b9) --- .../storageengine/dataregion/DataRegion.java | 6 +- .../dataregion/memtable/AbstractMemTable.java | 86 +++-- .../memtable/AlignedWritableMemChunk.java | 132 +++++-- .../dataregion/memtable/IMemTable.java | 6 +- .../dataregion/memtable/TsFileProcessor.java | 18 +- .../dataregion/memtable/WritableMemChunk.java | 52 +-- .../WritableMemChunkRegionScanTest.java | 359 ++++++++++++++++++ 7 files changed, 572 insertions(+), 87 deletions(-) create mode 100644 iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java index 1124e33a7df9..4db15d48bc67 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/DataRegion.java @@ -2476,7 +2476,8 @@ private List getFileHandleListForQuery( } else { tsFileResource .getProcessor() - .queryForSeriesRegionScanWithoutLock(partialPaths, context, fileScanHandles); + .queryForSeriesRegionScanWithoutLock( + partialPaths, context, fileScanHandles, globalTimeFilter); } } return fileScanHandles; @@ -2553,7 +2554,8 @@ private List getFileHandleListForQuery( } else { tsFileResource .getProcessor() - .queryForDeviceRegionScanWithoutLock(devicePathsToContext, context, fileScanHandles); + .queryForDeviceRegionScanWithoutLock( + devicePathsToContext, context, fileScanHandles, globalTimeFilter); } } return fileScanHandles; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java index 92803984cee1..22606f232a8d 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AbstractMemTable.java @@ -70,6 +70,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; @@ -488,7 +489,8 @@ public void queryForSeriesRegionScan( long ttlLowerBound, Map> chunkMetaDataMap, Map> memChunkHandleMap, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { IDeviceID deviceID = fullPath.getDeviceId(); if (fullPath instanceof NonAlignedFullPath) { @@ -506,7 +508,12 @@ public void queryForSeriesRegionScan( fullPath.getDeviceId(), measurementId, this, modsToMemTabled, ttlLowerBound); } getMemChunkHandleFromMemTable( - deviceID, measurementId, chunkMetaDataMap, memChunkHandleMap, deletionList); + deviceID, + measurementId, + chunkMetaDataMap, + memChunkHandleMap, + deletionList, + globalTimeFilter); } else { // check If MemTable Contains this path if (!memTableMap.containsKey(deviceID)) { @@ -528,7 +535,8 @@ public void queryForSeriesRegionScan( ((AlignedFullPath) fullPath).getSchemaList(), chunkMetaDataMap, memChunkHandleMap, - deletionList); + deletionList, + globalTimeFilter); } } @@ -539,7 +547,8 @@ public void queryForDeviceRegionScan( long ttlLowerBound, Map> chunkMetadataMap, Map> memChunkHandleMap, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { Map memTableMap = getMemTableMap(); @@ -556,7 +565,8 @@ public void queryForDeviceRegionScan( chunkMetadataMap, memChunkHandleMap, ttlLowerBound, - modsToMemTabled); + modsToMemTabled, + globalTimeFilter); } else { getMemChunkHandleFromMemTable( deviceID, @@ -564,7 +574,8 @@ public void queryForDeviceRegionScan( chunkMetadataMap, memChunkHandleMap, ttlLowerBound, - modsToMemTabled); + modsToMemTabled, + globalTimeFilter); } } @@ -573,24 +584,30 @@ private void getMemChunkHandleFromMemTable( String measurementId, Map> chunkMetadataMap, Map> memChunkHandleMap, - List deletionList) { + List deletionList, + Filter globalTimeFilter) { WritableMemChunk memChunk = (WritableMemChunk) memTableMap.get(deviceID).getMemChunkMap().get(measurementId); - long[] timestamps = memChunk.getFilteredTimestamp(deletionList); + if (memChunk == null) { + return; + } + Optional anySatisfiedTimestamp = + memChunk.getAnySatisfiedTimestamp(deletionList, globalTimeFilter); + if (!anySatisfiedTimestamp.isPresent()) { + return; + } + long satisfiedTimestamp = anySatisfiedTimestamp.get(); chunkMetadataMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) .add( - buildChunkMetaDataForMemoryChunk( - measurementId, - timestamps[0], - timestamps[timestamps.length - 1], - Collections.emptyList())); + buildFakeChunkMetaDataForFakeMemoryChunk( + measurementId, satisfiedTimestamp, satisfiedTimestamp, Collections.emptyList())); memChunkHandleMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) - .add(new MemChunkHandleImpl(deviceID, measurementId, timestamps)); + .add(new MemChunkHandleImpl(deviceID, measurementId, new long[] {satisfiedTimestamp})); } private void getMemAlignedChunkHandleFromMemTable( @@ -598,7 +615,8 @@ private void getMemAlignedChunkHandleFromMemTable( List schemaList, Map> chunkMetadataList, Map> memChunkHandleMap, - List> deletionList) { + List> deletionList, + Filter globalTimeFilter) { AlignedWritableMemChunk alignedMemChunk = ((AlignedWritableMemChunkGroup) memTableMap.get(deviceID)).getAlignedMemChunk(); @@ -615,7 +633,11 @@ private void getMemAlignedChunkHandleFromMemTable( } List bitMaps = new ArrayList<>(); - long[] timestamps = alignedMemChunk.getFilteredTimestamp(deletionList, bitMaps, true); + long[] timestamps = + alignedMemChunk.getAnySatisfiedTimestamp(deletionList, bitMaps, true, globalTimeFilter); + if (timestamps.length == 0) { + return; + } buildAlignedMemChunkHandle( deviceID, @@ -633,7 +655,8 @@ private void getMemAlignedChunkHandleFromMemTable( Map> chunkMetadataList, Map> memChunkHandleMap, long ttlLowerBound, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { AlignedWritableMemChunk memChunk = writableMemChunkGroup.getAlignedMemChunk(); List schemaList = memChunk.getSchemaList(); @@ -648,7 +671,11 @@ private void getMemAlignedChunkHandleFromMemTable( } List bitMaps = new ArrayList<>(); - long[] timestamps = memChunk.getFilteredTimestamp(deletionList, bitMaps, true); + long[] timestamps = + memChunk.getAnySatisfiedTimestamp(deletionList, bitMaps, true, globalTimeFilter); + if (timestamps.length == 0) { + return; + } buildAlignedMemChunkHandle( deviceID, timestamps, @@ -665,7 +692,8 @@ private void getMemChunkHandleFromMemTable( Map> chunkMetadataMap, Map> memChunkHandleMap, long ttlLowerBound, - List> modsToMemTabled) { + List> modsToMemTabled, + Filter globalTimeFilter) { for (Entry entry : writableMemChunkGroup.getMemChunkMap().entrySet()) { @@ -679,18 +707,20 @@ private void getMemChunkHandleFromMemTable( ModificationUtils.constructDeletionList( deviceID, measurementId, this, modsToMemTabled, ttlLowerBound); } - long[] timestamps = writableMemChunk.getFilteredTimestamp(deletionList); + Optional anySatisfiedTimestamp = + writableMemChunk.getAnySatisfiedTimestamp(deletionList, globalTimeFilter); + if (!anySatisfiedTimestamp.isPresent()) { + return; + } + long satisfiedTimestamp = anySatisfiedTimestamp.get(); chunkMetadataMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) .add( - buildChunkMetaDataForMemoryChunk( - measurementId, - timestamps[0], - timestamps[timestamps.length - 1], - Collections.emptyList())); + buildFakeChunkMetaDataForFakeMemoryChunk( + measurementId, satisfiedTimestamp, satisfiedTimestamp, Collections.emptyList())); memChunkHandleMap .computeIfAbsent(measurementId, k -> new ArrayList<>()) - .add(new MemChunkHandleImpl(deviceID, measurementId, timestamps)); + .add(new MemChunkHandleImpl(deviceID, measurementId, new long[] {satisfiedTimestamp})); } } @@ -714,7 +744,7 @@ private void buildAlignedMemChunkHandle( chunkMetadataList .computeIfAbsent(measurement, k -> new ArrayList<>()) .add( - buildChunkMetaDataForMemoryChunk( + buildFakeChunkMetaDataForFakeMemoryChunk( measurement, startEndTime[0], startEndTime[1], deletion)); chunkHandleMap .computeIfAbsent(measurement, k -> new ArrayList<>()) @@ -745,7 +775,7 @@ private long[] calculateStartEndTime(long[] timestamps, List bitMaps, in return new long[] {startTime, endTime}; } - private IChunkMetadata buildChunkMetaDataForMemoryChunk( + private IChunkMetadata buildFakeChunkMetaDataForFakeMemoryChunk( String measurement, long startTime, long endTime, List deletionList) { TimeStatistics timeStatistics = new TimeStatistics(); timeStatistics.setStartTime(startTime); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java index 5f6af57746fe..f42867d4772c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/AlignedWritableMemChunk.java @@ -34,6 +34,7 @@ import org.apache.tsfile.encrypt.EncryptUtils; import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.common.TimeRange; +import org.apache.tsfile.read.filter.basic.Filter; import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; import org.apache.tsfile.utils.Pair; @@ -55,6 +56,7 @@ import java.util.Set; import java.util.TreeMap; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; import static org.apache.iotdb.db.utils.ModificationUtils.isPointDeleted; @@ -287,30 +289,99 @@ private Pair checkAndReorderColumnValuesInInsertPlan( return new Pair<>(reorderedColumnValues, reorderedBitMaps); } - private void filterDeletedTimeStamp( + public long[] getAnySatisfiedTimestamp( + List> deletionList, + List bitMaps, + boolean ignoreAllNullRows, + Filter globalTimeFilter) { + BitMap columnHasNonNullValue = new BitMap(schemaList.size()); + AtomicInteger hasNonNullValueColumnCount = new AtomicInteger(0); + Map timestampWithBitmap = new TreeMap<>(); + + getAnySatisfiedTimestamp( + list, + deletionList, + ignoreAllNullRows, + timestampWithBitmap, + globalTimeFilter, + columnHasNonNullValue, + hasNonNullValueColumnCount); + for (int i = 0; + i < sortedList.size() && hasNonNullValueColumnCount.get() < schemaList.size(); + i++) { + if (!ignoreAllNullRows && !timestampWithBitmap.isEmpty()) { + // count devices in table model + break; + } + getAnySatisfiedTimestamp( + sortedList.get(i), + deletionList, + ignoreAllNullRows, + timestampWithBitmap, + globalTimeFilter, + columnHasNonNullValue, + hasNonNullValueColumnCount); + } + + long[] timestamps = new long[timestampWithBitmap.size()]; + int idx = 0; + for (Map.Entry entry : timestampWithBitmap.entrySet()) { + timestamps[idx++] = entry.getKey(); + bitMaps.add(entry.getValue()); + } + return timestamps; + } + + private void getAnySatisfiedTimestamp( AlignedTVList alignedTVList, List> valueColumnsDeletionList, boolean ignoreAllNullRows, - Map timestampWithBitmap) { + Map timestampWithBitmap, + Filter globalTimeFilter, + BitMap columnHasNonNullValue, + AtomicInteger hasNonNullValueColumnCount) { + if (globalTimeFilter != null + && !globalTimeFilter.satisfyStartEndTime( + alignedTVList.getMinTime(), alignedTVList.getMaxTime())) { + return; + } BitMap allValueColDeletedMap = alignedTVList.getAllValueColDeletedMap(); - int rowCount = alignedTVList.rowCount(); List valueColumnDeleteCursor = new ArrayList<>(); if (valueColumnsDeletionList != null) { valueColumnsDeletionList.forEach(x -> valueColumnDeleteCursor.add(new int[] {0})); } + // example: + // globalTimeFilter:null, ignoreAllNullRows: true + // tvList: + // time s1 s2 s3 + // 1 1 null null + // 2 null 1 null + // 2 1 1 null + // 3 1 null null + // 4 1 null 1 + // timestampWithBitmap: + // timestamp: 1 bitmap: 011 + // timestamp: 2 bitmap: 101 + // timestamp: 4 bitmap: 110 for (int row = 0; row < rowCount; row++) { // the row is deleted if (allValueColDeletedMap != null && allValueColDeletedMap.isMarked(row)) { continue; } long timestamp = alignedTVList.getTime(row); + if (globalTimeFilter != null && !globalTimeFilter.satisfy(timestamp, null)) { + continue; + } + + // Note that this method will only perform bitmap unmarking on the first occurrence of a + // non-null value in multiple timestamps for the same column. + BitMap currentRowNullValueBitmap = null; - BitMap bitMap = new BitMap(schemaList.size()); for (int column = 0; column < schemaList.size(); column++) { if (alignedTVList.isNullValue(alignedTVList.getValueIndex(row), column)) { - bitMap.mark(column); + continue; } // skip deleted row @@ -320,33 +391,44 @@ && isPointDeleted( timestamp, valueColumnsDeletionList.get(column), valueColumnDeleteCursor.get(column))) { - bitMap.mark(column); - } - - // skip all-null row - if (ignoreAllNullRows && bitMap.isAllMarked()) { continue; } - timestampWithBitmap.put(timestamp, bitMap); + if (!columnHasNonNullValue.isMarked(column)) { + hasNonNullValueColumnCount.incrementAndGet(); + columnHasNonNullValue.mark(column); + currentRowNullValueBitmap = + currentRowNullValueBitmap != null + ? currentRowNullValueBitmap + : timestampWithBitmap.computeIfAbsent( + timestamp, k -> getAllMarkedBitmap(schemaList.size())); + currentRowNullValueBitmap.unmark(column); + } } - } - } - public long[] getFilteredTimestamp( - List> deletionList, List bitMaps, boolean ignoreAllNullRows) { - Map timestampWithBitmap = new TreeMap<>(); + if (!ignoreAllNullRows) { + timestampWithBitmap.put( + timestamp, + currentRowNullValueBitmap != null + ? currentRowNullValueBitmap + : getAllMarkedBitmap(schemaList.size())); + return; + } + if (currentRowNullValueBitmap == null) { + continue; + } + // found new column with non-null value + timestampWithBitmap.put(timestamp, currentRowNullValueBitmap); - filterDeletedTimeStamp(list, deletionList, ignoreAllNullRows, timestampWithBitmap); - for (AlignedTVList alignedTVList : sortedList) { - filterDeletedTimeStamp(alignedTVList, deletionList, ignoreAllNullRows, timestampWithBitmap); + if (hasNonNullValueColumnCount.get() == schemaList.size()) { + return; + } } + } - List filteredTimestamps = new ArrayList<>(); - for (Map.Entry entry : timestampWithBitmap.entrySet()) { - filteredTimestamps.add(entry.getKey()); - bitMaps.add(entry.getValue()); - } - return filteredTimestamps.stream().mapToLong(Long::valueOf).toArray(); + private BitMap getAllMarkedBitmap(int size) { + BitMap bitMap = new BitMap(size); + bitMap.markAll(); + return bitMap; } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java index 6c6e09c57825..fd9ffe90b0ae 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/IMemTable.java @@ -125,7 +125,8 @@ void queryForSeriesRegionScan( long ttlLowerBound, Map> chunkMetadataMap, Map> memChunkHandleMap, - List> modsToMemtabled) + List> modsToMemtabled, + Filter globalTimeFilter) throws IOException, QueryProcessException, MetadataException; void queryForDeviceRegionScan( @@ -134,7 +135,8 @@ void queryForDeviceRegionScan( long ttlLowerBound, Map> chunkMetadataMap, Map> memChunkHandleMap, - List> modsToMemtabled) + List> modsToMemtabled, + Filter globalTimeFilter) throws IOException, QueryProcessException, MetadataException; /** putBack all the memory resources. */ diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java index 68776c92c0b5..36cbb4e0f880 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/TsFileProcessor.java @@ -1974,7 +1974,8 @@ private List getAlignedVisibleMetadataListFromWriterByDeviceID( public void queryForSeriesRegionScanWithoutLock( List pathList, QueryContext queryContext, - List fileScanHandlesForQuery) { + List fileScanHandlesForQuery, + Filter globalTimeFilter) { long startTime = System.nanoTime(); try { Map>> deviceToMemChunkHandleMap = new HashMap<>(); @@ -1995,7 +1996,8 @@ public void queryForSeriesRegionScanWithoutLock( timeLowerBound, measurementToChunkMetaList, measurementToChunkHandleList, - modsToMemtable); + modsToMemtable, + globalTimeFilter); } if (workMemTable != null) { workMemTable.queryForSeriesRegionScan( @@ -2003,7 +2005,8 @@ public void queryForSeriesRegionScanWithoutLock( timeLowerBound, measurementToChunkMetaList, measurementToChunkHandleList, - null); + null, + globalTimeFilter); } IDeviceID deviceID = seriesPath.getDeviceId(); // Some memTable have been flushed already, so we need to get the chunk metadata from @@ -2054,7 +2057,8 @@ public void queryForSeriesRegionScanWithoutLock( public void queryForDeviceRegionScanWithoutLock( Map devicePathsToContext, QueryContext queryContext, - List fileScanHandlesForQuery) { + List fileScanHandlesForQuery, + Filter globalTimeFilter) { long startTime = System.nanoTime(); try { Map>> deviceToMemChunkHandleMap = new HashMap<>(); @@ -2077,7 +2081,8 @@ public void queryForDeviceRegionScanWithoutLock( timeLowerBound, measurementToChunkMetadataList, measurementToMemChunkHandleList, - modsToMemtable); + modsToMemtable, + globalTimeFilter); } if (workMemTable != null) { workMemTable.queryForDeviceRegionScan( @@ -2086,7 +2091,8 @@ public void queryForDeviceRegionScanWithoutLock( timeLowerBound, measurementToChunkMetadataList, measurementToMemChunkHandleList, - null); + null, + globalTimeFilter); } buildChunkHandleForFlushedMemTable( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java index c4a871a29cc1..6499d33fcd9f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunk.java @@ -35,6 +35,7 @@ import org.apache.tsfile.enums.TSDataType; import org.apache.tsfile.read.TimeValuePair; import org.apache.tsfile.read.common.TimeRange; +import org.apache.tsfile.read.filter.basic.Filter; import org.apache.tsfile.utils.Binary; import org.apache.tsfile.utils.BitMap; import org.apache.tsfile.write.UnSupportedDataTypeException; @@ -48,10 +49,9 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.concurrent.BlockingQueue; -import java.util.stream.Collectors; import static org.apache.iotdb.db.utils.MemUtils.getBinarySize; @@ -575,11 +575,30 @@ public List getSortedList() { return sortedList; } - private void filterDeletedTimestamp( - TVList tvlist, List deletionList, List timestampList) { - long lastTime = Long.MIN_VALUE; + public Optional getAnySatisfiedTimestamp( + List deletionList, Filter globalTimeFilter) { + Optional anySatisfiedTimestamp = + getAnySatisfiedTimestamp(list, deletionList, globalTimeFilter); + if (anySatisfiedTimestamp.isPresent()) { + return anySatisfiedTimestamp; + } + for (TVList tvList : sortedList) { + anySatisfiedTimestamp = getAnySatisfiedTimestamp(tvList, deletionList, globalTimeFilter); + if (anySatisfiedTimestamp.isPresent()) { + break; + } + } + return anySatisfiedTimestamp; + } + + private Optional getAnySatisfiedTimestamp( + TVList tvlist, List deletionList, Filter globalTimeFilter) { int[] deletionCursor = {0}; int rowCount = tvlist.rowCount(); + if (globalTimeFilter != null + && !globalTimeFilter.satisfyStartEndTime(tvlist.getMinTime(), tvlist.getMaxTime())) { + return Optional.empty(); + } for (int i = 0; i < rowCount; i++) { if (tvlist.getBitMap() != null && tvlist.isNullValue(tvlist.getValueIndex(i))) { continue; @@ -589,27 +608,12 @@ private void filterDeletedTimestamp( && ModificationUtils.isPointDeleted(curTime, deletionList, deletionCursor)) { continue; } - - if (i == rowCount - 1 || curTime != lastTime) { - timestampList.add(curTime); + if (globalTimeFilter != null && !globalTimeFilter.satisfy(curTime, null)) { + continue; } - lastTime = curTime; + return Optional.of(curTime); } - } - - public long[] getFilteredTimestamp(List deletionList) { - List timestampList = new ArrayList<>(); - filterDeletedTimestamp(list, deletionList, timestampList); - for (TVList tvList : sortedList) { - filterDeletedTimestamp(tvList, deletionList, timestampList); - } - - // remove duplicated time - List distinctTimestamps = timestampList.stream().distinct().collect(Collectors.toList()); - // sort timestamps - long[] filteredTimestamps = distinctTimestamps.stream().mapToLong(Long::longValue).toArray(); - Arrays.sort(filteredTimestamps); - return filteredTimestamps; + return Optional.empty(); } @Override diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java new file mode 100644 index 000000000000..3967409db52d --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/storageengine/dataregion/memtable/WritableMemChunkRegionScanTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.storageengine.dataregion.memtable; + +import org.apache.iotdb.commons.path.AlignedFullPath; +import org.apache.iotdb.commons.path.NonAlignedFullPath; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.storageengine.dataregion.read.filescan.IChunkHandle; + +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.file.metadata.StringArrayDeviceID; +import org.apache.tsfile.read.common.TimeRange; +import org.apache.tsfile.read.filter.operator.TimeFilterOperators; +import org.apache.tsfile.utils.BitMap; +import org.apache.tsfile.write.schema.IMeasurementSchema; +import org.apache.tsfile.write.schema.MeasurementSchema; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +@RunWith(Parameterized.class) +public class WritableMemChunkRegionScanTest { + + @Parameterized.Parameters(name = "{0}") + public static Collection data() { + return Arrays.asList(new Object[][] {{0}, {1000}, {10000}, {20000}}); + } + + private int defaultTvListThreshold; + private int tvListSortThreshold; + + public WritableMemChunkRegionScanTest(int tvListSortThreshold) { + this.tvListSortThreshold = tvListSortThreshold; + } + + @Before + public void setup() { + defaultTvListThreshold = IoTDBDescriptor.getInstance().getConfig().getTvListSortThreshold(); + IoTDBDescriptor.getInstance().getConfig().setTVListSortThreshold(tvListSortThreshold); + } + + @After + public void tearDown() { + IoTDBDescriptor.getInstance().getConfig().setTVListSortThreshold(defaultTvListThreshold); + } + + @Test + public void testAlignedWritableMemChunkRegionScan() { + PrimitiveMemTable memTable = new PrimitiveMemTable("root.test", "0"); + try { + List measurementSchemas = + Arrays.asList( + new MeasurementSchema("s1", TSDataType.INT32), + new MeasurementSchema("s2", TSDataType.INT32), + new MeasurementSchema("s3", TSDataType.INT32)); + AlignedWritableMemChunk writableMemChunk = null; + int size = 100000; + for (int i = 0; i < size; i++) { + if (i <= 10000) { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {1, null, 1}); + } else if (i <= 20000) { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {null, null, 2}); + } else if (i <= 30000) { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {3, null, null}); + } else { + memTable.writeAlignedRow( + new StringArrayDeviceID("root.test.d1"), + measurementSchemas, + i, + new Object[] {4, 4, 4}); + } + } + writableMemChunk = + (AlignedWritableMemChunk) + memTable.getWritableMemChunk(new StringArrayDeviceID("root.test.d1"), ""); + List bitMaps = new ArrayList<>(); + long[] timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + bitMaps, + true, + null); + Assert.assertEquals(2, timestamps.length); + Assert.assertEquals(0, timestamps[0]); + Assert.assertFalse(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertFalse(bitMaps.get(0).isMarked(2)); + Assert.assertTrue(bitMaps.get(1).isMarked(0)); + Assert.assertFalse(bitMaps.get(1).isMarked(1)); + Assert.assertTrue(bitMaps.get(1).isMarked(2)); + Assert.assertEquals(30001, timestamps[1]); + + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(new TimeRange(0, 12000))), + bitMaps, + true, + new TimeFilterOperators.TimeGt(10000000)); + Assert.assertEquals(0, timestamps.length); + + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(new TimeRange(0, 12000))), + bitMaps, + true, + new TimeFilterOperators.TimeGt(11000)); + + Assert.assertEquals(3, timestamps.length); + Assert.assertEquals(12001, timestamps[0]); + Assert.assertTrue(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertFalse(bitMaps.get(0).isMarked(2)); + Assert.assertEquals(20001, timestamps[1]); + Assert.assertFalse(bitMaps.get(1).isMarked(0)); + Assert.assertTrue(bitMaps.get(1).isMarked(1)); + Assert.assertTrue(bitMaps.get(1).isMarked(2)); + Assert.assertEquals(30001, timestamps[2]); + Assert.assertTrue(bitMaps.get(2).isMarked(0)); + Assert.assertFalse(bitMaps.get(2).isMarked(1)); + Assert.assertTrue(bitMaps.get(2).isMarked(2)); + + writableMemChunk.writeAlignedPoints( + 1000001, new Object[] {1, null, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints( + 1000002, new Object[] {null, 1, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints(1000002, new Object[] {1, 1, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints( + 1000003, new Object[] {1, null, null}, measurementSchemas); + writableMemChunk.writeAlignedPoints(1000004, new Object[] {1, null, 1}, measurementSchemas); + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + bitMaps, + true, + new TimeFilterOperators.TimeGt(1000000)); + Assert.assertEquals(3, timestamps.length); + Assert.assertEquals(1000001, timestamps[0]); + Assert.assertFalse(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertTrue(bitMaps.get(0).isMarked(2)); + Assert.assertEquals(1000002, timestamps[1]); + Assert.assertTrue(bitMaps.get(1).isMarked(0)); + Assert.assertFalse(bitMaps.get(1).isMarked(1)); + Assert.assertTrue(bitMaps.get(1).isMarked(2)); + Assert.assertEquals(1000004, timestamps[2]); + Assert.assertTrue(bitMaps.get(2).isMarked(0)); + Assert.assertTrue(bitMaps.get(2).isMarked(1)); + Assert.assertFalse(bitMaps.get(2).isMarked(2)); + + Map> chunkHandleMap = new HashMap<>(); + memTable.queryForDeviceRegionScan( + new StringArrayDeviceID("root.test.d1"), + true, + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1000000)); + Assert.assertEquals(3, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {1000001, 1000001}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000002, 1000002}, chunkHandleMap.get("s2").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000004, 1000004}, chunkHandleMap.get("s3").get(0).getPageStatisticsTime()); + + memTable.queryForSeriesRegionScan( + new AlignedFullPath( + new StringArrayDeviceID("root.test.d1"), + IMeasurementSchema.getMeasurementNameList(measurementSchemas), + measurementSchemas), + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1000000)); + Assert.assertEquals(3, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {1000001, 1000001}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000002, 1000002}, chunkHandleMap.get("s2").get(0).getPageStatisticsTime()); + Assert.assertArrayEquals( + new long[] {1000004, 1000004}, chunkHandleMap.get("s3").get(0).getPageStatisticsTime()); + } finally { + memTable.release(); + } + } + + @Test + public void testTableWritableMemChunkRegionScan() { + List measurementSchemas = + Arrays.asList( + new MeasurementSchema("s1", TSDataType.INT32), + new MeasurementSchema("s2", TSDataType.INT32), + new MeasurementSchema("s3", TSDataType.INT32)); + AlignedWritableMemChunk writableMemChunk = + new AlignedWritableMemChunk(measurementSchemas, true); + int size = 100000; + for (int i = 0; i < size; i++) { + if (i <= 10000) { + writableMemChunk.writeAlignedPoints(i, new Object[] {1, null, 1}, measurementSchemas); + } else if (i <= 20000) { + writableMemChunk.writeAlignedPoints(i, new Object[] {null, null, 2}, measurementSchemas); + } else if (i <= 30000) { + writableMemChunk.writeAlignedPoints(i, new Object[] {3, null, null}, measurementSchemas); + } else { + writableMemChunk.writeAlignedPoints(i, new Object[] {4, 4, 4}, measurementSchemas); + } + } + List bitMaps = new ArrayList<>(); + long[] timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), Collections.emptyList(), Collections.emptyList()), + bitMaps, + false, + null); + Assert.assertEquals(1, timestamps.length); + Assert.assertEquals(0, timestamps[0]); + Assert.assertFalse(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertFalse(bitMaps.get(0).isMarked(2)); + + bitMaps = new ArrayList<>(); + timestamps = + writableMemChunk.getAnySatisfiedTimestamp( + Arrays.asList( + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(new TimeRange(0, 12000))), + bitMaps, + false, + new TimeFilterOperators.TimeGt(11000)); + + Assert.assertEquals(1, timestamps.length); + Assert.assertEquals(11001, timestamps[0]); + Assert.assertTrue(bitMaps.get(0).isMarked(0)); + Assert.assertTrue(bitMaps.get(0).isMarked(1)); + Assert.assertTrue(bitMaps.get(0).isMarked(2)); + } + + @Test + public void testNonAlignedWritableMemChunkRegionScan() { + PrimitiveMemTable memTable = new PrimitiveMemTable("root.test", "0"); + try { + MeasurementSchema measurementSchema = new MeasurementSchema("s1", TSDataType.INT32); + int size = 100000; + for (int i = 0; i < size; i++) { + memTable.write( + new StringArrayDeviceID("root.test.d1"), + Collections.singletonList(measurementSchema), + i, + new Object[] {i}); + } + WritableMemChunk writableMemChunk = + (WritableMemChunk) + memTable.getWritableMemChunk(new StringArrayDeviceID("root.test.d1"), "s1"); + Optional timestamp = writableMemChunk.getAnySatisfiedTimestamp(null, null); + Assert.assertTrue(timestamp.isPresent()); + Assert.assertEquals(0, timestamp.get().longValue()); + + timestamp = + writableMemChunk.getAnySatisfiedTimestamp( + null, new TimeFilterOperators.TimeBetweenAnd(1000L, 2000L)); + Assert.assertTrue(timestamp.isPresent()); + Assert.assertEquals(1000, timestamp.get().longValue()); + + timestamp = + writableMemChunk.getAnySatisfiedTimestamp( + Collections.singletonList(new TimeRange(1, 1500)), + new TimeFilterOperators.TimeBetweenAnd(1000L, 2000L)); + Assert.assertTrue(timestamp.isPresent()); + Assert.assertEquals(1501, timestamp.get().longValue()); + + timestamp = + writableMemChunk.getAnySatisfiedTimestamp( + Collections.singletonList(new TimeRange(1, 1500)), + new TimeFilterOperators.TimeBetweenAnd(100000L, 200000L)); + Assert.assertFalse(timestamp.isPresent()); + + Map> chunkHandleMap = new HashMap<>(); + memTable.queryForDeviceRegionScan( + new StringArrayDeviceID("root.test.d1"), + false, + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1)); + Assert.assertEquals(1, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {2, 2}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + memTable.queryForSeriesRegionScan( + new NonAlignedFullPath(new StringArrayDeviceID("root.test.d1"), measurementSchema), + Long.MIN_VALUE, + new HashMap<>(), + chunkHandleMap, + Collections.emptyList(), + new TimeFilterOperators.TimeGt(1)); + Assert.assertEquals(1, chunkHandleMap.size()); + Assert.assertArrayEquals( + new long[] {2, 2}, chunkHandleMap.get("s1").get(0).getPageStatisticsTime()); + } finally { + memTable.release(); + } + } +} From a8903f478aca3fa7086e748281ab4961521645d0 Mon Sep 17 00:00:00 2001 From: Yongzao Date: Wed, 10 Dec 2025 09:26:22 +0800 Subject: [PATCH 07/13] [AINode] Refactoring of Model Storage, Loading, and Inference Pipeline (#16819) Co-authored-by: RkGrit Co-authored-by: Gewu <89496957+RkGrit@users.noreply.github.com> (cherry picked from commit d7898c44343c731c936b06aebfb23ca8172cf58b) --- .github/workflows/cluster-it-1c1d1a.yml | 3 - .../it/env/cluster/node/AINodeWrapper.java | 2 +- .../ainode/it/AINodeCallInferenceIT.java | 117 ++ .../ainode/it/AINodeConcurrentForecastIT.java | 113 ++ .../it/AINodeConcurrentInferenceIT.java | 187 --- .../iotdb/ainode/it/AINodeForecastIT.java | 98 ++ .../iotdb/ainode/it/AINodeInferenceSQLIT.java | 344 ----- .../ainode/it/AINodeInstanceManagementIT.java | 79 +- .../iotdb/ainode/it/AINodeModelManageIT.java | 53 +- .../iotdb/ainode/utils/AINodeTestUtils.java | 119 +- .../relational/it/schema/IoTDBDatabaseIT.java | 14 +- .../test/resources/ainode-example/config.yaml | 5 - .../test/resources/ainode-example/model.pt | Bin 1906 -> 0 bytes iotdb-core/ainode/ainode.spec | 138 +- iotdb-core/ainode/iotdb/ainode/core/config.py | 33 +- .../ainode/iotdb/ainode/core/constant.py | 173 +-- .../ainode/iotdb/ainode/core/exception.py | 2 +- .../core/inference/inference_request.py | 22 +- .../core/inference/inference_request_pool.py | 116 +- .../{strategy => pipeline}/__init__.py | 0 .../core/inference/pipeline/basic_pipeline.py | 87 ++ .../inference/pipeline/pipeline_loader.py | 56 + .../ainode/core/inference/pool_controller.py | 83 +- .../pool_scheduler/basic_pool_scheduler.py | 2 +- .../strategy/abstract_inference_pipeline.py | 60 - .../iotdb/ainode/core/inference/utils.py | 45 +- .../ainode/core/manager/inference_manager.py | 172 +-- .../ainode/core/manager/model_manager.py | 162 +-- .../ainode/iotdb/ainode/core/manager/utils.py | 7 +- .../core/model/built_in_model_factory.py | 1238 ----------------- .../ainode/core/model/model_constants.py | 48 + .../iotdb/ainode/core/model/model_enums.py | 70 - .../iotdb/ainode/core/model/model_factory.py | 291 ---- .../iotdb/ainode/core/model/model_info.py | 130 +- .../iotdb/ainode/core/model/model_loader.py | 156 +++ .../iotdb/ainode/core/model/model_storage.py | 808 ++++++----- .../model/{timerxl => sktime}/__init__.py | 0 .../core/model/sktime/arima/config.json | 25 + .../core/model/sktime/configuration_sktime.py | 409 ++++++ .../sktime/exponential_smoothing/config.json | 11 + .../model/sktime/gaussian_hmm/config.json | 22 + .../core/model/sktime/gmm_hmm/config.json | 24 + .../core/model/sktime/modeling_sktime.py | 180 +++ .../model/sktime/naive_forecaster/config.json | 9 + .../core/model/sktime/pipeline_sktime.py | 68 + .../model/sktime/stl_forecaster/config.json | 22 + .../core/model/sktime/stray/config.json | 11 + .../core/model/sundial/modeling_sundial.py | 21 +- .../sundial/pipeline_sundial.py} | 42 +- .../ainode/core/model/timer_xl/__init__.py | 17 + .../configuration_timer.py | 0 .../{timerxl => timer_xl}/modeling_timer.py | 15 +- .../timer_xl/pipeline_timer.py} | 36 +- .../ts_generation_mixin.py | 0 .../iotdb/ainode/core/model/uri_utils.py | 137 -- .../ainode/iotdb/ainode/core/model/utils.py | 98 ++ .../ainode/iotdb/ainode/core/rpc/client.py | 39 - .../ainode/iotdb/ainode/core/rpc/handler.py | 47 +- iotdb-core/ainode/pyproject.toml | 9 +- .../async/AsyncAINodeHeartbeatClientPool.java | 19 +- .../AsyncDataNodeHeartbeatClientPool.java | 1 - .../consensus/request/ConfigPhysicalPlan.java | 24 - .../request/read/model/GetModelInfoPlan.java | 64 - .../request/read/model/ShowModelPlan.java | 70 - .../request/write/model/CreateModelPlan.java | 79 -- .../write/model/DropModelInNodePlan.java | 70 - .../request/write/model/DropModelPlan.java | 79 -- .../write/model/UpdateModelInfoPlan.java | 122 -- .../response/model/GetModelInfoResp.java | 63 - .../response/model/ModelTableResp.java | 62 - .../confignode/manager/ConfigManager.java | 179 --- .../iotdb/confignode/manager/IManager.java | 42 - .../confignode/manager/ModelManager.java | 245 ---- .../confignode/manager/ProcedureManager.java | 20 - .../confignode/persistence/ModelInfo.java | 378 ----- .../executor/ConfigPlanExecutor.java | 25 - .../impl/model/CreateModelProcedure.java | 250 ---- .../impl/model/DropModelProcedure.java | 200 --- .../impl/node/RemoveAINodeProcedure.java | 17 +- .../procedure/state/RemoveAINodeState.java | 1 - .../procedure/store/ProcedureFactory.java | 12 - .../procedure/store/ProcedureType.java | 2 + .../thrift/ConfigNodeRPCServiceProcessor.java | 25 - .../protocol/client/AINodeClientFactory.java | 133 -- .../db/protocol/client/ConfigNodeClient.java | 30 +- .../client/DataNodeClientPoolFactory.java | 42 +- .../protocol/client/ainode/AINodeClient.java | 401 ------ .../client/ainode/AINodeClientManager.java | 75 - .../db/protocol/client/an/AINodeClient.java | 321 +++++ .../client/an/AINodeClientManager.java | 47 + .../process/ai/InferenceOperator.java | 82 +- ...formationSchemaContentSupplierFactory.java | 113 -- .../plan/analyze/AnalyzeVisitor.java | 134 +- .../plan/analyze/IModelFetcher.java | 4 - .../plan/analyze/ModelFetcher.java | 51 +- .../executor/ClusterConfigTaskExecutor.java | 43 +- .../plan/node/process/AI/InferenceNode.java | 3 +- .../model/ModelInferenceDescriptor.java | 61 +- .../analyzer/StatementAnalyzer.java | 6 - .../function/tvf/ForecastTableFunction.java | 39 +- .../plan/relational/metadata/Metadata.java | 6 - .../metadata/TableMetadataImpl.java | 5 - .../DataNodeLocationSupplierFactory.java | 1 - .../db/queryengine/plan/udf/UDTFForecast.java | 25 +- .../relational/analyzer/TSBSMetadata.java | 6 - .../analyzer/TableFunctionTest.java | 3 - .../relational/analyzer/TestMetadata.java | 19 - iotdb-core/node-commons/pom.xml | 5 + .../commons/client/ClientPoolFactory.java | 28 + .../AsyncAINodeInternalServiceClient.java} | 25 +- .../iotdb/commons/model/ModelInformation.java | 43 +- .../iotdb/commons/model/ModelTable.java | 4 +- .../schema/table/InformationSchema.java | 18 - .../src/main/thrift/ainode.thrift | 8 +- .../src/main/thrift/confignode.thrift | 63 - 115 files changed, 3205 insertions(+), 6963 deletions(-) create mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java create mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java delete mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java create mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java delete mode 100644 integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java delete mode 100644 integration-test/src/test/resources/ainode-example/config.yaml delete mode 100644 integration-test/src/test/resources/ainode-example/model.pt rename iotdb-core/ainode/iotdb/ainode/core/inference/{strategy => pipeline}/__init__.py (100%) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => sktime}/__init__.py (100%) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json rename iotdb-core/ainode/iotdb/ainode/core/{inference/strategy/timer_sundial_inference_pipeline.py => model/sundial/pipeline_sundial.py} (56%) create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/configuration_timer.py (100%) rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/modeling_timer.py (98%) rename iotdb-core/ainode/iotdb/ainode/core/{inference/strategy/timerxl_inference_pipeline.py => model/timer_xl/pipeline_timer.py} (52%) rename iotdb-core/ainode/iotdb/ainode/core/model/{timerxl => timer_xl}/ts_generation_mixin.py (100%) delete mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py create mode 100644 iotdb-core/ainode/iotdb/ainode/core/model/utils.py delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java delete mode 100644 iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java delete mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java delete mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java delete mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java create mode 100644 iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java rename iotdb-core/{datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java => node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java} (83%) diff --git a/.github/workflows/cluster-it-1c1d1a.yml b/.github/workflows/cluster-it-1c1d1a.yml index d4c40fa7ad88..67be8f1a5f6c 100644 --- a/.github/workflows/cluster-it-1c1d1a.yml +++ b/.github/workflows/cluster-it-1c1d1a.yml @@ -41,9 +41,6 @@ jobs: steps: - uses: actions/checkout@v4 - - name: Build AINode - shell: bash - run: mvn clean package -DskipTests -P with-ainode - name: IT Test shell: bash run: | diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java index e118d6c3a98f..34fd7e85240c 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/node/AINodeWrapper.java @@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper { private static final String PROPERTIES_FILE = "iotdb-ainode.properties"; public static final String CONFIG_PATH = "conf"; public static final String SCRIPT_PATH = "sbin"; - public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights"; + public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin"; public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights"; private void replaceAttribute(String[] keys, String[] values, String filePath) { diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java new file mode 100644 index 000000000000..44e280eca169 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeCallInferenceIT { + + private static final String[] WRITE_SQL_IN_TREE = + new String[] { + "CREATE DATABASE root.AI", + "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", + "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", + }; + + private static final String CALL_INFERENCE_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final int DEFAULT_INPUT_LENGTH = 256; + private static final int DEFAULT_OUTPUT_LENGTH = 48; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareData(WRITE_SQL_IN_TREE); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void callInferenceTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + callInferenceTest(statement, modelInfo); + } + } + } + + public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format( + CALL_INFERENCE_SQL_TEMPLATE, + modelInfo.getModelId(), + i, + DEFAULT_INPUT_LENGTH, + DEFAULT_OUTPUT_LENGTH); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output"); + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count); + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java new file mode 100644 index 000000000000..64029c1e34b8 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentForecastIT.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeConcurrentForecastIT { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class); + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataForTableModel(); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + private static void prepareDataForTableModel() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE root"); + statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); + } + } + } + + @Test + public void concurrentGPUForecastTest() throws SQLException, InterruptedException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) { + concurrentGPUForecastTest(modelInfo); + } + } + + public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException, InterruptedException { + final int forecastLength = 512; + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + // Single forecast request can be processed successfully + final String forecastSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength); + final int threadCnt = 10; + final int loop = 100; + final String devices = "0,1"; + statement.execute( + String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); + long startTime = System.currentTimeMillis(); + concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength); + long endTime = System.currentTimeMillis(); + LOGGER.info( + String.format( + "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", + modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime)); + statement.execute( + String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices)); + checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java deleted file mode 100644 index a08990d472fe..000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeConcurrentInferenceIT.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import com.google.common.collect.ImmutableSet; -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.HashSet; -import java.util.Set; -import java.util.concurrent.TimeUnit; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeConcurrentInferenceIT { - - private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentInferenceIT.class); - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataForTreeModel(); - prepareDataForTableModel(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - private static void prepareDataForTreeModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root.AI"); - statement.execute("CREATE TIMESERIES root.AI.s WITH DATATYPE=DOUBLE, ENCODING=RLE"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp, s) VALUES(%d, %f)", - i, Math.sin(i * Math.PI / 1440))); - } - } - } - - private static void prepareDataForTableModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - statement.execute("CREATE DATABASE root"); - statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)"); - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440))); - } - } - } - - // @Test - public void concurrentGPUCallInferenceTest() throws SQLException, InterruptedException { - concurrentGPUCallInferenceTest("timer_xl"); - concurrentGPUCallInferenceTest("sundial"); - } - - private void concurrentGPUCallInferenceTest(String modelId) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - concurrentInference( - statement, - String.format( - "CALL INFERENCE(%s, 'SELECT s FROM root.AI', predict_length=%d)", - modelId, predictLength), - threadCnt, - loop, - predictLength); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - String forecastTableFunctionSql = - "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time), predict_length=>%d"; - String forecastUDTFSql = - "SELECT forecast(s, 'MODEL_ID'='%s', 'PREDICT_LENGTH'='%d') FROM root.AI"; - - @Test - public void concurrentGPUForecastTest() throws SQLException, InterruptedException { - concurrentGPUForecastTest("timer_xl", forecastUDTFSql); - concurrentGPUForecastTest("sundial", forecastUDTFSql); - concurrentGPUForecastTest("timer_xl", forecastTableFunctionSql); - concurrentGPUForecastTest("sundial", forecastTableFunctionSql); - } - - public void concurrentGPUForecastTest(String modelId, String selectSql) - throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - final int threadCnt = 10; - final int loop = 100; - final int predictLength = 512; - final String devices = "0,1"; - statement.execute(String.format("LOAD MODEL %s TO DEVICES '%s'", modelId, devices)); - checkModelOnSpecifiedDevice(statement, modelId, devices); - long startTime = System.currentTimeMillis(); - concurrentInference( - statement, - String.format(selectSql, modelId, predictLength), - threadCnt, - loop, - predictLength); - long endTime = System.currentTimeMillis(); - LOGGER.info( - String.format( - "Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms", - modelId, threadCnt * loop, threadCnt, loop, endTime - startTime)); - statement.execute(String.format("UNLOAD MODEL %s FROM DEVICES '0,1'", modelId)); - } - } - - private void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) - throws SQLException, InterruptedException { - Set targetDevices = ImmutableSet.copyOf(device.split(",")); - LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); - for (int retry = 0; retry < 200; retry++) { - Set foundDevices = new HashSet<>(); - try (final ResultSet resultSet = - statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { - while (resultSet.next()) { - String deviceId = resultSet.getString("DeviceId"); - String loadedModelId = resultSet.getString("ModelId"); - int count = resultSet.getInt("Count(instances)"); - LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); - if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { - foundDevices.add(deviceId); - LOGGER.info("Model {} is loaded to device {}", modelId, device); - } - } - if (foundDevices.containsAll(targetDevices)) { - LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); - return; - } - } - TimeUnit.SECONDS.sleep(3); - } - Assert.fail("Model " + modelId + " is not loaded on device " + device); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java new file mode 100644 index 000000000000..a06656d4adac --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; + +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeForecastIT { + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time, s%d FROM db.AI) ORDER BY time)"; + + @BeforeClass + public static void setUp() throws Exception { + // Init 1C1D1A cluster environment + EnvFactory.getEnv().initClusterEnvironment(1, 1); + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("CREATE DATABASE db"); + statement.execute( + "CREATE TABLE db.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)"); + for (int i = 0; i < 2880; i++) { + statement.execute( + String.format( + "INSERT INTO db.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", + i, (float) i, (double) i, i, i)); + } + } + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void forecastTableFunctionTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + forecastTableFunctionTest(statement, modelInfo); + } + } + } + + public void forecastTableFunctionTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + // Invoke call inference for specified models, there should exist result. + for (int i = 0; i < 4; i++) { + String forecastTableFunctionSQL = + String.format(FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + // Ensure the call inference return results + Assert.assertTrue(count > 0); + } + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java deleted file mode 100644 index 70f7a1d9f9eb..000000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInferenceSQLIT.java +++ /dev/null @@ -1,344 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareData; -import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData; -import static org.junit.Assert.assertEquals; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeInferenceSQLIT { - - static String[] WRITE_SQL_IN_TREE = - new String[] { - "set configuration \"trusted_uri_pattern\"='.*'", - "create model identity using uri \"" + EXAMPLE_MODEL_PATH + "\"", - "CREATE DATABASE root.AI", - "CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE", - "CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE", - }; - - static String[] WRITE_SQL_IN_TABLE = - new String[] { - "CREATE DATABASE root", - "CREATE TABLE root.AI (s0 FLOAT FIELD, s1 DOUBLE FIELD, s2 INT32 FIELD, s3 INT64 FIELD)", - }; - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareData(WRITE_SQL_IN_TREE); - prepareTableData(WRITE_SQL_IN_TABLE); - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < 2880; i++) { - statement.execute( - String.format( - "INSERT INTO root.AI(time,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)", - i, (float) i, (double) i, i, i)); - } - } - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - // @Test - public void callInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void callInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement); - } - } - - public void callInferenceTest(Statement statement) throws SQLException { - // SQL0: Invoke timer-sundial and timer-xl to inference, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(sundial, \"select s1 from root.AI\", generateTime=true, predict_length=720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "CALL INFERENCE(timer_xl, \"select s2 from root.AI\", generateTime=true, predict_length=256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output0"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", generateTime=true)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "CALL INFERENCE(identity, \"select s2,s0,s3,s1 from root.AI\", generateTime=false)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "CALL INFERENCE(naive_forecaster, \"select s0 from root.AI\", predict_length=3, generateTime=true)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "CALL INFERENCE(holtwinters, \"select s0 from root.AI\", predict_length=6, generateTime=true)"; - // TODO: enable following tests after refactor the CALL INFERENCE - - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2,output3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0,output1,output2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "Time,output0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - - // @Test - public void errorCallInferenceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - // TODO: Enable this test after the call inference is supported by the table model - // @Test - public void errorCallInferenceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorCallInferenceTest(statement); - } - } - - public void errorCallInferenceTest(Statement statement) { - String sql = "CALL INFERENCE(notFound404, \"select s0,s1,s2 from root.AI\", window=head(5))"; - errorTest(statement, sql, "1505: model [notFound404] has not been created."); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI\", window=head(2))"; - // TODO: enable following tests after refactor the CALL INFERENCE - // errorTest(statement, sql, "701: Window output 2 is not equal to input size of model 7"); - sql = "CALL INFERENCE(identity, \"select s0,s1,s2,s3 from root.AI limit 5\")"; - // errorTest( - // statement, - // sql, - // "301: The number of rows 5 in the input data does not match the model input 7. Try to - // use LIMIT in SQL or WINDOW in CALL INFERENCE"); - sql = "CREATE MODEL 中文 USING URI \"" + EXAMPLE_MODEL_PATH + "\""; - errorTest(statement, sql, "701: ModelId can only contain letters, numbers, and underscores"); - } - - @Test - public void selectForecastTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - // SQL0: Invoke timer-sundial and timer-xl to forecast, the result should success - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'sundial', input=>(SELECT time,s1 FROM root.AI) ORDER BY time, output_length=>720)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s1"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(720, count); - } - try (ResultSet resultSet = - statement.executeQuery( - "SELECT * FROM FORECAST(model_id=>'timer_xl', input=>(SELECT time,s2 FROM root.AI) ORDER BY time, output_length=>256)")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "time,s2"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(256, count); - } - // SQL1: user-defined model inferences multi-columns with generateTime=true - String sql1 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s0,s1,s2,s3 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL2: user-defined model inferences multi-columns with generateTime=false - String sql2 = - "SELECT * FROM FORECAST(model_id=>'identity', input=>(SELECT time,s2,s0,s3,s1 FROM root.AI) ORDER BY time, output_length=>7)"; - // SQL3: built-in model inferences single column with given predict_length and multi-outputs - String sql3 = - "SELECT * FROM FORECAST(model_id=>'naive_forecaster', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>3)"; - // SQL4: built-in model inferences single column with given predict_length - String sql4 = - "SELECT * FROM FORECAST(model_id=>'holtwinters', input=>(SELECT time,s0 FROM root.AI) ORDER BY time, output_length=>6)"; - // TODO: enable following tests after refactor the FORECAST - // try (ResultSet resultSet = statement.executeQuery(sql1)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2,s3"); - // int count = 0; - // while (resultSet.next()) { - // float s0 = resultSet.getFloat(2); - // float s1 = resultSet.getFloat(3); - // float s2 = resultSet.getFloat(4); - // float s3 = resultSet.getFloat(5); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - // - // try (ResultSet resultSet = statement.executeQuery(sql2)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s2,s0,s3,s1"); - // int count = 0; - // while (resultSet.next()) { - // float s2 = resultSet.getFloat(1); - // float s0 = resultSet.getFloat(2); - // float s3 = resultSet.getFloat(3); - // float s1 = resultSet.getFloat(4); - // - // assertEquals(s0, count + 1.0, 0.0001); - // assertEquals(s1, count + 2.0, 0.0001); - // assertEquals(s2, count + 3.0, 0.0001); - // assertEquals(s3, count + 4.0, 0.0001); - // count++; - // } - // assertEquals(7, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql3)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0,s1,s2"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(3, count); - // } - - // try (ResultSet resultSet = statement.executeQuery(sql4)) { - // ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - // checkHeader(resultSetMetaData, "time,s0"); - // int count = 0; - // while (resultSet.next()) { - // count++; - // } - // assertEquals(6, count); - // } - } - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java index 93351c017852..2ae1b860cd23 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java @@ -35,14 +35,14 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Set; -import java.util.concurrent.TimeUnit; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; public class AINodeInstanceManagementIT { - private static final int WAITING_TIME_SEC = 30; private static final Set TARGET_DEVICES = new HashSet<>(Arrays.asList("cpu", "0", "1")); @BeforeClass @@ -85,52 +85,18 @@ private void basicManagementTest(Statement statement) throws SQLException, Inter } // Load sundial to each device - statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS 0")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - while (resultSet.next()) { - Assert.assertEquals("0", resultSet.getString("DeviceID")); - Assert.assertEquals("Timer-Sundial", resultSet.getString("ModelType")); - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - } - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); // Load timer_xl to each device - statement.execute("LOAD MODEL timer_xl TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - if (resultSet.getString("ModelType").equals("Timer-XL")) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertTrue(resultSet.getInt("Count(instances)") > 1); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES)); + checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); // Clean every device - statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - statement.execute("UNLOAD MODEL timer_xl FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES)); + statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES)); + checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES.toString()); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } private static final int LOOP_CNT = 10; @@ -141,23 +107,9 @@ public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException Statement statement = connection.createStatement()) { for (int i = 0; i < LOOP_CNT; i++) { statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - final Set resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceID")); - } - Assert.assertEquals(TARGET_DEVICES, resultDevices); - } + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } } @@ -170,12 +122,7 @@ public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedExcept statement.execute("LOAD MODEL sundial TO DEVICES \"cpu,0,1\""); statement.execute("UNLOAD MODEL sundial FROM DEVICES \"cpu,0,1\""); } - TimeUnit.SECONDS.sleep(WAITING_TIME_SEC * LOOP_CNT); - try (ResultSet resultSet = statement.executeQuery("SHOW LOADED MODELS")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceID,ModelType,Count(instances)"); - Assert.assertFalse(resultSet.next()); - } + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES.toString()); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index 2a1461e4a15b..b92b80aecf32 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -19,6 +19,7 @@ package org.apache.iotdb.ainode.it; +import org.apache.iotdb.ainode.utils.AINodeTestUtils; import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo; import org.apache.iotdb.it.env.EnvFactory; import org.apache.iotdb.it.framework.IoTDBTestRunner; @@ -36,13 +37,8 @@ import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; -import java.util.AbstractMap; -import java.util.Map; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.EXAMPLE_MODEL_PATH; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; import static org.junit.Assert.assertEquals; @@ -54,36 +50,6 @@ @Category({AIClusterIT.class}) public class AINodeModelManageIT { - private static final Map BUILT_IN_MODEL_MAP = - Stream.of( - new AbstractMap.SimpleEntry<>( - "arima", new FakeModelInfo("arima", "Arima", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "holtwinters", - new FakeModelInfo("holtwinters", "HoltWinters", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "exponential_smoothing", - new FakeModelInfo( - "exponential_smoothing", "ExponentialSmoothing", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "naive_forecaster", - new FakeModelInfo("naive_forecaster", "NaiveForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stl_forecaster", - new FakeModelInfo("stl_forecaster", "StlForecaster", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gaussian_hmm", - new FakeModelInfo("gaussian_hmm", "GaussianHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "gmm_hmm", new FakeModelInfo("gmm_hmm", "GmmHmm", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "stray", new FakeModelInfo("stray", "Stray", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "sundial", new FakeModelInfo("sundial", "Timer-Sundial", "BUILT-IN", "ACTIVE")), - new AbstractMap.SimpleEntry<>( - "timer_xl", new FakeModelInfo("timer_xl", "Timer-XL", "BUILT-IN", "ACTIVE"))) - .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); - @BeforeClass public static void setUp() throws Exception { // Init 1C1D1A cluster environment @@ -95,7 +61,7 @@ public static void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } - @Test + // @Test public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -103,7 +69,7 @@ public void userDefinedModelManagementTestInTree() throws SQLException, Interrup } } - @Test + // @Test public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { @@ -114,8 +80,7 @@ public void userDefinedModelManagementTestInTable() throws SQLException, Interru private void userDefinedModelManagementTest(Statement statement) throws SQLException, InterruptedException { final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = - "create model operationTest using uri \"" + EXAMPLE_MODEL_PATH + "\""; + final String registerSql = "create model operationTest using uri \"" + "\""; final String showSql = "SHOW MODELS operationTest"; final String dropSql = "DROP MODEL operationTest"; @@ -166,7 +131,7 @@ private void userDefinedModelManagementTest(Statement statement) public void dropBuiltInModelErrorTestInTree() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); } } @@ -174,7 +139,7 @@ public void dropBuiltInModelErrorTestInTree() throws SQLException { public void dropBuiltInModelErrorTestInTable() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1501: Built-in model sundial can't be removed"); + errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); } } @@ -208,10 +173,10 @@ private void showBuiltInModelTest(Statement statement) throws SQLException { resultSet.getString(2), resultSet.getString(3), resultSet.getString(4)); - assertTrue(BUILT_IN_MODEL_MAP.containsKey(modelInfo.getModelId())); - assertEquals(BUILT_IN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); + assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId())); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); } } - assertEquals(BUILT_IN_MODEL_MAP.size(), built_in_model_count); + assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.size(), built_in_model_count); } } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java index cbb0b03b2299..0de90c42925f 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java @@ -19,30 +19,68 @@ package org.apache.iotdb.ainode.utils; -import java.io.File; +import com.google.common.collect.ImmutableSet; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import java.sql.ResultSet; import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; +import java.util.AbstractMap; +import java.util.Collections; +import java.util.HashSet; +import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; public class AINodeTestUtils { - public static final String EXAMPLE_MODEL_PATH = - "file://" - + System.getProperty("user.dir") - + File.separator - + "src" - + File.separator - + "test" - + File.separator - + "resources" - + File.separator - + "ainode-example"; + public static final Map BUILTIN_LTSM_MAP = + Stream.of( + new AbstractMap.SimpleEntry<>( + "sundial", new FakeModelInfo("sundial", "sundial", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "timer_xl", new FakeModelInfo("timer_xl", "timer", "builtin", "active"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + public static final Map BUILTIN_MODEL_MAP; + + static { + Map tmp = + Stream.of( + new AbstractMap.SimpleEntry<>( + "arima", new FakeModelInfo("arima", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "holtwinters", new FakeModelInfo("holtwinters", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "exponential_smoothing", + new FakeModelInfo("exponential_smoothing", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "naive_forecaster", + new FakeModelInfo("naive_forecaster", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "stl_forecaster", + new FakeModelInfo("stl_forecaster", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "gaussian_hmm", + new FakeModelInfo("gaussian_hmm", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "gmm_hmm", new FakeModelInfo("gmm_hmm", "sktime", "builtin", "active")), + new AbstractMap.SimpleEntry<>( + "stray", new FakeModelInfo("stray", "sktime", "builtin", "active"))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + tmp.putAll(BUILTIN_LTSM_MAP); + BUILTIN_MODEL_MAP = Collections.unmodifiableMap(tmp); + } + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeTestUtils.class); public static void checkHeader(ResultSetMetaData resultSetMetaData, String title) throws SQLException { @@ -94,6 +132,63 @@ public static void concurrentInference( } } + public static void checkModelOnSpecifiedDevice(Statement statement, String modelId, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 200; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.containsAll(targetDevices)) { + LOGGER.info("Model {} is loaded to devices {}, start testing", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is not loaded on device " + device); + } + + public static void checkModelNotOnSpecifiedDevice( + Statement statement, String modelId, String device) + throws SQLException, InterruptedException { + Set targetDevices = ImmutableSet.copyOf(device.split(",")); + LOGGER.info("Checking model: {} not on target devices: {}", modelId, targetDevices); + for (int retry = 0; retry < 50; retry++) { + Set foundDevices = new HashSet<>(); + try (final ResultSet resultSet = + statement.executeQuery(String.format("SHOW LOADED MODELS '%s'", device))) { + while (resultSet.next()) { + String deviceId = resultSet.getString("DeviceId"); + String loadedModelId = resultSet.getString("ModelId"); + int count = resultSet.getInt("Count(instances)"); + LOGGER.info("Model {} found in device {}, count {}", loadedModelId, deviceId, count); + if (loadedModelId.equals(modelId) && targetDevices.contains(deviceId) && count > 0) { + foundDevices.add(deviceId); + LOGGER.info("Model {} is loaded to device {}", modelId, device); + } + } + if (foundDevices.isEmpty()) { + LOGGER.info("Model {} is unloaded from devices {}.", modelId, targetDevices); + return; + } + } + TimeUnit.SECONDS.sleep(3); + } + fail("Model " + modelId + " is still loaded on device " + device); + } + public static class FakeModelInfo { private final String modelId; diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index e04ff838819e..e7bab16ad1ff 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -403,7 +403,6 @@ public void testInformationSchema() throws SQLException { "databases,INF,", "functions,INF,", "keywords,INF,", - "models,INF,", "nodes,INF,", "pipe_plugins,INF,", "pipes,INF,", @@ -504,16 +503,6 @@ public void testInformationSchema() throws SQLException { "database,STRING,TAG,", "table_name,STRING,TAG,", "view_definition,STRING,ATTRIBUTE,"))); - TestUtils.assertResultSetEqual( - statement.executeQuery("desc models"), - "ColumnName,DataType,Category,", - new HashSet<>( - Arrays.asList( - "model_id,STRING,TAG,", - "model_type,STRING,ATTRIBUTE,", - "state,STRING,ATTRIBUTE,", - "configs,STRING,ATTRIBUTE,", - "notes,STRING,ATTRIBUTE,"))); TestUtils.assertResultSetEqual( statement.executeQuery("desc functions"), "ColumnName,DataType,Category,", @@ -638,7 +627,6 @@ public void testInformationSchema() throws SQLException { "information_schema,pipes,INF,USING,null,SYSTEM VIEW,", "information_schema,subscriptions,INF,USING,null,SYSTEM VIEW,", "information_schema,views,INF,USING,null,SYSTEM VIEW,", - "information_schema,models,INF,USING,null,SYSTEM VIEW,", "information_schema,functions,INF,USING,null,SYSTEM VIEW,", "information_schema,configurations,INF,USING,null,SYSTEM VIEW,", "information_schema,keywords,INF,USING,null,SYSTEM VIEW,", @@ -651,7 +639,7 @@ public void testInformationSchema() throws SQLException { TestUtils.assertResultSetEqual( statement.executeQuery("count devices from tables where status = 'USING'"), "count(devices),", - Collections.singleton("20,")); + Collections.singleton("19,")); TestUtils.assertResultSetEqual( statement.executeQuery( "select * from columns where table_name = 'queries' or database = 'test'"), diff --git a/integration-test/src/test/resources/ainode-example/config.yaml b/integration-test/src/test/resources/ainode-example/config.yaml deleted file mode 100644 index 665acb8704e2..000000000000 --- a/integration-test/src/test/resources/ainode-example/config.yaml +++ /dev/null @@ -1,5 +0,0 @@ -configs: - input_shape: [7, 4] - output_shape: [7, 4] - input_type: ["float32", "float32", "float32", "float32"] - output_type: ["float32", "float32", "float32", "float32"] diff --git a/integration-test/src/test/resources/ainode-example/model.pt b/integration-test/src/test/resources/ainode-example/model.pt deleted file mode 100644 index 67d4aec6999f1b677d7e71e2415ba3178f7f618b..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1906 zcmWIWW@cev;NW1u0D=sH44EmZc_o=8mHH`(C5d_k**R`bybMvupn)klKE5QsC^;iO zUJp#`<>l$+=BJeAq!#PtWagzN7IAq(jo~U}&}^*LhydAEQk0mPmzkGd$k-7f2IR+Q z7RRTR=H$dDB_?N=Cl;l|XXNK+7c%*kCKWR41$eV_IGE>6O9!e1;Q*ksMS#x6bhiB#-dViGfQ?(V%DD<11r@UFA?#QIlCwVffiWf>m5x#By; zQ&>r6;m-W5=so`@Z`yUBQ9fazXs5P_&@Un7ZR@1x`&G-h96P%IoF&KZ8*3iSdV8{2 zJa(p5;FmJZ6N~OPWNgcxadu0uml~Ve@@M8{`=2l}omY+CA*g;)if{S*Z1ZqU@u&P- zs+mE-=bLt7=Q3c>y@3WF_E@)JFd!rN^ioojO4H-P2}FmafNWrjkN`T!%|FQ3F(f|R zGsGi4I3&o^&pkfG(aFcPkRbvn%TUPJTFB(hkPJy+S(znz@dcU5**U3PNu`-NDe;+k zB{`YJC0vEf8nGIwB|+W{-VE)9EMN-AYAs}K2PdJ8ANSg30nGzpP!hr(24W0C$YGFI zT#}eqQVdD{d}zLFVA2GeoU6|n4m6GdgmIfJz+i_kxh%D)I5R)b&B+SQOnm7LUCx*b z6t@@WrH3*BZ3bc7whJNKo>WK%01K|K~Mo?hEOFrbnMGz!`0x-%!h;~E?gq*pI zP_%9b^5EKuE|1Wihn#S2P|W|vNRIi442y0PazX}`%Lwob7+^>~LCO~BW*{d=0fYfS zRtPha8PE)Xt{XWa38Cn|gsdB$fYJ3MN3%SN{xD$fg!=${;tTL*W7C0Zl4I6|YiEbD nV6+4{^)N8}0A+X}0O|uv2|yJ9V+AP23d#%&>_7-o4^ayM@w(>s diff --git a/iotdb-core/ainode/ainode.spec b/iotdb-core/ainode/ainode.spec index a131b2bcff21..bde8b845fb8c 100644 --- a/iotdb-core/ainode/ainode.spec +++ b/iotdb-core/ainode/ainode.spec @@ -41,14 +41,24 @@ all_hiddenimports = [] # Only collect essential data files and binaries for critical libraries # This reduces startup time by avoiding unnecessary module imports essential_libraries = { - 'torch': True, # Keep collect_all for torch as it has many dynamic imports - 'transformers': True, # Keep collect_all for transformers - 'safetensors': True, # Keep collect_all for safetensors + 'torch': True, + 'transformers': True, + 'tokenizers': True, + 'huggingface_hub': True, + 'safetensors': True, + 'hf_xet': True, + 'numpy': True, + 'scipy': True, + 'pandas': True, + 'sklearn': True, + 'statsmodels': True, + 'sktime': True, + 'pmdarima': True, + 'hmmlearn': True, + 'accelerate': True } -# For other libraries, use selective collection to speed up startup -other_libraries = ['sktime', 'scipy', 'pandas', 'sklearn', 'statsmodels', 'optuna'] - +# Collect all libraries using collect_all (includes data files and binaries) for lib in essential_libraries: try: lib_datas, lib_binaries, lib_hiddenimports = collect_all(lib) @@ -58,48 +68,105 @@ for lib in essential_libraries: except Exception: pass -# For other libraries, only collect submodules (lighter weight) -# This relies on PyInstaller's dependency analysis to include what's actually used -for lib in other_libraries: +# Additionally collect ALL submodules for libraries that commonly have dynamic imports +# This is a more aggressive approach but ensures we don't miss any modules +# Libraries that are known to have many dynamic imports and submodules +libraries_with_dynamic_imports = [ + 'scipy', # Has many subpackages: stats, interpolate, optimize, linalg, sparse, signal, etc. + 'sklearn', # Has many submodules that may be dynamically imported + 'transformers', # Has dynamic model loading + 'torch', # Has many submodules, especially _dynamo.polyfills +] + +# Collect all submodules for these libraries to ensure comprehensive coverage +for lib in libraries_with_dynamic_imports: try: submodules = collect_submodules(lib) all_hiddenimports.extend(submodules) - # Only collect essential data files and binaries, not all submodules - # This significantly reduces startup time - try: - lib_datas, lib_binaries, _ = collect_all(lib) - all_datas.extend(lib_datas) - all_binaries.extend(lib_binaries) - except Exception: - # If collect_all fails, try collect_data_files for essential data only - try: - lib_datas = collect_data_files(lib) - all_datas.extend(lib_datas) - except Exception: - pass - except Exception: - pass + print(f"Collected {len(submodules)} submodules from {lib}") + except Exception as e: + print(f"Warning: Failed to collect submodules from {lib}: {e}") + + +# Helper function to collect submodules with fallback +def collect_submodules_with_fallback(package, fallback_modules=None, package_name=None): + """ + Collect all submodules for a package, with fallback to manual module list if collection fails. + + Args: + package: Package name to collect submodules from + fallback_modules: List of module names to add if collection fails (optional) + package_name: Display name for logging (defaults to package) + """ + if package_name is None: + package_name = package + try: + submodules = collect_submodules(package) + all_hiddenimports.extend(submodules) + print(f"Collected {len(submodules)} submodules from {package_name}") + except Exception as e: + print(f"Warning: Failed to collect {package_name} submodules: {e}") + if fallback_modules: + all_hiddenimports.extend(fallback_modules) + print(f"Using fallback modules for {package_name}") + + +# Additional specific packages that need submodule collection +# Note: scipy, sklearn, transformers, torch are already collected above via libraries_with_dynamic_imports +# This section is for more specific sub-packages that need special handling +# Format: (package_name, fallback_modules_list, display_name) +submodule_collection_configs = [ + # torch._dynamo.polyfills - critical for torch dynamo functionality + # (torch is already collected above, but this ensures polyfills are included) + ( + 'torch._dynamo.polyfills', + [ + 'torch._dynamo.polyfills', + 'torch._dynamo.polyfills.functools', + 'torch._dynamo.polyfills.operator', + 'torch._dynamo.polyfills.collections', + ], + 'torch._dynamo.polyfills' + ), + # transformers sub-packages with dynamic imports + # (transformers is already collected above, but these specific sub-packages may need extra attention) + ('transformers.generation', None, 'transformers.generation'), + ('transformers.models.auto', None, 'transformers.models.auto'), +] + +# Collect submodules for all configured packages +for package, fallback_modules, display_name in submodule_collection_configs: + collect_submodules_with_fallback(package, fallback_modules, display_name) # Project-specific packages that need their submodules collected # Only list top-level packages - collect_submodules will recursively collect all submodules -TOP_LEVEL_PACKAGES = [ +project_packages = [ 'iotdb.ainode.core', # This will include all sub-packages: manager, model, inference, etc. 'iotdb.thrift', # This will include all thrift sub-packages ] # Collect all submodules for project packages automatically # Using top-level packages avoids duplicate collection -for package in TOP_LEVEL_PACKAGES: - try: - submodules = collect_submodules(package) - all_hiddenimports.extend(submodules) - except Exception: - # If package doesn't exist or collection fails, add the package itself - all_hiddenimports.append(package) +# If collection fails, add the package itself as fallback +for package in project_packages: + collect_submodules_with_fallback(package, fallback_modules=[package], package_name=package) # Add parent packages to ensure they are included all_hiddenimports.extend(['iotdb', 'iotdb.ainode']) +# Fix circular import issues in scipy.stats +# scipy.stats has circular imports that can cause issues in PyInstaller +# We need to ensure _stats is imported before scipy.stats tries to import it +# This helps resolve the "partially initialized module" error +scipy_stats_critical_modules = [ + 'scipy.stats._stats', # Core stats module, must be imported first + 'scipy.stats._stats_py', # Python implementation + 'scipy.stats._continuous_distns', # Continuous distributions + 'scipy.stats._discrete_distns', # Discrete distributions + 'scipy.stats.distributions', # Distribution base classes +] +all_hiddenimports.extend(scipy_stats_critical_modules) + # Multiprocessing support for PyInstaller # When using multiprocessing with PyInstaller, we need to ensure proper handling multiprocessing_modules = [ @@ -119,9 +186,6 @@ multiprocessing_modules = [ # Additional dependencies that may need explicit import # These are external libraries that might use dynamic imports external_dependencies = [ - 'huggingface_hub', - 'tokenizers', - 'hf_xet', 'einops', 'dynaconf', 'tzlocal', @@ -161,7 +225,9 @@ a = Analysis( win_no_prefer_redirects=False, win_private_assemblies=False, cipher=block_cipher, - noarchive=True, # Set to True to speed up startup - files are not archived into PYZ + noarchive=False, # Set to False to avoid circular import issues with scipy.stats + # When noarchive=True, modules are loaded as separate files which can cause + # circular import issues. Using PYZ archive helps PyInstaller handle module loading order better. ) # Package all PYZ files diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index afcf0683d7d0..e465df7e36d2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -20,7 +20,6 @@ from iotdb.ainode.core.constant import ( AINODE_BUILD_INFO, - AINODE_BUILTIN_MODELS_DIR, AINODE_CLUSTER_INGRESS_ADDRESS, AINODE_CLUSTER_INGRESS_PASSWORD, AINODE_CLUSTER_INGRESS_PORT, @@ -33,10 +32,11 @@ AINODE_CONF_POM_FILE_NAME, AINODE_INFERENCE_BATCH_INTERVAL_IN_MS, AINODE_INFERENCE_EXTRA_MEMORY_RATIO, - AINODE_INFERENCE_MAX_PREDICT_LENGTH, + AINODE_INFERENCE_MAX_OUTPUT_LENGTH, AINODE_INFERENCE_MEMORY_USAGE_RATIO, AINODE_INFERENCE_MODEL_MEM_USAGE_MAP, AINODE_LOG_DIR, + AINODE_MODELS_BUILTIN_DIR, AINODE_MODELS_DIR, AINODE_RPC_ADDRESS, AINODE_RPC_PORT, @@ -75,9 +75,7 @@ def __init__(self): self._ain_inference_batch_interval_in_ms: int = ( AINODE_INFERENCE_BATCH_INTERVAL_IN_MS ) - self._ain_inference_max_predict_length: int = ( - AINODE_INFERENCE_MAX_PREDICT_LENGTH - ) + self._ain_inference_max_output_length: int = AINODE_INFERENCE_MAX_OUTPUT_LENGTH self._ain_inference_model_mem_usage_map: dict[str, int] = ( AINODE_INFERENCE_MODEL_MEM_USAGE_MAP ) @@ -95,7 +93,7 @@ def __init__(self): # Directory to save models self._ain_models_dir = AINODE_MODELS_DIR - self._ain_builtin_models_dir = AINODE_BUILTIN_MODELS_DIR + self._ain_models_builtin_dir = AINODE_MODELS_BUILTIN_DIR self._ain_system_dir = AINODE_SYSTEM_DIR # Whether to enable compression for thrift @@ -160,13 +158,13 @@ def set_ain_inference_batch_interval_in_ms( ) -> None: self._ain_inference_batch_interval_in_ms = ain_inference_batch_interval_in_ms - def get_ain_inference_max_predict_length(self) -> int: - return self._ain_inference_max_predict_length + def get_ain_inference_max_output_length(self) -> int: + return self._ain_inference_max_output_length - def set_ain_inference_max_predict_length( - self, ain_inference_max_predict_length: int + def set_ain_inference_max_output_length( + self, ain_inference_max_output_length: int ) -> None: - self._ain_inference_max_predict_length = ain_inference_max_predict_length + self._ain_inference_max_output_length = ain_inference_max_output_length def get_ain_inference_model_mem_usage_map(self) -> dict[str, int]: return self._ain_inference_model_mem_usage_map @@ -204,11 +202,11 @@ def get_ain_models_dir(self) -> str: def set_ain_models_dir(self, ain_models_dir: str) -> None: self._ain_models_dir = ain_models_dir - def get_ain_builtin_models_dir(self) -> str: - return self._ain_builtin_models_dir + def get_ain_models_builtin_dir(self) -> str: + return self._ain_models_builtin_dir - def set_ain_builtin_models_dir(self, ain_builtin_models_dir: str) -> None: - self._ain_builtin_models_dir = ain_builtin_models_dir + def set_ain_models_builtin_dir(self, ain_models_builtin_dir: str) -> None: + self._ain_models_builtin_dir = ain_models_builtin_dir def get_ain_system_dir(self) -> str: return self._ain_system_dir @@ -374,6 +372,11 @@ def _load_config_from_file(self) -> None: if "ain_models_dir" in config_keys: self._config.set_ain_models_dir(file_configs["ain_models_dir"]) + if "ain_models_builtin_dir" in config_keys: + self._config.set_ain_models_builtin_dir( + file_configs["ain_models_builtin_dir"] + ) + if "ain_system_dir" in config_keys: self._config.set_ain_system_dir(file_configs["ain_system_dir"]) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index b9923d3e3ee7..d8f730c829c8 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -18,9 +18,7 @@ import logging import os from enum import Enum -from typing import List -from iotdb.ainode.core.model.model_enums import BuiltInModelType from iotdb.thrift.common.ttypes import TEndPoint IOTDB_AINODE_HOME = os.getenv("IOTDB_AINODE_HOME", "") @@ -50,21 +48,21 @@ # AINode inference configuration AINODE_INFERENCE_BATCH_INTERVAL_IN_MS = 15 -AINODE_INFERENCE_MAX_PREDICT_LENGTH = 2880 +AINODE_INFERENCE_MAX_OUTPUT_LENGTH = 2880 + +# TODO: Should be optimized AINODE_INFERENCE_MODEL_MEM_USAGE_MAP = { - BuiltInModelType.SUNDIAL.value: 1036 * 1024**2, # 1036 MiB - BuiltInModelType.TIMER_XL.value: 856 * 1024**2, # 856 MiB + "sundial": 1036 * 1024**2, # 1036 MiB + "timer": 856 * 1024**2, # 856 MiB } # the memory usage of each model in bytes + AINODE_INFERENCE_MEMORY_USAGE_RATIO = 0.4 # the device space allocated for inference AINODE_INFERENCE_EXTRA_MEMORY_RATIO = ( 1.2 # the overhead ratio for inference, used to estimate the pool size ) -# AINode folder structure AINODE_MODELS_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/models") -AINODE_BUILTIN_MODELS_DIR = os.path.join( - IOTDB_AINODE_HOME, "data/ainode/models/weights" -) # For built-in models, we only need to store their weights and config. +AINODE_MODELS_BUILTIN_DIR = "iotdb.ainode.core.model" AINODE_SYSTEM_DIR = os.path.join(IOTDB_AINODE_HOME, "data/ainode/system") AINODE_LOG_DIR = os.path.join(IOTDB_AINODE_HOME, "logs") @@ -77,11 +75,6 @@ "log_inference_rank_{}_" # example: log_inference_rank_0_all.log ) -# AINode model management -MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" -MODEL_CONFIG_FILE_IN_JSON = "config.json" -MODEL_WEIGHTS_FILE_IN_PT = "model.pt" -MODEL_CONFIG_FILE_IN_YAML = "config.yaml" DEFAULT_CHUNK_SIZE = 8192 @@ -100,27 +93,6 @@ def get_status_code(self) -> int: return self.value -class TaskType(Enum): - FORECAST = "forecast" - - -class OptionsKey(Enum): - # common - TASK_TYPE = "task_type" - MODEL_TYPE = "model_type" - AUTO_TUNING = "auto_tuning" - INPUT_VARS = "input_vars" - - # forecast - INPUT_LENGTH = "input_length" - PREDICT_LENGTH = "predict_length" - PREDICT_INDEX_LIST = "predict_index_list" - INPUT_TYPE_LIST = "input_type_list" - - def name(self) -> str: - return self.value - - class HyperparameterName(Enum): # Training hyperparameter LEARNING_RATE = "learning_rate" @@ -139,134 +111,3 @@ class HyperparameterName(Enum): def name(self): return self.value - - -class ForecastModelType(Enum): - DLINEAR = "dlinear" - DLINEAR_INDIVIDUAL = "dlinear_individual" - NBEATS = "nbeats" - - @classmethod - def values(cls) -> List[str]: - values = [] - for item in list(cls): - values.append(item.value) - return values - - -class ModelInputName(Enum): - DATA_X = "data_x" - TIME_STAMP_X = "time_stamp_x" - TIME_STAMP_Y = "time_stamp_y" - DEC_INP = "dec_inp" - - -class AttributeName(Enum): - # forecast Attribute - PREDICT_LENGTH = "predict_length" - - # NaiveForecaster - STRATEGY = "strategy" - SP = "sp" - - # STLForecaster - # SP = 'sp' - SEASONAL = "seasonal" - SEASONAL_DEG = "seasonal_deg" - TREND_DEG = "trend_deg" - LOW_PASS_DEG = "low_pass_deg" - SEASONAL_JUMP = "seasonal_jump" - TREND_JUMP = "trend_jump" - LOSS_PASS_JUMP = "low_pass_jump" - - # ExponentialSmoothing - DAMPED_TREND = "damped_trend" - INITIALIZATION_METHOD = "initialization_method" - OPTIMIZED = "optimized" - REMOVE_BIAS = "remove_bias" - USE_BRUTE = "use_brute" - - # Arima - ORDER = "order" - SEASONAL_ORDER = "seasonal_order" - METHOD = "method" - MAXITER = "maxiter" - SUPPRESS_WARNINGS = "suppress_warnings" - OUT_OF_SAMPLE_SIZE = "out_of_sample_size" - SCORING = "scoring" - WITH_INTERCEPT = "with_intercept" - TIME_VARYING_REGRESSION = "time_varying_regression" - ENFORCE_STATIONARITY = "enforce_stationarity" - ENFORCE_INVERTIBILITY = "enforce_invertibility" - SIMPLE_DIFFERENCING = "simple_differencing" - MEASUREMENT_ERROR = "measurement_error" - MLE_REGRESSION = "mle_regression" - HAMILTON_REPRESENTATION = "hamilton_representation" - CONCENTRATE_SCALE = "concentrate_scale" - - # GAUSSIAN_HMM - N_COMPONENTS = "n_components" - COVARIANCE_TYPE = "covariance_type" - MIN_COVAR = "min_covar" - STARTPROB_PRIOR = "startprob_prior" - TRANSMAT_PRIOR = "transmat_prior" - MEANS_PRIOR = "means_prior" - MEANS_WEIGHT = "means_weight" - COVARS_PRIOR = "covars_prior" - COVARS_WEIGHT = "covars_weight" - ALGORITHM = "algorithm" - N_ITER = "n_iter" - TOL = "tol" - PARAMS = "params" - INIT_PARAMS = "init_params" - IMPLEMENTATION = "implementation" - - # GMMHMM - # N_COMPONENTS = "n_components" - N_MIX = "n_mix" - # MIN_COVAR = "min_covar" - # STARTPROB_PRIOR = "startprob_prior" - # TRANSMAT_PRIOR = "transmat_prior" - WEIGHTS_PRIOR = "weights_prior" - - # MEANS_PRIOR = "means_prior" - # MEANS_WEIGHT = "means_weight" - # ALGORITHM = "algorithm" - # COVARIANCE_TYPE = "covariance_type" - # N_ITER = "n_iter" - # TOL = "tol" - # INIT_PARAMS = "init_params" - # PARAMS = "params" - # IMPLEMENTATION = "implementation" - - # STRAY - ALPHA = "alpha" - K = "k" - KNN_ALGORITHM = "knn_algorithm" - P = "p" - SIZE_THRESHOLD = "size_threshold" - OUTLIER_TAIL = "outlier_tail" - - # timerxl - INPUT_TOKEN_LEN = "input_token_len" - HIDDEN_SIZE = "hidden_size" - INTERMEDIATE_SIZE = "intermediate_size" - OUTPUT_TOKEN_LENS = "output_token_lens" - NUM_HIDDEN_LAYERS = "num_hidden_layers" - NUM_ATTENTION_HEADS = "num_attention_heads" - HIDDEN_ACT = "hidden_act" - USE_CACHE = "use_cache" - ROPE_THETA = "rope_theta" - ATTENTION_DROPOUT = "attention_dropout" - INITIALIZER_RANGE = "initializer_range" - MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - CKPT_PATH = "ckpt_path" - - # sundial - DROPOUT_RATE = "dropout_rate" - FLOW_LOSS_DEPTH = "flow_loss_depth" - NUM_SAMPLING_STEPS = "num_sampling_steps" - DIFFUSION_BATCH_MUL = "diffusion_batch_mul" - - def name(self) -> str: - return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index bc89cdc30662..30b9d54dcc7d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -17,7 +17,7 @@ # import re -from iotdb.ainode.core.constant import ( +from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_YAML, MODEL_WEIGHTS_FILE_IN_PT, ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py index 82c72cc37abf..50634914c273 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request.py @@ -15,14 +15,12 @@ # specific language governing permissions and limitations # under the License. # + import threading from typing import Any import torch -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.atmoic_int import AtomicInt @@ -41,8 +39,7 @@ def __init__( req_id: str, model_id: str, inputs: torch.Tensor, - inference_pipeline: AbstractInferencePipeline, - max_new_tokens: int = 96, + output_length: int = 96, **infer_kwargs, ): if inputs.ndim == 1: @@ -52,9 +49,8 @@ def __init__( self.model_id = model_id self.inputs = inputs self.infer_kwargs = infer_kwargs - self.inference_pipeline = inference_pipeline - self.max_new_tokens = ( - max_new_tokens # Number of time series data points to generate + self.output_length = ( + output_length # Number of time series data points to generate ) self.batch_size = inputs.size(0) @@ -65,7 +61,7 @@ def __init__( # Preallocate output buffer [batch_size, max_new_tokens] self.output_tensor = torch.zeros( - self.batch_size, max_new_tokens, device="cpu" + self.batch_size, output_length, device="cpu" ) # shape: [self.batch_size, max_new_steps] def mark_running(self): @@ -77,7 +73,7 @@ def mark_finished(self): def is_finished(self) -> bool: return ( self.state == InferenceRequestState.FINISHED - or self.cur_step_idx >= self.max_new_tokens + or self.cur_step_idx >= self.output_length ) def write_step_output(self, step_output: torch.Tensor): @@ -87,11 +83,11 @@ def write_step_output(self, step_output: torch.Tensor): batch_size, step_size = step_output.shape end_idx = self.cur_step_idx + step_size - if end_idx > self.max_new_tokens: + if end_idx > self.output_length: self.output_tensor[:, self.cur_step_idx :] = step_output[ - :, : self.max_new_tokens - self.cur_step_idx + :, : self.output_length - self.cur_step_idx ] - self.cur_step_idx = self.max_new_tokens + self.cur_step_idx = self.output_length else: self.output_tensor[:, self.cur_step_idx : end_idx] = step_output self.cur_step_idx = end_idx diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py index 6b054c91fe31..a6c415a6c848 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/inference_request_pool.py @@ -25,19 +25,22 @@ import numpy as np import torch import torch.multiprocessing as mp -from transformers import PretrainedConfig from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE from iotdb.ainode.core.inference.batcher.basic_batcher import BasicBatcher from iotdb.ainode.core.inference.inference_request import InferenceRequest +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, +) +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline from iotdb.ainode.core.inference.request_scheduler.basic_request_scheduler import ( BasicRequestScheduler, ) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.model_storage import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device @@ -62,7 +65,6 @@ def __init__( pool_id: int, model_info: ModelInfo, device: str, - config: PretrainedConfig, request_queue: mp.Queue, result_queue: mp.Queue, ready_event, @@ -71,7 +73,6 @@ def __init__( super().__init__() self.pool_id = pool_id self.model_info = model_info - self.config = config self.pool_kwargs = pool_kwargs self.ready_event = ready_event self.device = convert_device_id_to_torch_device(device) @@ -86,8 +87,8 @@ def __init__( self._batcher = BasicBatcher() self._stop_event = mp.Event() - self._model = None - self._model_manager = None + self._inference_pipeline = None + self._logger = None # Fix inference seed @@ -98,9 +99,6 @@ def __init__( def _activate_requests(self): requests = self._request_scheduler.schedule_activate() for request in requests: - request.inputs = request.inference_pipeline.preprocess_inputs( - request.inputs - ) request.mark_running() self._running_queue.put(request) self._logger.debug( @@ -117,72 +115,51 @@ def _step(self): grouped_requests = defaultdict(list) for req in all_requests: - key = (req.inputs.shape[1], req.max_new_tokens) + key = (req.inputs.shape[1], req.output_length) grouped_requests[key].append(req) grouped_requests = list(grouped_requests.values()) for requests in grouped_requests: batch_inputs = self._batcher.batch_request(requests).to(self.device) - if self.model_info.model_type == BuiltInModelType.SUNDIAL.value: - batch_output = self._model.generate( + if isinstance(self._inference_pipeline, ForecastPipeline): + batch_output = self._inference_pipeline.forecast( batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - num_samples=10, + predict_length=requests[0].output_length, revin=True, ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output.mean(dim=1)) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) - - elif self.model_info.model_type == BuiltInModelType.TIMER_XL.value: - batch_output = self._model.generate( + elif isinstance(self._inference_pipeline, ClassificationPipeline): + batch_output = self._inference_pipeline.classify( batch_inputs, - max_new_tokens=requests[0].max_new_tokens, - revin=True, + # more infer kwargs can be added here ) - - offset = 0 - for request in requests: - request.output_tensor = request.output_tensor.to(self.device) - cur_batch_size = request.batch_size - cur_output = batch_output[offset : offset + cur_batch_size] - offset += cur_batch_size - request.write_step_output(cur_output) - - request.inference_pipeline.post_decode() - if request.is_finished(): - request.inference_pipeline.post_inference() - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" - ) - # ensure the output tensor is on CPU before sending to result queue - request.output_tensor = request.output_tensor.cpu() - self._finished_queue.put(request) - else: - self._logger.debug( - f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" - ) - self._waiting_queue.put(request) + elif isinstance(self._inference_pipeline, ChatPipeline): + batch_output = self._inference_pipeline.chat( + batch_inputs, + # more infer kwargs can be added here + ) + else: + self._logger.error("[Inference] Unsupported pipeline type.") + offset = 0 + for request in requests: + request.output_tensor = request.output_tensor.to(self.device) + cur_batch_size = request.batch_size + cur_output = batch_output[offset : offset + cur_batch_size] + offset += cur_batch_size + request.write_step_output(cur_output) + + if request.is_finished(): + # ensure the output tensor is on CPU before sending to result queue + request.output_tensor = request.output_tensor.cpu() + self._finished_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is finished" + ) + else: + self._waiting_queue.put(request) + self._logger.debug( + f"[Inference][Device-{self.device}][Pool-{self.pool_id}][ID-{request.req_id}] Request is not finished, re-queueing" + ) + return def _requests_execute_loop(self): while not self._stop_event.is_set(): @@ -193,11 +170,8 @@ def run(self): self._logger = Logger( INFERENCE_LOG_FILE_NAME_PREFIX_TEMPLATE.format(self.device) ) - self._model_manager = ModelManager() self._request_scheduler.device = self.device - self._model = self._model_manager.load_model(self.model_info.model_id, {}).to( - self.device - ) + self._inference_pipeline = load_pipeline(self.model_info, str(self.device)) self.ready_event.set() activate_daemon = threading.Thread( diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py new file mode 100644 index 000000000000..82601e398059 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/basic_pipeline.py @@ -0,0 +1,87 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import ABC, abstractmethod + +import torch + +from iotdb.ainode.core.model.model_loader import load_model + + +class BasicPipeline(ABC): + def __init__(self, model_info, **model_kwargs): + self.model_info = model_info + self.device = model_kwargs.get("device", "cpu") + self.model = load_model(model_info, device_map=self.device, **model_kwargs) + + def _preprocess(self, inputs): + """ + Preprocess the input before inference, including shape validation and value transformation. + """ + return inputs + + def _postprocess(self, output: torch.Tensor): + """ + Post-process the outputs after the entire inference task. + """ + return output + + +class ForecastPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def forecast(self, inputs, **infer_kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output + + +class ClassificationPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def classify(self, inputs, **kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output + + +class ChatPipeline(BasicPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + @abstractmethod + def chat(self, inputs, **kwargs): + pass + + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py new file mode 100644 index 000000000000..a30038dd5fef --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pipeline/pipeline_loader.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from pathlib import Path + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_storage import ModelInfo +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path + +logger = Logger() + + +def load_pipeline(model_info: ModelInfo, device: str, **model_kwargs): + if model_info.model_type == "sktime": + from iotdb.ainode.core.model.sktime.pipeline_sktime import SktimePipeline + + pipeline_cls = SktimePipeline + elif model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + pipeline_cls = import_class_from_path(module_name, model_info.pipeline_cls) + else: + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + pipeline_cls = import_class_from_path( + model_info.model_id, model_info.pipeline_cls + ) + + return pipeline_cls(model_info, device=device, **model_kwargs) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 54580402ec29..8ffa89ffd675 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -22,7 +22,6 @@ from concurrent.futures import wait from typing import Dict, Optional -import torch import torch.multiprocessing as mp from iotdb.ainode.core.exception import InferenceModelInternalError @@ -41,9 +40,6 @@ ) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig from iotdb.ainode.core.util.atmoic_int import AtomicInt from iotdb.ainode.core.util.batch_executor import BatchExecutor from iotdb.ainode.core.util.decorator import synchronized @@ -76,9 +72,9 @@ def __init__(self, result_queue: mp.Queue): thread_name_prefix=ThreadName.INFERENCE_POOL_CONTROLLER.value ) - # =============== Pool Management =============== + # =============== Automatic Pool Management (Developing) =============== @synchronized(threading.Lock()) - def first_req_init(self, model_id: str): + def first_req_init(self, model_id: str, device): """ Initialize the pools when the first request for the given model_id arrives. """ @@ -107,38 +103,35 @@ def _first_pool_init(self, model_id: str, device_str: str): Initialize the first pool for the given model_id. Ensure the pool is ready before returning. """ - device = torch.device(device_str) - device_id = device.index - - if model_id == "sundial": - config = SundialConfig() - elif model_id == "timer_xl": - config = TimerConfig() - first_queue = mp.Queue() - ready_event = mp.Event() - first_pool = InferenceRequestPool( - pool_id=0, - model_id=model_id, - device=device_str, - config=config, - request_queue=first_queue, - result_queue=self._result_queue, - ready_event=ready_event, - ) - first_pool.start() - self._register_pool(model_id, device_str, 0, first_pool, first_queue) - - if not ready_event.wait(timeout=30): - self._erase_pool(model_id, device_id, 0) - logger.error( - f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" - ) - else: - self.set_state(model_id, device_id, 0, PoolState.RUNNING) - logger.info( - f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" - ) + pass + # device = torch.device(device_str) + # device_id = device.index + # + # first_queue = mp.Queue() + # ready_event = mp.Event() + # first_pool = InferenceRequestPool( + # pool_id=0, + # model_id=model_id, + # device=device_str, + # request_queue=first_queue, + # result_queue=self._result_queue, + # ready_event=ready_event, + # ) + # first_pool.start() + # self._register_pool(model_id, device_str, 0, first_pool, first_queue) + # + # if not ready_event.wait(timeout=30): + # self._erase_pool(model_id, device_id, 0) + # logger.error( + # f"[Inference][Device-{device}][Pool-0] Pool failed to be ready in time" + # ) + # else: + # self.set_state(model_id, device_id, 0, PoolState.RUNNING) + # logger.info( + # f"[Inference][Device-{device}][Pool-0] Pool started running for model {model_id}" + # ) + # =============== Pool Management =============== def load_model(self, model_id: str, device_id_list: list[str]): """ Load the model to the specified devices asynchronously. @@ -255,29 +248,19 @@ def _expand_pools_on_device(self, model_id: str, device_id: str, count: int): """ def _expand_pool_on_device(*_): - result_queue = mp.Queue() + request_queue = mp.Queue() pool_id = self._new_pool_id.get_and_increment() model_info = self._model_manager.get_model_info(model_id) - model_type = model_info.model_type - if model_type == BuiltInModelType.SUNDIAL.value: - config = SundialConfig() - elif model_type == BuiltInModelType.TIMER_XL.value: - config = TimerConfig() - else: - raise InferenceModelInternalError( - f"Unsupported model type {model_type} for loading model {model_id}" - ) pool = InferenceRequestPool( pool_id=pool_id, model_info=model_info, device=device_id, - config=config, - request_queue=result_queue, + request_queue=request_queue, result_queue=self._result_queue, ready_event=mp.Event(), ) pool.start() - self._register_pool(model_id, device_id, pool_id, pool, result_queue) + self._register_pool(model_id, device_id, pool_id, pool, request_queue) if not pool.ready_event.wait(timeout=300): logger.error( f"[Inference][Device-{device_id}][Pool-{pool_id}] Pool failed to be ready in time" diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index 6a2bd2b619aa..d2e7292ecd8f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -36,7 +36,7 @@ estimate_pool_size, evaluate_system_resources, ) -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP, ModelInfo +from iotdb.ainode.core.model.model_info import ModelInfo from iotdb.ainode.core.util.gpu_mapping import convert_device_id_to_torch_device logger = Logger() diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py deleted file mode 100644 index 2300169a6ee9..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/abstract_inference_pipeline.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# - -from abc import ABC, abstractmethod - -import torch - - -class AbstractInferencePipeline(ABC): - """ - Abstract assistance strategy class for model inference. - This class shall define the interface process for specific model. - """ - - def __init__(self, model_config, **infer_kwargs): - self.model_config = model_config - self.infer_kwargs = infer_kwargs - - @abstractmethod - def preprocess_inputs(self, inputs: torch.Tensor): - """ - Preprocess the inputs before inference, including shape validation and value transformation. - - Args: - inputs (torch.Tensor): The input tensor to be preprocessed. - - Returns: - torch.Tensor: The preprocessed input tensor. - """ - # TODO: Integrate with the data processing pipeline operators - pass - - @abstractmethod - def post_decode(self): - """ - Post-process the outputs after each decode step. - """ - pass - - @abstractmethod - def post_inference(self): - """ - Post-process the outputs after the entire inference task. - """ - pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py index cf10b5b2cd4d..d17f9fbcec53 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/utils.py @@ -19,7 +19,6 @@ import string import torch -from transformers.modeling_outputs import MoeCausalLMOutputWithPast def generate_req_id(length=10, charset=string.ascii_letters + string.digits) -> str: @@ -56,25 +55,25 @@ def _slice_pkv(pkv, s, e): return out -def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): - """ - split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) - split_sizes[i] = ith request's batch_size。 - """ - outs = [] - start = 0 - for bsz in split_sizes: - end = start + bsz - outs.append( - MoeCausalLMOutputWithPast( - loss=_slice_tensor(batch_out.loss, start, end), - logits=batch_out.logits[start:end], - past_key_values=_slice_pkv(batch_out.past_key_values, start, end), - hidden_states=_slice_tuple_of_tensors( - batch_out.hidden_states, start, end - ), - attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), - ) - ) - start = end - return outs +# def split_moe_output(batch_out: MoeCausalLMOutputWithPast, split_sizes): +# """ +# split batch_out with type: MoeCausalLMOutputWithPast into len(split_sizes) +# split_sizes[i] = ith request's batch_size。 +# """ +# outs = [] +# start = 0 +# for bsz in split_sizes: +# end = start + bsz +# outs.append( +# MoeCausalLMOutputWithPast( +# loss=_slice_tensor(batch_out.loss, start, end), +# logits=batch_out.logits[start:end], +# past_key_values=_slice_pkv(batch_out.past_key_values, start, end), +# hidden_states=_slice_tuple_of_tensors( +# batch_out.hidden_states, start, end +# ), +# attentions=_slice_tuple_of_tensors(batch_out.attentions, start, end), +# ) +# ) +# start = end +# return outs diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index a67d576b0ec8..1ce2e84e0592 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -18,7 +18,6 @@ import threading import time -from abc import ABC, abstractmethod from typing import Dict import pandas as pd @@ -29,29 +28,22 @@ from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import ( InferenceModelInternalError, - InvalidWindowArgumentError, NumericalRangeException, - runtime_error_extractor, ) from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, InferenceRequestProxy, ) -from iotdb.ainode.core.inference.pool_controller import PoolController -from iotdb.ainode.core.inference.strategy.timer_sundial_inference_pipeline import ( - TimerSundialInferencePipeline, -) -from iotdb.ainode.core.inference.strategy.timerxl_inference_pipeline import ( - TimerXLInferencePipeline, +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ( + ChatPipeline, + ClassificationPipeline, + ForecastPipeline, ) +from iotdb.ainode.core.inference.pipeline.pipeline_loader import load_pipeline +from iotdb.ainode.core.inference.pool_controller import PoolController from iotdb.ainode.core.inference.utils import generate_req_id from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig -from iotdb.ainode.core.model.sundial.modeling_sundial import SundialForPrediction -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.modeling_timer import TimerForPrediction from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.gpu_mapping import get_available_devices from iotdb.ainode.core.util.serde import convert_to_binary @@ -71,83 +63,6 @@ logger = Logger() -class InferenceStrategy(ABC): - def __init__(self, model): - self.model = model - - @abstractmethod - def infer(self, full_data, **kwargs): - pass - - -# [IoTDB] full data deserialized from iotdb is composed of [timestampList, valueList, length], -# we only get valueList currently. -class TimerXLStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, revin=True) - df = pd.DataFrame(output[0]) - return convert_to_binary(df) - - -class SundialStrategy(InferenceStrategy): - def infer(self, full_data, predict_length=96, **_): - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - seqs = torch.tensor(data).unsqueeze(0).float() - # TODO: unify model inference input - output = self.model.generate( - seqs, max_new_tokens=predict_length, num_samples=10, revin=True - ) - df = pd.DataFrame(output[0].mean(dim=0)) - return convert_to_binary(df) - - -class BuiltInStrategy(InferenceStrategy): - def infer(self, full_data, **_): - data = pd.DataFrame(full_data[1]).T - output = self.model.inference(data) - df = pd.DataFrame(output) - return convert_to_binary(df) - - -class RegisteredStrategy(InferenceStrategy): - def infer(self, full_data, window_interval=None, window_step=None, **_): - _, dataset, _, length = full_data - if window_interval is None or window_step is None: - window_interval = length - window_step = float("inf") - - if window_interval <= 0 or window_step <= 0 or window_interval > length: - raise InvalidWindowArgumentError(window_interval, window_step, length) - - data = torch.tensor(dataset, dtype=torch.float32).unsqueeze(0).permute(0, 2, 1) - - times = int((length - window_interval) // window_step + 1) - results = [] - try: - for i in range(times): - start = 0 if window_step == float("inf") else i * window_step - end = start + window_interval - window = data[:, start:end, :] - out = self.model(window) - df = pd.DataFrame(out.squeeze(0).detach().numpy()) - results.append(df) - except Exception as e: - msg = runtime_error_extractor(str(e)) or str(e) - raise InferenceModelInternalError(msg) - - # concatenate or return first window for forecast - return [convert_to_binary(df) for df in results] - - class InferenceManager: WAITING_INTERVAL_IN_MS = ( AINodeDescriptor().get_config().get_ain_inference_batch_interval_in_ms() @@ -251,15 +166,6 @@ def _process_request(self, req): with self._result_wrapper_lock: del self._result_wrapper_map[req_id] - def _get_strategy(self, model_id, model): - if isinstance(model, TimerForPrediction): - return TimerXLStrategy(model) - if isinstance(model, SundialForPrediction): - return SundialStrategy(model) - if self._model_manager.model_storage.is_built_in_or_fine_tuned(model_id): - return BuiltInStrategy(model) - return RegisteredStrategy(model) - def _run( self, req, @@ -272,59 +178,54 @@ def _run( model_id = req.modelId try: raw = data_getter(req) + # full data deserialized from iotdb is composed of [timestampList, valueList, None, length], we only get valueList currently. full_data = deserializer(raw) - inference_attrs = extract_attrs(req) + # TODO: TSBlock -> Tensor codes should be unified + data = full_data[1][0] # get valueList in ndarray + if data.dtype.byteorder not in ("=", "|"): + np_data = data.byteswap() + data = np_data.view(np_data.dtype.newbyteorder()) + # the inputs should be on CPU before passing to the inference request + inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - predict_length = int(inference_attrs.pop("predict_length", 96)) + inference_attrs = extract_attrs(req) + output_length = int(inference_attrs.pop("output_length", 96)) if ( - predict_length - > AINodeDescriptor().get_config().get_ain_inference_max_predict_length() + output_length + > AINodeDescriptor().get_config().get_ain_inference_max_output_length() ): raise NumericalRangeException( "output_length", 1, AINodeDescriptor() .get_config() - .get_ain_inference_max_predict_length(), - predict_length, + .get_ain_inference_max_output_length(), + output_length, ) if self._pool_controller.has_request_pools(model_id): - # use request pool to accelerate inference when the model instance is already loaded. - # TODO: TSBlock -> Tensor codes should be unified - data = full_data[1][0] - if data.dtype.byteorder not in ("=", "|"): - np_data = data.byteswap() - data = np_data.view(np_data.dtype.newbyteorder()) - # the inputs should be on CPU before passing to the inference request - inputs = torch.tensor(data).unsqueeze(0).float().to("cpu") - model_type = self._model_manager.get_model_info(model_id).model_type - if model_type == BuiltInModelType.SUNDIAL.value: - inference_pipeline = TimerSundialInferencePipeline(SundialConfig()) - elif model_type == BuiltInModelType.TIMER_XL.value: - inference_pipeline = TimerXLInferencePipeline(TimerConfig()) - else: - raise InferenceModelInternalError( - f"Unsupported model_id: {model_id}" - ) infer_req = InferenceRequest( req_id=generate_req_id(), model_id=model_id, inputs=inputs, - inference_pipeline=inference_pipeline, - max_new_tokens=predict_length, + output_length=output_length, ) outputs = self._process_request(infer_req) outputs = convert_to_binary(pd.DataFrame(outputs[0])) else: - # load model - accel = str(inference_attrs.get("acceleration", "")).lower() == "true" - model = self._model_manager.load_model(model_id, inference_attrs, accel) - # inference by strategy - strategy = self._get_strategy(model_id, model) - outputs = strategy.infer( - full_data, predict_length=predict_length, **inference_attrs - ) + model_info = self._model_manager.get_model_info(model_id) + inference_pipeline = load_pipeline(model_info, device="cpu") + if isinstance(inference_pipeline, ForecastPipeline): + outputs = inference_pipeline.forecast( + inputs, predict_length=output_length, **inference_attrs + ) + elif isinstance(inference_pipeline, ClassificationPipeline): + outputs = inference_pipeline.classify(inputs) + elif isinstance(inference_pipeline, ChatPipeline): + outputs = inference_pipeline.chat(inputs) + else: + logger.error("[Inference] Unsupported pipeline type.") + outputs = convert_to_binary(pd.DataFrame(outputs[0])) # construct response status = get_status(TSStatusCode.SUCCESS_STATUS) @@ -345,7 +246,7 @@ def forecast(self, req: TForecastReq): data_getter=lambda r: r.inputData, deserializer=deserialize, extract_attrs=lambda r: { - "predict_length": r.outputLength, + "output_length": r.outputLength, **(r.options or {}), }, resp_cls=TForecastResp, @@ -358,8 +259,7 @@ def inference(self, req: TInferenceReq): data_getter=lambda r: r.dataset, deserializer=deserialize, extract_attrs=lambda r: { - "window_interval": getattr(r.windowParams, "windowInterval", None), - "window_step": getattr(r.windowParams, "windowStep", None), + "output_length": int(r.inferenceAttributes.pop("outputLength", 96)), **(r.inferenceAttributes or {}), }, resp_cls=TInferenceResp, diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index d84bca77c843..8ffb33d91e2d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -15,17 +15,14 @@ # specific language governing permissions and limitations # under the License. # -from typing import Callable, Dict -from torch import nn -from yaml import YAMLError +from typing import Any, List, Optional from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType, ModelStates -from iotdb.ainode.core.model.model_info import ModelInfo -from iotdb.ainode.core.model.model_storage import ModelStorage +from iotdb.ainode.core.model.model_loader import load_model +from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.ainode.ttypes import ( @@ -43,127 +40,60 @@ @singleton class ModelManager: def __init__(self): - self.model_storage = ModelStorage() + self._model_storage = ModelStorage() - def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: - logger.info(f"register model {req.modelId} from {req.uri}") + def register_model( + self, + req: TRegisterModelReq, + ) -> TRegisterModelResp: try: - configs, attributes = self.model_storage.register_model( - req.modelId, req.uri - ) - return TRegisterModelResp( - get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes - ) - except InvalidUriError as e: - logger.warning(e) - return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, e.message) - ) - except BadConfigValueError as e: - logger.warning(e) + if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + except ValueError as e: return TRegisterModelResp( - get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) + get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) ) - except YAMLError as e: - logger.warning(e) - if hasattr(e, "problem_mark"): - mark = e.problem_mark - return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file, " - f"at line {mark.line + 1} column {mark.column + 1}.", - ) - ) + except Exception as e: return TRegisterModelResp( - get_status( - TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file", - ) + get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) ) - except Exception as e: - logger.warning(e) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) + + def show_models(self, req: TShowModelsReq) -> TShowModelsResp: + self._refresh() + return self._model_storage.show_models(req) def delete_model(self, req: TDeleteModelReq) -> TSStatus: - logger.info(f"delete model {req.modelId}") try: - self.model_storage.delete_model(req.modelId) + self._model_storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) - except Exception as e: + except BuiltInModelDeletionError as e: logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool = False - ) -> Callable: - """ - Load the model with the given model_id. - """ - logger.info(f"Load model {model_id}") - try: - model = self.model_storage.load_model( - model_id, inference_attrs, acceleration - ) - logger.info(f"Model {model_id} loaded") - return model - except Exception as e: - logger.error(f"Failed to load model {model_id}: {e}") - raise - - def save_model(self, model_id: str, model: nn.Module) -> TSStatus: - """ - Save the model using save_pretrained - """ - logger.info(f"Saving model {model_id}") - try: - self.model_storage.save_model(model_id, model) - logger.info(f"Saving model {model_id} successfully") - return get_status( - TSStatusCode.SUCCESS_STATUS, f"Model {model_id} saved successfully" - ) except Exception as e: - logger.error(f"Save model failed: {e}") + logger.warning(e) return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. - - Args: - model_id (str): The ID of the model. - - Returns: - str: The path to the checkpoint file for the model. - """ - return self.model_storage.get_ckpt_path(model_id) - - def show_models(self, req: TShowModelsReq) -> TShowModelsResp: - return self.model_storage.show_models(req) - - def register_built_in_model(self, model_info: ModelInfo): - self.model_storage.register_built_in_model(model_info) - - def get_model_info(self, model_id: str) -> ModelInfo: - return self.model_storage.get_model_info(model_id) - - def update_model_state(self, model_id: str, state: ModelStates): - self.model_storage.update_model_state(model_id, state) - - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: - """ - Get the type of the model with the given model_id. - """ - return self.model_storage.get_built_in_model_type(model_id) - - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. - - Args: - model_id (str): The ID of the model. - - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return self.model_storage.is_built_in_or_fine_tuned(model_id) + def get_model_info( + self, + model_id: str, + category: Optional[ModelCategory] = None, + ) -> Optional[ModelInfo]: + return self._model_storage.get_model_info(model_id, category) + + def get_model_infos( + self, + category: Optional[ModelCategory] = None, + model_type: Optional[str] = None, + ) -> List[ModelInfo]: + return self._model_storage.get_model_infos(category, model_type) + + def _refresh(self): + """Refresh the model list (re-scan the file system)""" + self._model_storage.discover_all_models() + + def get_registered_models(self) -> List[str]: + return self._model_storage.get_registered_models() + + def is_model_registered(self, model_id: str) -> bool: + return self._model_storage.is_model_registered(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 0264e27331a8..23a98f26bbff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -25,7 +25,7 @@ from iotdb.ainode.core.exception import ModelNotExistError from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager -from iotdb.ainode.core.model.model_info import BUILT_IN_LTSM_MAP +from iotdb.ainode.core.model.model_loader import load_model logger = Logger() @@ -47,7 +47,8 @@ def measure_model_memory(device: torch.device, model_id: str) -> int: torch.cuda.synchronize(device) start = torch.cuda.memory_reserved(device) - model = ModelManager().load_model(model_id, {}).to(device) + model_info = ModelManager().get_model_info(model_id) + model = load_model(model_info).to(device) torch.cuda.synchronize(device) end = torch.cuda.memory_reserved(device) usage = end - start @@ -80,7 +81,7 @@ def evaluate_system_resources(device: torch.device) -> dict: def estimate_pool_size(device: torch.device, model_id: str) -> int: - model_info = BUILT_IN_LTSM_MAP.get(model_id, None) + model_info = ModelManager().get_model_info(model_id) if model_info is None or model_info.model_type not in MODEL_MEM_USAGE_MAP: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py deleted file mode 100644 index 3b55142350ba..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/built_in_model_factory.py +++ /dev/null @@ -1,1238 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from abc import abstractmethod -from typing import Callable, Dict, List - -import numpy as np -from huggingface_hub import hf_hub_download -from sklearn.preprocessing import MinMaxScaler -from sktime.detection.hmm_learn import GMMHMM, GaussianHMM -from sktime.detection.stray import STRAY -from sktime.forecasting.arima import ARIMA -from sktime.forecasting.exp_smoothing import ExponentialSmoothing -from sktime.forecasting.naive import NaiveForecaster -from sktime.forecasting.trend import STLForecaster - -from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - AttributeName, -) -from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - InferenceModelInternalError, - ListRangeException, - NumericalRangeException, - StringRangeException, - WrongAttributeTypeError, -) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import BuiltInModelType -from iotdb.ainode.core.model.model_info import TIMER_REPO_ID -from iotdb.ainode.core.model.sundial import modeling_sundial -from iotdb.ainode.core.model.timerxl import modeling_timer - -logger = Logger() - - -def _download_file_from_hf_if_necessary(local_dir: str, repo_id: str) -> bool: - weights_path = os.path.join(local_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) - config_path = os.path.join(local_dir, MODEL_CONFIG_FILE_IN_JSON) - if not os.path.exists(weights_path): - logger.info( - f"Model weights file not found at {weights_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, - local_dir=local_dir, - ) - logger.info(f"Got file to {weights_path}") - except Exception as e: - logger.error( - f"Failed to download model weights file to {local_dir} due to {e}" - ) - return False - if not os.path.exists(config_path): - logger.info( - f"Model config file not found at {config_path}, downloading from HuggingFace..." - ) - try: - hf_hub_download( - repo_id=repo_id, - filename=MODEL_CONFIG_FILE_IN_JSON, - local_dir=local_dir, - ) - logger.info(f"Got file to {config_path}") - except Exception as e: - logger.error( - f"Failed to download model config file to {local_dir} due to {e}" - ) - return False - return True - - -def download_built_in_ltsm_from_hf_if_necessary( - model_type: BuiltInModelType, local_dir: str -) -> bool: - """ - Download the built-in ltsm from HuggingFace repository when necessary. - - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - repo_id = TIMER_REPO_ID[model_type] - if not _download_file_from_hf_if_necessary(local_dir, repo_id): - return False - return True - - -def get_model_attributes(model_type: BuiltInModelType): - if model_type == BuiltInModelType.ARIMA: - attribute_map = arima_attribute_map - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - attribute_map = naive_forecaster_attribute_map - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - attribute_map = exponential_smoothing_attribute_map - elif model_type == BuiltInModelType.STL_FORECASTER: - attribute_map = stl_forecaster_attribute_map - elif model_type == BuiltInModelType.GMM_HMM: - attribute_map = gmmhmm_attribute_map - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - attribute_map = gaussian_hmm_attribute_map - elif model_type == BuiltInModelType.STRAY: - attribute_map = stray_attribute_map - elif model_type == BuiltInModelType.TIMER_XL: - attribute_map = timerxl_attribute_map - elif model_type == BuiltInModelType.SUNDIAL: - attribute_map = sundial_attribute_map - else: - raise BuiltInModelNotSupportError(model_type.value) - return attribute_map - - -def fetch_built_in_model( - model_type: BuiltInModelType, model_dir, inference_attrs: Dict[str, str] -) -> Callable: - """ - Fetch the built-in model according to its id and directory, not that this directory only contains model weights and config. - Args: - model_type: the type of the built-in model - model_dir: for huggingface models only, the directory where the model is stored - Returns: - model: the built-in model - """ - default_attributes = get_model_attributes(model_type) - # parse the attributes from inference_attrs - attributes = parse_attribute(inference_attrs, default_attributes) - - # build the built-in model - if model_type == BuiltInModelType.ARIMA: - model = ArimaModel(attributes) - elif ( - model_type == BuiltInModelType.EXPONENTIAL_SMOOTHING - or model_type == BuiltInModelType.HOLTWINTERS - ): - model = ExponentialSmoothingModel(attributes) - elif model_type == BuiltInModelType.NAIVE_FORECASTER: - model = NaiveForecasterModel(attributes) - elif model_type == BuiltInModelType.STL_FORECASTER: - model = STLForecasterModel(attributes) - elif model_type == BuiltInModelType.GMM_HMM: - model = GMMHMMModel(attributes) - elif model_type == BuiltInModelType.GAUSSIAN_HMM: - model = GaussianHmmModel(attributes) - elif model_type == BuiltInModelType.STRAY: - model = STRAYModel(attributes) - elif model_type == BuiltInModelType.TIMER_XL: - model = modeling_timer.TimerForPrediction.from_pretrained(model_dir) - elif model_type == BuiltInModelType.SUNDIAL: - model = modeling_sundial.SundialForPrediction.from_pretrained(model_dir) - else: - raise BuiltInModelNotSupportError(model_type.value) - - return model - - -class Attribute(object): - def __init__(self, name: str): - """ - Args: - name: the name of the attribute - """ - self._name = name - - @abstractmethod - def get_default_value(self): - raise NotImplementedError - - @abstractmethod - def validate_value(self, value): - raise NotImplementedError - - @abstractmethod - def parse(self, string_value: str): - raise NotImplementedError - - -class IntAttribute(Attribute): - def __init__( - self, - name: str, - default_value: int, - default_low: int, - default_high: int, - ): - super(IntAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - int_value = int(string_value) - except: - raise WrongAttributeTypeError(self._name, "int") - return int_value - - -class FloatAttribute(Attribute): - def __init__( - self, - name: str, - default_value: float, - default_low: float, - default_high: float, - ): - super(FloatAttribute, self).__init__(name) - self.__default_value = default_value - self.__default_low = default_low - self.__default_high = default_high - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if self.__default_low <= value <= self.__default_high: - return True - raise NumericalRangeException( - self._name, value, self.__default_low, self.__default_high - ) - - def parse(self, string_value: str): - try: - float_value = float(string_value) - except: - raise WrongAttributeTypeError(self._name, "float") - return float_value - - -class StringAttribute(Attribute): - def __init__(self, name: str, default_value: str, value_choices: List[str]): - super(StringAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_choices = value_choices - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if value in self.__value_choices: - return True - raise StringRangeException(self._name, value, self.__value_choices) - - def parse(self, string_value: str): - return string_value - - -class BooleanAttribute(Attribute): - def __init__(self, name: str, default_value: bool): - super(BooleanAttribute, self).__init__(name) - self.__default_value = default_value - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if isinstance(value, bool): - return True - raise WrongAttributeTypeError(self._name, "bool") - - def parse(self, string_value: str): - if string_value.lower() == "true": - return True - elif string_value.lower() == "false": - return False - else: - raise WrongAttributeTypeError(self._name, "bool") - - -class ListAttribute(Attribute): - def __init__(self, name: str, default_value: List, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(ListAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, list): - raise WrongAttributeTypeError(self._name, "list") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - list_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "list") - if not isinstance(list_value, list): - raise WrongAttributeTypeError(self._name, "list") - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - return list_value - - -class TupleAttribute(Attribute): - def __init__(self, name: str, default_value: tuple, value_type): - """ - value_type is the type of the elements in the list, e.g. int, float, str - """ - super(TupleAttribute, self).__init__(name) - self.__default_value = default_value - self.__value_type = value_type - self.__type_to_str = {str: "str", int: "int", float: "float"} - - def get_default_value(self): - return self.__default_value - - def validate_value(self, value): - if not isinstance(value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - for value_item in value: - if not isinstance(value_item, self.__value_type): - raise WrongAttributeTypeError(self._name, self.__value_type) - return True - - def parse(self, string_value: str): - try: - tuple_value = eval(string_value) - except: - raise WrongAttributeTypeError(self._name, "tuple") - if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self._name, "tuple") - list_value = list(tuple_value) - for i in range(len(list_value)): - try: - list_value[i] = self.__value_type(list_value[i]) - except: - raise ListRangeException( - self._name, list_value, self.__type_to_str[self.__value_type] - ) - tuple_value = tuple(list_value) - return tuple_value - - -def parse_attribute( - input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] -): - """ - Args: - input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of - the attribute - attribute_map: a dict of hyperparameters, where the key is the attribute name, the value is the Attribute - object - Returns: - a dict of attributes, where the key is the attribute name, the value is the parsed value of the attribute - """ - attributes = {} - for attribute_name in attribute_map: - # user specified the attribute - if attribute_name in input_attributes: - attribute = attribute_map[attribute_name] - value = attribute.parse(input_attributes[attribute_name]) - attribute.validate_value(value) - attributes[attribute_name] = value - # user did not specify the attribute, use the default value - else: - try: - attributes[attribute_name] = attribute_map[ - attribute_name - ].get_default_value() - except NotImplementedError as e: - logger.error(f"attribute {attribute_name} is not implemented.") - raise e - return attributes - - -sundial_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=16, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=768, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=3072, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=12, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=12, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.DROPOUT_RATE.value: FloatAttribute( - name=AttributeName.DROPOUT_RATE.value, - default_value=0.1, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute( - name=AttributeName.FLOW_LOSS_DEPTH.value, - default_value=3, - default_low=1, - default_high=50, - ), - AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute( - name=AttributeName.NUM_SAMPLING_STEPS.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute( - name=AttributeName.DIFFUSION_BATCH_MUL.value, - default_value=4, - default_low=1, - default_high=5000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "sundial", - ), - value_choices=[""], - ), -} - -timerxl_attribute_map = { - AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( - name=AttributeName.INPUT_TOKEN_LEN.value, - default_value=96, - default_low=1, - default_high=5000, - ), - AttributeName.HIDDEN_SIZE.value: IntAttribute( - name=AttributeName.HIDDEN_SIZE.value, - default_value=1024, - default_low=1, - default_high=5000, - ), - AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( - name=AttributeName.INTERMEDIATE_SIZE.value, - default_value=2048, - default_low=1, - default_high=5000, - ), - AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int - ), - AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( - name=AttributeName.NUM_HIDDEN_LAYERS.value, - default_value=8, - default_low=1, - default_high=16, - ), - AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( - name=AttributeName.NUM_ATTENTION_HEADS.value, - default_value=8, - default_low=1, - default_high=192, - ), - AttributeName.HIDDEN_ACT.value: StringAttribute( - name=AttributeName.HIDDEN_ACT.value, - default_value="silu", - value_choices=["relu", "gelu", "silu", "tanh"], - ), - AttributeName.USE_CACHE.value: BooleanAttribute( - name=AttributeName.USE_CACHE.value, - default_value=True, - ), - AttributeName.ROPE_THETA.value: IntAttribute( - name=AttributeName.ROPE_THETA.value, - default_value=10000, - default_low=1000, - default_high=50000, - ), - AttributeName.ATTENTION_DROPOUT.value: FloatAttribute( - name=AttributeName.ATTENTION_DROPOUT.value, - default_value=0.0, - default_low=0.0, - default_high=1.0, - ), - AttributeName.INITIALIZER_RANGE.value: FloatAttribute( - name=AttributeName.INITIALIZER_RANGE.value, - default_value=0.02, - default_low=0.0, - default_high=1.0, - ), - AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( - name=AttributeName.MAX_POSITION_EMBEDDINGS.value, - default_value=10000, - default_low=1, - default_high=50000, - ), - AttributeName.CKPT_PATH.value: StringAttribute( - name=AttributeName.CKPT_PATH.value, - default_value=os.path.join( - os.getcwd(), - AINodeDescriptor().get_config().get_ain_models_dir(), - "weights", - "timerxl", - "model.safetensors", - ), - value_choices=[""], - ), -} - -# built-in sktime model attributes -# NaiveForecaster -naive_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.STRATEGY.value: StringAttribute( - name=AttributeName.STRATEGY.value, - default_value="last", - value_choices=["last", "mean"], - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 - ), -} -# ExponentialSmoothing -exponential_smoothing_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.DAMPED_TREND.value: BooleanAttribute( - name=AttributeName.DAMPED_TREND.value, - default_value=False, - ), - AttributeName.INITIALIZATION_METHOD.value: StringAttribute( - name=AttributeName.INITIALIZATION_METHOD.value, - default_value="estimated", - value_choices=["estimated", "heuristic", "legacy-heuristic", "known"], - ), - AttributeName.OPTIMIZED.value: BooleanAttribute( - name=AttributeName.OPTIMIZED.value, - default_value=True, - ), - AttributeName.REMOVE_BIAS.value: BooleanAttribute( - name=AttributeName.REMOVE_BIAS.value, - default_value=False, - ), - AttributeName.USE_BRUTE.value: BooleanAttribute( - name=AttributeName.USE_BRUTE.value, - default_value=False, - ), -} -# Arima -arima_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.ORDER.value: TupleAttribute( - name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int - ), - AttributeName.SEASONAL_ORDER.value: TupleAttribute( - name=AttributeName.SEASONAL_ORDER.value, - default_value=(0, 0, 0, 0), - value_type=int, - ), - AttributeName.METHOD.value: StringAttribute( - name=AttributeName.METHOD.value, - default_value="lbfgs", - value_choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], - ), - AttributeName.MAXITER.value: IntAttribute( - name=AttributeName.MAXITER.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( - name=AttributeName.SUPPRESS_WARNINGS.value, - default_value=True, - ), - AttributeName.OUT_OF_SAMPLE_SIZE.value: IntAttribute( - name=AttributeName.OUT_OF_SAMPLE_SIZE.value, - default_value=0, - default_low=0, - default_high=5000, - ), - AttributeName.SCORING.value: StringAttribute( - name=AttributeName.SCORING.value, - default_value="mse", - value_choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], - ), - AttributeName.WITH_INTERCEPT.value: BooleanAttribute( - name=AttributeName.WITH_INTERCEPT.value, - default_value=True, - ), - AttributeName.TIME_VARYING_REGRESSION.value: BooleanAttribute( - name=AttributeName.TIME_VARYING_REGRESSION.value, - default_value=False, - ), - AttributeName.ENFORCE_STATIONARITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_STATIONARITY.value, - default_value=True, - ), - AttributeName.ENFORCE_INVERTIBILITY.value: BooleanAttribute( - name=AttributeName.ENFORCE_INVERTIBILITY.value, - default_value=True, - ), - AttributeName.SIMPLE_DIFFERENCING.value: BooleanAttribute( - name=AttributeName.SIMPLE_DIFFERENCING.value, - default_value=False, - ), - AttributeName.MEASUREMENT_ERROR.value: BooleanAttribute( - name=AttributeName.MEASUREMENT_ERROR.value, - default_value=False, - ), - AttributeName.MLE_REGRESSION.value: BooleanAttribute( - name=AttributeName.MLE_REGRESSION.value, - default_value=True, - ), - AttributeName.HAMILTON_REPRESENTATION.value: BooleanAttribute( - name=AttributeName.HAMILTON_REPRESENTATION.value, - default_value=False, - ), - AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( - name=AttributeName.CONCENTRATE_SCALE.value, - default_value=False, - ), -} -# STLForecaster -stl_forecaster_attribute_map = { - AttributeName.PREDICT_LENGTH.value: IntAttribute( - name=AttributeName.PREDICT_LENGTH.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000 - ), - AttributeName.SEASONAL.value: IntAttribute( - name=AttributeName.SEASONAL.value, - default_value=7, - default_low=1, - default_high=5000, - ), - AttributeName.SEASONAL_DEG.value: IntAttribute( - name=AttributeName.SEASONAL_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_DEG.value: IntAttribute( - name=AttributeName.TREND_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOW_PASS_DEG.value: IntAttribute( - name=AttributeName.LOW_PASS_DEG.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.SEASONAL_JUMP.value: IntAttribute( - name=AttributeName.SEASONAL_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.TREND_JUMP.value: IntAttribute( - name=AttributeName.TREND_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), - AttributeName.LOSS_PASS_JUMP.value: IntAttribute( - name=AttributeName.LOSS_PASS_JUMP.value, - default_value=1, - default_low=0, - default_high=5000, - ), -} - -# GAUSSIAN_HMM -gaussian_hmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["spherical", "diag", "full", "tied"], - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_PRIOR.value: FloatAttribute( - name=AttributeName.COVARS_PRIOR.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.COVARS_WEIGHT.value: FloatAttribute( - name=AttributeName.COVARS_WEIGHT.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmc", - value_choices=["stmc", "stm"], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# GMMHMM -gmmhmm_attribute_map = { - AttributeName.N_COMPONENTS.value: IntAttribute( - name=AttributeName.N_COMPONENTS.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.N_MIX.value: IntAttribute( - name=AttributeName.N_MIX.value, - default_value=1, - default_low=1, - default_high=5000, - ), - AttributeName.MIN_COVAR.value: FloatAttribute( - name=AttributeName.MIN_COVAR.value, - default_value=1e-3, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.STARTPROB_PRIOR.value: FloatAttribute( - name=AttributeName.STARTPROB_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.TRANSMAT_PRIOR.value: FloatAttribute( - name=AttributeName.TRANSMAT_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.WEIGHTS_PRIOR.value: FloatAttribute( - name=AttributeName.WEIGHTS_PRIOR.value, - default_value=1.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_PRIOR.value: FloatAttribute( - name=AttributeName.MEANS_PRIOR.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.MEANS_WEIGHT.value: FloatAttribute( - name=AttributeName.MEANS_WEIGHT.value, - default_value=0.0, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.ALGORITHM.value: StringAttribute( - name=AttributeName.ALGORITHM.value, - default_value="viterbi", - value_choices=["viterbi", "map"], - ), - AttributeName.COVARIANCE_TYPE.value: StringAttribute( - name=AttributeName.COVARIANCE_TYPE.value, - default_value="diag", - value_choices=["sperical", "diag", "full", "tied"], - ), - AttributeName.N_ITER.value: IntAttribute( - name=AttributeName.N_ITER.value, - default_value=10, - default_low=1, - default_high=5000, - ), - AttributeName.TOL.value: FloatAttribute( - name=AttributeName.TOL.value, - default_value=1e-2, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.INIT_PARAMS.value: StringAttribute( - name=AttributeName.INIT_PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.PARAMS.value: StringAttribute( - name=AttributeName.PARAMS.value, - default_value="stmcw", - value_choices=[ - "s", - "t", - "m", - "c", - "w", - "st", - "sm", - "sc", - "sw", - "tm", - "tc", - "tw", - "mc", - "mw", - "cw", - "stm", - "stc", - "stw", - "smc", - "smw", - "scw", - "tmc", - "tmw", - "tcw", - "mcw", - "stmc", - "stmw", - "stcw", - "smcw", - "tmcw", - "stmcw", - ], - ), - AttributeName.IMPLEMENTATION.value: StringAttribute( - name=AttributeName.IMPLEMENTATION.value, - default_value="log", - value_choices=["log", "scaling"], - ), -} - -# STRAY -stray_attribute_map = { - AttributeName.ALPHA.value: FloatAttribute( - name=AttributeName.ALPHA.value, - default_value=0.01, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.K.value: IntAttribute( - name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000 - ), - AttributeName.KNN_ALGORITHM.value: StringAttribute( - name=AttributeName.KNN_ALGORITHM.value, - default_value="brute", - value_choices=["brute", "kd_tree", "ball_tree", "auto"], - ), - AttributeName.P.value: FloatAttribute( - name=AttributeName.P.value, - default_value=0.5, - default_low=-1e10, - default_high=1e10, - ), - AttributeName.SIZE_THRESHOLD.value: IntAttribute( - name=AttributeName.SIZE_THRESHOLD.value, - default_value=50, - default_low=1, - default_high=5000, - ), - AttributeName.OUTLIER_TAIL.value: StringAttribute( - name=AttributeName.OUTLIER_TAIL.value, - default_value="max", - value_choices=["min", "max"], - ), -} - - -class BuiltInModel(object): - def __init__(self, attributes): - self._attributes = attributes - self._model = None - - @abstractmethod - def inference(self, data): - raise NotImplementedError - - -class ArimaModel(BuiltInModel): - def __init__(self, attributes): - super(ArimaModel, self).__init__(attributes) - self._model = ARIMA( - order=attributes["order"], - seasonal_order=attributes["seasonal_order"], - method=attributes["method"], - suppress_warnings=attributes["suppress_warnings"], - maxiter=attributes["maxiter"], - out_of_sample_size=attributes["out_of_sample_size"], - scoring=attributes["scoring"], - with_intercept=attributes["with_intercept"], - time_varying_regression=attributes["time_varying_regression"], - enforce_stationarity=attributes["enforce_stationarity"], - enforce_invertibility=attributes["enforce_invertibility"], - simple_differencing=attributes["simple_differencing"], - measurement_error=attributes["measurement_error"], - mle_regression=attributes["mle_regression"], - hamilton_representation=attributes["hamilton_representation"], - concentrate_scale=attributes["concentrate_scale"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class ExponentialSmoothingModel(BuiltInModel): - def __init__(self, attributes): - super(ExponentialSmoothingModel, self).__init__(attributes) - self._model = ExponentialSmoothing( - damped_trend=attributes["damped_trend"], - initialization_method=attributes["initialization_method"], - optimized=attributes["optimized"], - remove_bias=attributes["remove_bias"], - use_brute=attributes["use_brute"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class NaiveForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(NaiveForecasterModel, self).__init__(attributes) - self._model = NaiveForecaster( - strategy=attributes["strategy"], sp=attributes["sp"] - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STLForecasterModel(BuiltInModel): - def __init__(self, attributes): - super(STLForecasterModel, self).__init__(attributes) - self._model = STLForecaster( - sp=attributes["sp"], - seasonal=attributes["seasonal"], - seasonal_deg=attributes["seasonal_deg"], - trend_deg=attributes["trend_deg"], - low_pass_deg=attributes["low_pass_deg"], - seasonal_jump=attributes["seasonal_jump"], - trend_jump=attributes["trend_jump"], - low_pass_jump=attributes["low_pass_jump"], - ) - - def inference(self, data): - try: - predict_length = self._attributes["predict_length"] - self._model.fit(data) - output = self._model.predict(fh=range(predict_length)) - output = np.array(output, dtype=np.float64) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GMMHMMModel(BuiltInModel): - def __init__(self, attributes): - super(GMMHMMModel, self).__init__(attributes) - self._model = GMMHMM( - n_components=attributes["n_components"], - n_mix=attributes["n_mix"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - weights_prior=attributes["weights_prior"], - algorithm=attributes["algorithm"], - covariance_type=attributes["covariance_type"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class GaussianHmmModel(BuiltInModel): - def __init__(self, attributes): - super(GaussianHmmModel, self).__init__(attributes) - self._model = GaussianHMM( - n_components=attributes["n_components"], - covariance_type=attributes["covariance_type"], - min_covar=attributes["min_covar"], - startprob_prior=attributes["startprob_prior"], - transmat_prior=attributes["transmat_prior"], - means_prior=attributes["means_prior"], - means_weight=attributes["means_weight"], - covars_prior=attributes["covars_prior"], - covars_weight=attributes["covars_weight"], - algorithm=attributes["algorithm"], - n_iter=attributes["n_iter"], - tol=attributes["tol"], - params=attributes["params"], - init_params=attributes["init_params"], - implementation=attributes["implementation"], - ) - - def inference(self, data): - try: - self._model.fit(data) - output = self._model.predict(data) - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) - - -class STRAYModel(BuiltInModel): - def __init__(self, attributes): - super(STRAYModel, self).__init__(attributes) - self._model = STRAY( - alpha=attributes["alpha"], - k=attributes["k"], - knn_algorithm=attributes["knn_algorithm"], - p=attributes["p"], - size_threshold=attributes["size_threshold"], - outlier_tail=attributes["outlier_tail"], - ) - - def inference(self, data): - try: - data = MinMaxScaler().fit_transform(data) - output = self._model.fit_transform(data) - # change the output to int - output = np.array(output, dtype=np.int32) - return output - except Exception as e: - raise InferenceModelInternalError(str(e)) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py new file mode 100644 index 000000000000..c42ec98551b8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -0,0 +1,48 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from enum import Enum + +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + +# Model file constants +MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" +MODEL_CONFIG_FILE_IN_JSON = "config.json" +MODEL_WEIGHTS_FILE_IN_PT = "model.pt" +MODEL_CONFIG_FILE_IN_YAML = "config.yaml" + + +class ModelCategory(Enum): + BUILTIN = "builtin" + USER_DEFINED = "user_defined" + + +class ModelStates(Enum): + INACTIVE = "inactive" + ACTIVATING = "activating" + ACTIVE = "active" + DROPPING = "dropping" + + +class UriType(Enum): + REPO = "repo" + FILE = "file" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py deleted file mode 100644 index 348f9924316b..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_enums.py +++ /dev/null @@ -1,70 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -from enum import Enum -from typing import List - - -class BuiltInModelType(Enum): - # forecast models - ARIMA = "Arima" - HOLTWINTERS = "HoltWinters" - EXPONENTIAL_SMOOTHING = "ExponentialSmoothing" - NAIVE_FORECASTER = "NaiveForecaster" - STL_FORECASTER = "StlForecaster" - - # anomaly detection models - GAUSSIAN_HMM = "GaussianHmm" - GMM_HMM = "GmmHmm" - STRAY = "Stray" - - # large time series models (LTSM) - TIMER_XL = "Timer-XL" - # sundial - SUNDIAL = "Timer-Sundial" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def is_built_in_model(model_type: str) -> bool: - """ - Check if the given model type corresponds to a built-in model. - """ - return model_type in BuiltInModelType.values() - - -class ModelFileType(Enum): - SAFETENSORS = "safetensors" - PYTORCH = "pytorch" - UNKNOWN = "unknown" - - -class ModelCategory(Enum): - BUILT_IN = "BUILT-IN" - FINE_TUNED = "FINE-TUNED" - USER_DEFINED = "USER-DEFINED" - - -class ModelStates(Enum): - ACTIVE = "ACTIVE" - INACTIVE = "INACTIVE" - LOADING = "LOADING" - DROPPING = "DROPPING" - TRAINING = "TRAINING" - FAILED = "FAILED" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py deleted file mode 100644 index 26d863156f37..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_factory.py +++ /dev/null @@ -1,291 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import glob -import os -import shutil -from urllib.parse import urljoin - -import yaml - -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, -) -from iotdb.ainode.core.exception import BadConfigValueError, InvalidUriError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.uri_utils import ( - UriType, - download_file, - download_snapshot_from_hf, -) -from iotdb.ainode.core.util.serde import get_data_type_byte_from_str -from iotdb.thrift.ainode.ttypes import TConfigs - -logger = Logger() - - -def fetch_model_by_uri( - uri_type: UriType, uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Fetch the model files from the specified URI. - - Args: - uri_type (UriType): type of the URI, either repo, file, http or https - uri (str): a network or a local path of the model to be registered - storage_path (str): path to save the whole model, including weights, config, codes, etc. - model_file_type (ModelFileType): The type of model file, either safetensors or pytorch - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if uri_type == UriType.REPO or uri_type in [UriType.HTTP, UriType.HTTPS]: - return _fetch_model_from_network(uri, storage_path, model_file_type) - elif uri_type == UriType.FILE: - return _fetch_model_from_local(uri, storage_path, model_file_type) - else: - raise InvalidUriError(f"Invalid URI type: {uri_type}") - - -def _fetch_model_from_network( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - download_snapshot_from_hf(uri, storage_path) - return _process_huggingface_files(storage_path) - - # TODO: The following codes might be refactored in future - # concat uri to get complete url - uri = uri if uri.endswith("/") else uri + "/" - target_model_path = urljoin(uri, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = urljoin(uri, MODEL_CONFIG_FILE_IN_YAML) - - # download config file - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - download_file(target_config_path, config_storage_path) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, download model file - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - download_file(target_model_path, model_storage_path) - return configs, attributes - - -def _fetch_model_from_local( - uri: str, storage_path: str, model_file_type: ModelFileType -): - """ - Returns: TODO: Will be removed in future - configs: TConfigs - attributes: str - """ - if model_file_type == ModelFileType.SAFETENSORS: - # copy anything in the uri to local_dir - for file in os.listdir(uri): - shutil.copy(os.path.join(uri, file), storage_path) - return _process_huggingface_files(storage_path) - # concat uri to get complete path - target_model_path = os.path.join(uri, MODEL_WEIGHTS_FILE_IN_PT) - model_storage_path = os.path.join(storage_path, MODEL_WEIGHTS_FILE_IN_PT) - target_config_path = os.path.join(uri, MODEL_CONFIG_FILE_IN_YAML) - config_storage_path = os.path.join(storage_path, MODEL_CONFIG_FILE_IN_YAML) - - # check if file exist - exist_model_file = os.path.exists(target_model_path) - exist_config_file = os.path.exists(target_config_path) - - configs = None - attributes = None - if exist_model_file and exist_config_file: - # copy config.yaml - shutil.copy(target_config_path, config_storage_path) - logger.info( - f"copy file from {target_config_path} to {config_storage_path} success" - ) - - # read and parse config dict from config.yaml - with open(config_storage_path, "r", encoding="utf-8") as file: - config_dict = yaml.safe_load(file) - configs, attributes = _parse_inference_config(config_dict) - - # if config.yaml is correct, copy model file - shutil.copy(target_model_path, model_storage_path) - logger.info( - f"copy file from {target_model_path} to {model_storage_path} success" - ) - - elif not exist_model_file or not exist_config_file: - raise InvalidUriError(uri) - - return configs, attributes - - -def _parse_inference_config(config_dict): - """ - Args: - config_dict: dict - - configs: dict - - input_shape (list): input shape of the model and needs to be two-dimensional array like [96, 2] - - output_shape (list): output shape of the model and needs to be two-dimensional array like [96, 2] - - input_type (list): input type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - output_type (list): output type of the model and each element needs to be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'], default float64 - - attributes: dict - Returns: - configs: TConfigs - attributes: str - """ - configs = config_dict["configs"] - - # check if input_shape and output_shape are two-dimensional array - if not ( - isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2 - ): - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "input_shape should be a two-dimensional array.", - ) - if not ( - isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2 - ): - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "output_shape should be a two-dimensional array.", - ) - - # check if input_shape and output_shape are positive integer - input_shape_is_positive_number = ( - isinstance(configs["input_shape"][0], int) - and isinstance(configs["input_shape"][1], int) - and configs["input_shape"][0] > 0 - and configs["input_shape"][1] > 0 - ) - if not input_shape_is_positive_number: - raise BadConfigValueError( - "input_shape", - configs["input_shape"], - "element in input_shape should be positive integer.", - ) - - output_shape_is_positive_number = ( - isinstance(configs["output_shape"][0], int) - and isinstance(configs["output_shape"][1], int) - and configs["output_shape"][0] > 0 - and configs["output_shape"][1] > 0 - ) - if not output_shape_is_positive_number: - raise BadConfigValueError( - "output_shape", - configs["output_shape"], - "element in output_shape should be positive integer.", - ) - - # check if input_type and output_type are one-dimensional array with right length - if "input_type" in configs and not ( - isinstance(configs["input_type"], list) - and len(configs["input_type"]) == configs["input_shape"][1] - ): - raise BadConfigValueError( - "input_type", - configs["input_type"], - "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].", - ) - - if "output_type" in configs and not ( - isinstance(configs["output_type"], list) - and len(configs["output_type"]) == configs["output_shape"][1] - ): - raise BadConfigValueError( - "output_type", - configs["output_type"], - "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].", - ) - - # parse input_type and output_type to byte - if "input_type" in configs: - input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]] - else: - input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][ - 1 - ] - - if "output_type" in configs: - output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]] - else: - output_type = [get_data_type_byte_from_str("float32")] * configs[ - "output_shape" - ][1] - - # parse attributes - attributes = "" - if "attributes" in config_dict: - attributes = str(config_dict["attributes"]) - - return ( - TConfigs( - configs["input_shape"], configs["output_shape"], input_type, output_type - ), - attributes, - ) - - -def _process_huggingface_files(local_dir: str): - """ - TODO: Currently, we use this function to convert the model config from huggingface, we will refactor this in the future. - """ - config_file = None - for config_name in ["config.json", "model_config.json"]: - config_path = os.path.join(local_dir, config_name) - if os.path.exists(config_path): - config_file = config_path - break - - if not config_file: - raise InvalidUriError(f"No config.json found in {local_dir}") - - safetensors_files = glob.glob(os.path.join(local_dir, "*.safetensors")) - if not safetensors_files: - raise InvalidUriError(f"No .safetensors files found in {local_dir}") - - simple_config = { - "configs": { - "input_shape": [96, 1], - "output_shape": [96, 1], - "input_type": ["float32"], - "output_type": ["float32"], - }, - "attributes": { - "model_type": "huggingface_model", - "source_dir": local_dir, - "files": [os.path.basename(f) for f in safetensors_files], - }, - } - - configs, attributes = _parse_inference_config(simple_config) - return configs, attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py index 167bfd76640d..718ead530dd2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py @@ -15,140 +15,118 @@ # specific language governing permissions and limitations # under the License. # -import glob -import os -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_CONFIG_FILE_IN_YAML, - MODEL_WEIGHTS_FILE_IN_PT, - MODEL_WEIGHTS_FILE_IN_SAFETENSORS, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, - ModelCategory, - ModelFileType, - ModelStates, -) +from typing import Dict, Optional - -def get_model_file_type(model_path: str) -> ModelFileType: - """ - Determine the file type of the specified model directory. - """ - if _has_safetensors_format(model_path): - return ModelFileType.SAFETENSORS - elif _has_pytorch_format(model_path): - return ModelFileType.PYTORCH - else: - return ModelFileType.UNKNOWN - - -def _has_safetensors_format(path: str) -> bool: - """Check if directory contains safetensors files.""" - safetensors_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_SAFETENSORS)) - json_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_JSON)) - return len(safetensors_files) > 0 and len(json_files) > 0 - - -def _has_pytorch_format(path: str) -> bool: - """Check if directory contains pytorch files.""" - pt_files = glob.glob(os.path.join(path, MODEL_WEIGHTS_FILE_IN_PT)) - yaml_files = glob.glob(os.path.join(path, MODEL_CONFIG_FILE_IN_YAML)) - return len(pt_files) > 0 and len(yaml_files) > 0 - - -def get_built_in_model_type(model_type: str) -> BuiltInModelType: - if not BuiltInModelType.is_built_in_model(model_type): - raise ValueError(f"Invalid built-in model type: {model_type}") - return BuiltInModelType(model_type) +from iotdb.ainode.core.model.model_constants import ModelCategory, ModelStates class ModelInfo: def __init__( self, model_id: str, - model_type: str, category: ModelCategory, state: ModelStates, + model_type: str = "", + config_cls: str = "", + model_cls: str = "", + pipeline_cls: str = "", + repo_id: str = "", + auto_map: Optional[Dict] = None, + _transformers_registered: bool = False, ): self.model_id = model_id self.model_type = model_type self.category = category self.state = state + self.config_cls = config_cls + self.model_cls = model_cls + self.pipeline_cls = pipeline_cls + self.repo_id = repo_id + self.auto_map = auto_map # If exists, indicates it's a Transformers model + self._transformers_registered = _transformers_registered # Internal flag: whether registered to Transformers + def __repr__(self): + return ( + f"ModelInfo(model_id={self.model_id}, model_type={self.model_type}, " + f"category={self.category.value}, state={self.state.value}, " + f"has_auto_map={self.auto_map is not None})" + ) -TIMER_REPO_ID = { - BuiltInModelType.TIMER_XL: "thuml/timer-base-84m", - BuiltInModelType.SUNDIAL: "thuml/sundial-base-128m", -} -# Built-in machine learning models, they can be employed directly -BUILT_IN_MACHINE_LEARNING_MODEL_MAP = { +BUILTIN_SKTIME_MODEL_MAP = { # forecast models "arima": ModelInfo( model_id="arima", - model_type=BuiltInModelType.ARIMA.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "holtwinters": ModelInfo( model_id="holtwinters", - model_type=BuiltInModelType.HOLTWINTERS.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "exponential_smoothing": ModelInfo( model_id="exponential_smoothing", - model_type=BuiltInModelType.EXPONENTIAL_SMOOTHING.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "naive_forecaster": ModelInfo( model_id="naive_forecaster", - model_type=BuiltInModelType.NAIVE_FORECASTER.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stl_forecaster": ModelInfo( model_id="stl_forecaster", - model_type=BuiltInModelType.STL_FORECASTER.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), # anomaly detection models "gaussian_hmm": ModelInfo( model_id="gaussian_hmm", - model_type=BuiltInModelType.GAUSSIAN_HMM.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "gmm_hmm": ModelInfo( model_id="gmm_hmm", - model_type=BuiltInModelType.GMM_HMM.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), "stray": ModelInfo( model_id="stray", - model_type=BuiltInModelType.STRAY.value, - category=ModelCategory.BUILT_IN, + category=ModelCategory.BUILTIN, state=ModelStates.ACTIVE, + model_type="sktime", ), } -# Built-in large time series models (LTSM), their weights are not included in AINode by default -BUILT_IN_LTSM_MAP = { +# Built-in huggingface transformers models, their weights are not included in AINode by default +BUILTIN_HF_TRANSFORMERS_MODEL_MAP = { "timer_xl": ModelInfo( model_id="timer_xl", - model_type=BuiltInModelType.TIMER_XL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="timer", + config_cls="configuration_timer.TimerConfig", + model_cls="modeling_timer.TimerForPrediction", + pipeline_cls="pipeline_timer.TimerPipeline", + repo_id="thuml/timer-base-84m", ), "sundial": ModelInfo( model_id="sundial", - model_type=BuiltInModelType.SUNDIAL.value, - category=ModelCategory.BUILT_IN, - state=ModelStates.LOADING, + category=ModelCategory.BUILTIN, + state=ModelStates.INACTIVE, + model_type="sundial", + config_cls="configuration_sundial.SundialConfig", + model_cls="modeling_sundial.SundialForPrediction", + pipeline_cls="pipeline_sundial.SundialPipeline", + repo_id="thuml/sundial-base-128m", ), } diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py new file mode 100644 index 000000000000..a6e3b1f7b5e3 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -0,0 +1,156 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import os +from pathlib import Path +from typing import Any + +import torch +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForNextSentencePrediction, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTimeSeriesPrediction, + AutoModelForTokenClassification, +) + +from iotdb.ainode.core.config import AINodeDescriptor +from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.log import Logger +from iotdb.ainode.core.model.model_constants import ModelCategory +from iotdb.ainode.core.model.model_info import ModelInfo +from iotdb.ainode.core.model.sktime.modeling_sktime import create_sktime_model +from iotdb.ainode.core.model.utils import import_class_from_path, temporary_sys_path + +logger = Logger() + + +def load_model(model_info: ModelInfo, **model_kwargs) -> Any: + if model_info.auto_map is not None: + model = load_model_from_transformers(model_info, **model_kwargs) + else: + if model_info.model_type == "sktime": + model = create_sktime_model(model_info.model_id) + else: + model = load_model_from_pt(model_info, **model_kwargs) + + logger.info( + f"Model {model_info.model_id} loaded to device {model.device if model_info.model_type != 'sktime' else 'cpu'} successfully." + ) + return model + + +def load_model_from_transformers(model_info: ModelInfo, **model_kwargs): + device_map = model_kwargs.get("device_map", "cpu") + trust_remote_code = model_kwargs.get("trust_remote_code", True) + train_from_scratch = model_kwargs.get("train_from_scratch", False) + + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + + if model_info.category == ModelCategory.BUILTIN: + module_name = ( + AINodeDescriptor().get_config().get_ain_models_builtin_dir() + + "." + + model_info.model_id + ) + config_cls = import_class_from_path(module_name, model_info.config_cls) + model_cls = import_class_from_path(module_name, model_info.model_cls) + elif model_info.model_cls and model_info.config_cls: + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_cls = import_class_from_path( + model_info.model_id, model_info.config_cls + ) + model_cls = import_class_from_path( + model_info.model_id, model_info.model_cls + ) + else: + config_cls = AutoConfig.from_pretrained(model_path) + if type(config_cls) in AutoModelForTimeSeriesPrediction._model_mapping.keys(): + model_cls = AutoModelForTimeSeriesPrediction + elif ( + type(config_cls) in AutoModelForNextSentencePrediction._model_mapping.keys() + ): + model_cls = AutoModelForNextSentencePrediction + elif type(config_cls) in AutoModelForSeq2SeqLM._model_mapping.keys(): + model_cls = AutoModelForSeq2SeqLM + elif ( + type(config_cls) in AutoModelForSequenceClassification._model_mapping.keys() + ): + model_cls = AutoModelForSequenceClassification + elif type(config_cls) in AutoModelForTokenClassification._model_mapping.keys(): + model_cls = AutoModelForTokenClassification + else: + model_cls = AutoModelForCausalLM + + if train_from_scratch: + model = model_cls.from_config( + config_cls, trust_remote_code=trust_remote_code, device_map=device_map + ) + else: + model = model_cls.from_pretrained( + model_path, + trust_remote_code=trust_remote_code, + device_map=device_map, + ) + + return model + + +def load_model_from_pt(model_info: ModelInfo, **kwargs): + device_map = kwargs.get("device_map", "cpu") + acceleration = kwargs.get("acceleration", False) + model_path = os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + model_info.category.value, + model_info.model_id, + ) + model_file = os.path.join(model_path, "model.pt") + if not os.path.exists(model_file): + logger.error(f"Model file not found at {model_file}.") + raise ModelNotExistError(model_file) + model = torch.jit.load(model_file) + if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: + return model + try: + model = torch.compile(model) + except Exception as e: + logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + return model.to(device_map) + + +def load_model_for_efficient_inference(): + # TODO: An efficient model loading method for inference based on model_arguments + pass + + +def load_model_for_powerful_finetune(): + # TODO: An powerful model loading method for finetune based on model_arguments + pass + + +def unload_model(): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index e346f569102e..5194ed4df1bd 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -20,43 +20,37 @@ import json import os import shutil -from collections.abc import Callable -from typing import Dict +from pathlib import Path +from typing import Dict, List, Optional -import torch -from torch import nn +from huggingface_hub import hf_hub_download, snapshot_download +from transformers import AutoConfig, AutoModelForCausalLM from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.constant import ( - MODEL_CONFIG_FILE_IN_JSON, - MODEL_WEIGHTS_FILE_IN_PT, - TSStatusCode, -) -from iotdb.ainode.core.exception import ( - BuiltInModelDeletionError, - ModelNotExistError, - UnsupportedError, -) +from iotdb.ainode.core.constant import TSStatusCode +from iotdb.ainode.core.exception import BuiltInModelDeletionError from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.built_in_model_factory import ( - download_built_in_ltsm_from_hf_if_necessary, - fetch_built_in_model, -) -from iotdb.ainode.core.model.model_enums import ( - BuiltInModelType, +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, ModelCategory, - ModelFileType, ModelStates, + UriType, ) -from iotdb.ainode.core.model.model_factory import fetch_model_by_uri from iotdb.ainode.core.model.model_info import ( - BUILT_IN_LTSM_MAP, - BUILT_IN_MACHINE_LEARNING_MODEL_MAP, + BUILTIN_HF_TRANSFORMERS_MODEL_MAP, + BUILTIN_SKTIME_MODEL_MAP, ModelInfo, - get_built_in_model_type, - get_model_file_type, ) -from iotdb.ainode.core.model.uri_utils import get_model_register_strategy +from iotdb.ainode.core.model.utils import ( + ensure_init_file, + get_parsed_uri, + import_class_from_path, + load_model_config_in_json, + parse_uri_type, + temporary_sys_path, + validate_model_files, +) from iotdb.ainode.core.util.lock import ModelLockPool from iotdb.thrift.ainode.ttypes import TShowModelsReq, TShowModelsResp from iotdb.thrift.common.ttypes import TSStatus @@ -64,320 +58,368 @@ logger = Logger() -class ModelStorage(object): +class ModelStorage: + """Model storage class - unified management of model discovery and registration""" + def __init__(self): - self._model_dir = os.path.join( + self._models_dir = os.path.join( os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() ) - if not os.path.exists(self._model_dir): - try: - os.makedirs(self._model_dir) - except PermissionError as e: - logger.error(e) - raise e - self._builtin_model_dir = os.path.join( - os.getcwd(), AINodeDescriptor().get_config().get_ain_builtin_models_dir() - ) - if not os.path.exists(self._builtin_model_dir): - try: - os.makedirs(self._builtin_model_dir) - except PermissionError as e: - logger.error(e) - raise e + # Unified storage: category -> {model_id -> ModelInfo} + self._models: Dict[str, Dict[str, ModelInfo]] = { + ModelCategory.BUILTIN.value: {}, + ModelCategory.USER_DEFINED.value: {}, + } + # Async download executor + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=2) + # Thread lock pool for protecting concurrent access to model information self._lock_pool = ModelLockPool() - self._executor = concurrent.futures.ThreadPoolExecutor( - max_workers=1 - ) # TODO: Here we set the work_num=1 cause we found that the hf download interface is not stable for concurrent downloading. - self._model_info_map: Dict[str, ModelInfo] = {} - self._init_model_info_map() + self._initialize_directories() + self.discover_all_models() + + def _initialize_directories(self): + """Initialize directory structure and ensure __init__.py files exist""" + os.makedirs(self._models_dir, exist_ok=True) + ensure_init_file(self._models_dir) + for category in ModelCategory: + category_path = os.path.join(self._models_dir, category.value) + os.makedirs(category_path, exist_ok=True) + ensure_init_file(category_path) + + # ==================== Discovery Methods ==================== + + def discover_all_models(self): + """Scan file system to discover all models""" + self._discover_category(ModelCategory.BUILTIN) + self._discover_category(ModelCategory.USER_DEFINED) + + def _discover_category(self, category: ModelCategory): + """Discover all models in a category directory""" + category_path = os.path.join(self._models_dir, category.value) + if category == ModelCategory.BUILTIN: + self._discover_builtin_models(category_path) + elif category == ModelCategory.USER_DEFINED: + for model_id in os.listdir(category_path): + if os.path.isdir(os.path.join(category_path, model_id)): + self._process_user_defined_model_directory( + os.path.join(category_path, model_id), model_id + ) - def _init_model_info_map(self): - """ - Initialize the model info map. - """ - # 1. initialize built-in and ready-to-use models - for model_id in BUILT_IN_MACHINE_LEARNING_MODEL_MAP: - self._model_info_map[model_id] = BUILT_IN_MACHINE_LEARNING_MODEL_MAP[ - model_id - ] - # 2. retrieve fine-tuned models from the built-in model directory - fine_tuned_models = self._retrieve_fine_tuned_models() - for model_id in fine_tuned_models: - self._model_info_map[model_id] = fine_tuned_models[model_id] - # 3. automatically downloading the weights of built-in LSTM models when necessary - for model_id in BUILT_IN_LTSM_MAP: - if model_id not in self._model_info_map: - self._model_info_map[model_id] = BUILT_IN_LTSM_MAP[model_id] - future = self._executor.submit( - self._download_built_in_model_if_necessary, model_id - ) - future.add_done_callback( - lambda f, mid=model_id: self._callback_model_download_result(f, mid) - ) - # 4. retrieve user-defined models from the model directory - user_defined_models = self._retrieve_user_defined_models() - for model_id in user_defined_models: - self._model_info_map[model_id] = user_defined_models[model_id] + def _discover_builtin_models(self, category_path: str): + # Register SKTIME models directly from map + for model_id in BUILTIN_SKTIME_MODEL_MAP.keys(): + with self._lock_pool.get_lock(model_id).write_lock(): + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_SKTIME_MODEL_MAP[model_id] + ) - def _retrieve_fine_tuned_models(self): - """ - Retrieve fine-tuned models from the built-in model directory. + # Process HuggingFace Transformers models + for model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys(): + model_dir = os.path.join(category_path, model_id) + os.makedirs(model_dir, exist_ok=True) + self._process_builtin_model_directory(model_dir, model_id) - Returns: - {"model_id": ModelInfo} - """ - result = {} - build_in_dirs = [ - d - for d in os.listdir(self._builtin_model_dir) - if os.path.isdir(os.path.join(self._builtin_model_dir, d)) - ] - for model_id in build_in_dirs: - config_file_path = os.path.join( - self._builtin_model_dir, model_id, MODEL_CONFIG_FILE_IN_JSON + def _process_builtin_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a builtin model directory.""" + ensure_init_file(model_dir) + with self._lock_pool.get_lock(model_id).write_lock(): + # Check if model already exists and is in a valid state + existing_model = self._models[ModelCategory.BUILTIN.value].get(model_id) + if existing_model: + # If model is already ACTIVATING or ACTIVE, skip duplicate download + if existing_model.state in (ModelStates.ACTIVATING, ModelStates.ACTIVE): + return + + # If model not exists or is INACTIVE, we'll try to update its info and download its weights + self._models[ModelCategory.BUILTIN.value][model_id] = ( + BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id] ) - if os.path.isfile(config_file_path): - with open(config_file_path, "r") as f: - model_config = json.load(f) - if "model_type" in model_config: - model_type = model_config["model_type"] - model_info = ModelInfo( - model_id=model_id, - model_type=model_type, - category=ModelCategory.FINE_TUNED, - state=ModelStates.ACTIVE, + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.ACTIVATING + + def _download_model_if_necessary() -> bool: + """Returns: True if the model is existed or downloaded successfully, False otherwise.""" + repo_id = BUILTIN_HF_TRANSFORMERS_MODEL_MAP[model_id].repo_id + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + if not os.path.exists(weights_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + local_dir=model_dir, ) - # Refactor the built-in model category - if "timer_xl" == model_id: - model_info.category = ModelCategory.BUILT_IN - if "sundial" == model_id: - model_info.category = ModelCategory.BUILT_IN - # Compatible patch with the codes in HuggingFace - if "timer" == model_type: - model_info.model_type = BuiltInModelType.TIMER_XL.value - if "sundial" == model_type: - model_info.model_type = BuiltInModelType.SUNDIAL.value - result[model_id] = model_info - return result - - def _download_built_in_model_if_necessary(self, model_id: str) -> bool: - """ - Download the built-in model if it is not already downloaded. - - Args: - model_id (str): The ID of the model to download. + except Exception as e: + logger.error( + f"Failed to download model weights from HuggingFace: {e}" + ) + return False + if not os.path.exists(config_path): + try: + hf_hub_download( + repo_id=repo_id, + filename=MODEL_CONFIG_FILE_IN_JSON, + local_dir=model_dir, + ) + except Exception as e: + logger.error( + f"Failed to download model config from HuggingFace: {e}" + ) + return False + return True - Return: - bool: True if the model is existed or downloaded successfully, False otherwise. - """ - with self._lock_pool.get_lock(model_id).write_lock(): - local_dir = os.path.join(self._builtin_model_dir, model_id) - return download_built_in_ltsm_from_hf_if_necessary( - get_built_in_model_type(self._model_info_map[model_id].model_type), - local_dir, - ) + future = self._executor.submit(_download_model_if_necessary) + future.add_done_callback( + lambda f, mid=model_id: self._callback_model_download_result(f, mid) + ) def _callback_model_download_result(self, future, model_id: str): + """Callback function for handling model download results""" with self._lock_pool.get_lock(model_id).write_lock(): - if future.result(): - self._model_info_map[model_id].state = ModelStates.ACTIVE - logger.info( - f"The built-in model: {model_id} is active and ready to use." - ) - else: - self._model_info_map[model_id].state = ModelStates.INACTIVE - - def _retrieve_user_defined_models(self): - """ - Retrieve user_defined models from the model directory. + try: + if future.result(): + model_info = self._models[ModelCategory.BUILTIN.value][model_id] + model_info.state = ModelStates.ACTIVE + config_path = os.path.join( + self._models_dir, + ModelCategory.BUILTIN.value, + model_id, + MODEL_CONFIG_FILE_IN_JSON, + ) + if os.path.exists(config_path): + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + if model_info.model_type == "": + model_info.model_type = config.get("model_type", "") + model_info.auto_map = config.get("auto_map", None) + logger.info( + f"Model {model_id} downloaded successfully and is ready to use." + ) + else: + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + logger.warning(f"Failed to download model {model_id}.") + except Exception as e: + logger.error(f"Error in download callback for model {model_id}: {e}") + self._models[ModelCategory.BUILTIN.value][ + model_id + ].state = ModelStates.INACTIVE + + def _process_user_defined_model_directory(self, model_dir: str, model_id: str): + """Handling the discovery logic for a user-defined model directory.""" + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + model_type = "" + auto_map = {} + pipeline_cls = "" + if os.path.exists(config_path): + config = load_model_config_in_json(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map", None) + pipeline_cls = config.get("pipeline_cls", "") - Returns: - {"model_id": ModelInfo} - """ - result = {} - user_dirs = [ - d - for d in os.listdir(self._model_dir) - if os.path.isdir(os.path.join(self._model_dir, d)) and d != "weights" - ] - for model_id in user_dirs: - result[model_id] = ModelInfo( + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = ModelInfo( model_id=model_id, - model_type="", + model_type=model_type, category=ModelCategory.USER_DEFINED, state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + _transformers_registered=False, # Lazy registration ) - return result + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + # ==================== Registration Methods ==================== - def register_model(self, model_id: str, uri: str): + def register_model(self, model_id: str, uri: str) -> bool: """ - Args: - model_id: id of model to register - uri: network or local dir path of the model to register - Returns: - configs: TConfigs - attributes: str + Supported URI formats: + - repo:// (Maybe in the future) + - file:// """ + uri_type = parse_uri_type(uri) + parsed_uri = get_parsed_uri(uri) + + model_dir = os.path.join( + self._models_dir, ModelCategory.USER_DEFINED.value, model_id + ) + os.makedirs(model_dir, exist_ok=True) + ensure_init_file(model_dir) + + if uri_type == UriType.REPO: + self._fetch_model_from_hf_repo(parsed_uri, model_dir) + else: + self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) + + config_path, _ = validate_model_files(model_dir) + config = load_model_config_in_json(config_path) + model_type = config.get("model_type", "") + auto_map = config.get("auto_map") + pipeline_cls = config.get("pipeline_cls", "") + with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - # create storage dir if not exist - if not os.path.exists(storage_path): - os.makedirs(storage_path) - uri_type, parsed_uri, model_file_type = get_model_register_strategy(uri) - self._model_info_map[model_id] = ModelInfo( + model_info = ModelInfo( model_id=model_id, - model_type="", + model_type=model_type, category=ModelCategory.USER_DEFINED, - state=ModelStates.LOADING, + state=ModelStates.ACTIVE, + pipeline_cls=pipeline_cls, + auto_map=auto_map, + _transformers_registered=False, # Register later + ) + self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info + + if auto_map: + # Transformers model: immediately register to Transformers auto-loading mechanism + success = self._register_transformers_model(model_info) + if success: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info._transformers_registered = True + else: + with self._lock_pool.get_lock(model_id).write_lock(): + model_info.state = ModelStates.INACTIVE + logger.error(f"Failed to register Transformers model {model_id}") + return False + else: + # Other type models: only log + self._register_other_model(model_info) + + logger.info(f"Successfully registered model {model_id} from URI: {uri}") + return True + + def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" + ) + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise + + def _fetch_model_from_local(self, source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") + source_dir = Path(source_path) + if not source_dir.is_dir(): + raise ValueError( + f"Source path does not exist or is not a directory: {source_path}" ) - try: - # TODO: The uri should be fetched asynchronously - configs, attributes = fetch_model_by_uri( - uri_type, parsed_uri, storage_path, model_file_type - ) - self._model_info_map[model_id].state = ModelStates.ACTIVE - return configs, attributes - except Exception as e: - logger.error(f"Failed to register model {model_id}: {e}") - self._model_info_map[model_id].state = ModelStates.INACTIVE - raise e - def delete_model(self, model_id: str) -> None: - """ - Args: - model_id: id of model to delete - Returns: - None - """ - # check if the model is built-in - with self._lock_pool.get_lock(model_id).read_lock(): - if self._is_built_in(model_id): - raise BuiltInModelDeletionError(model_id) + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + return - # delete the user-defined or fine-tuned model - with self._lock_pool.get_lock(model_id).write_lock(): - storage_path = os.path.join(self._model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - storage_path = os.path.join(self._builtin_model_dir, f"{model_id}") - if os.path.exists(storage_path): - shutil.rmtree(storage_path) - if model_id in self._model_info_map: - del self._model_info_map[model_id] - logger.info(f"Model {model_id} deleted successfully.") - - def _is_built_in(self, model_id: str) -> bool: + def _register_transformers_model(self, model_info: ModelInfo) -> bool: """ - Check if the model_id corresponds to a built-in model. + Register Transformers model to auto-loading mechanism (internal method) + """ + auto_map = model_info.auto_map + if not auto_map: + return False - Args: - model_id (str): The ID of the model. + auto_config_path = auto_map.get("AutoConfig") + auto_model_path = auto_map.get("AutoModelForCausalLM") - Returns: - bool: True if the model is built-in, False otherwise. - """ - return ( - model_id in self._model_info_map - and self._model_info_map[model_id].category == ModelCategory.BUILT_IN - ) + try: + model_path = os.path.join( + self._models_dir, model_info.category.value, model_info.model_id + ) + module_parent = str(Path(model_path).parent.absolute()) + with temporary_sys_path(module_parent): + config_class = import_class_from_path( + model_info.model_id, auto_config_path + ) + AutoConfig.register(model_info.model_type, config_class) + logger.info( + f"Registered AutoConfig: {model_info.model_type} -> {auto_config_path}" + ) - def is_built_in_or_fine_tuned(self, model_id: str) -> bool: - """ - Check if the model_id corresponds to a built-in or fine-tuned model. + model_class = import_class_from_path( + model_info.model_id, auto_model_path + ) + AutoModelForCausalLM.register(config_class, model_class) + logger.info( + f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" + ) - Args: - model_id (str): The ID of the model. + return True + except Exception as e: + logger.warning( + f"Failed to register Transformers model {model_info.model_id}: {e}. Model may still work via auto_map, but ensure module path is correct." + ) + return False - Returns: - bool: True if the model is built-in or fine_tuned, False otherwise. - """ - return model_id in self._model_info_map and ( - self._model_info_map[model_id].category == ModelCategory.BUILT_IN - or self._model_info_map[model_id].category == ModelCategory.FINE_TUNED + def _register_other_model(self, model_info: ModelInfo): + """Register other type models (non-Transformers models)""" + logger.info( + f"Registered other type model: {model_info.model_id} ({model_info.model_type})" ) - def load_model( - self, model_id: str, inference_attrs: Dict[str, str], acceleration: bool - ) -> Callable: + def ensure_transformers_registered(self, model_id: str) -> ModelInfo: """ - Load a model with automatic detection of .safetensors or .pt format - + Ensure Transformers model is registered (called for lazy registration) + This method uses locks to ensure thread safety. All check logic is within lock protection. Returns: - model: The model instance corresponding to specific model_id - """ - with self._lock_pool.get_lock(model_id).read_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - return fetch_built_in_model( - get_built_in_model_type(self._model_info_map[model_id].model_type), - model_dir, - inference_attrs, - ) - else: - # load the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - model_file_type = get_model_file_type(model_dir) - if model_file_type == ModelFileType.SAFETENSORS: - # TODO: Support this function - raise UnsupportedError("SAFETENSORS format") - else: - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - - if not os.path.exists(model_path): - raise ModelNotExistError(model_path) - model = torch.jit.load(model_path) - if ( - isinstance(model, torch._dynamo.eval_frame.OptimizedModule) - or not acceleration - ): - return model - - try: - model = torch.compile(model) - except Exception as e: - logger.warning( - f"acceleration failed, fallback to normal mode: {str(e)}" - ) - return model - - def save_model(self, model_id: str, model: nn.Module): - """ - Save the model using save_pretrained - - Returns: - Whether saving succeeded + str: If None, registration failed, otherwise returns model path """ + # Use lock to protect entire check-execute process with self._lock_pool.get_lock(model_id).write_lock(): - if self.is_built_in_or_fine_tuned(model_id): - model_dir = os.path.join(self._builtin_model_dir, f"{model_id}") - model.save_pretrained(model_dir) - else: - # save the user-defined model - model_dir = os.path.join(self._model_dir, f"{model_id}") - os.makedirs(model_dir, exist_ok=True) - model_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_PT) - try: - scripted_model = ( - model - if isinstance(model, torch.jit.ScriptModule) - else torch.jit.script(model) + # Directly access _models dictionary (avoid calling get_model_info which may cause deadlock) + model_info = None + for category_dict in self._models.values(): + if model_id in category_dict: + model_info = category_dict[model_id] + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot register") + return None + + # If already registered, return directly + if model_info._transformers_registered: + return model_info + + # If no auto_map, not a Transformers model, mark as registered (avoid duplicate checks) + if ( + not model_info.auto_map + or model_id in BUILTIN_HF_TRANSFORMERS_MODEL_MAP.keys() + ): + model_info._transformers_registered = True + return model_info + + # Execute registration (under lock protection) + try: + success = self._register_transformers_model(model_info) + if success: + model_info._transformers_registered = True + logger.info( + f"Model {model_id} successfully registered to Transformers" ) - torch.jit.save(scripted_model, model_path) - except Exception as e: - logger.error(f"Failed to save scripted model: {e}") - - def get_ckpt_path(self, model_id: str) -> str: - """ - Get the checkpoint path for a given model ID. + return model_info + else: + model_info.state = ModelStates.INACTIVE + logger.error(f"Model {model_id} failed to register to Transformers") + return None - Args: - model_id (str): The ID of the model. + except Exception as e: + # Ensure state consistency in exception cases + model_info.state = ModelStates.INACTIVE + model_info._transformers_registered = False + logger.error( + f"Exception occurred while registering model {model_id} to Transformers: {e}" + ) + return None - Returns: - str: The path to the checkpoint file for the model. - """ - # Only support built-in models for now - return os.path.join(self._builtin_model_dir, f"{model_id}") + # ==================== Show and Delete Models ==================== def show_models(self, req: TShowModelsReq) -> TShowModelsResp: resp_status = TSStatus( @@ -385,8 +427,14 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: message="Show models successfully", ) if req.modelId: - if req.modelId in self._model_info_map: - model_info = self._model_info_map[req.modelId] + # Find specified model + model_info = None + for category_dict in self._models.values(): + if req.modelId in category_dict: + model_info = category_dict[req.modelId] + break + + if model_info: return TShowModelsResp( status=resp_status, modelIdList=[req.modelId], @@ -402,55 +450,133 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: categoryMap={}, stateMap={}, ) + # Return all models + model_id_list = [] + model_type_map = {} + category_map = {} + state_map = {} + + for category_dict in self._models.values(): + for model_id, model_info in category_dict.items(): + model_id_list.append(model_id) + model_type_map[model_id] = model_info.model_type + category_map[model_id] = model_info.category.value + state_map[model_id] = model_info.state.value + return TShowModelsResp( status=resp_status, - modelIdList=list(self._model_info_map.keys()), - modelTypeMap=dict( - (model_id, model_info.model_type) - for model_id, model_info in self._model_info_map.items() - ), - categoryMap=dict( - (model_id, model_info.category.value) - for model_id, model_info in self._model_info_map.items() - ), - stateMap=dict( - (model_id, model_info.state.value) - for model_id, model_info in self._model_info_map.items() - ), + modelIdList=model_id_list, + modelTypeMap=model_type_map, + categoryMap=category_map, + stateMap=state_map, ) - def register_built_in_model(self, model_info: ModelInfo): - with self._lock_pool.get_lock(model_info.model_id).write_lock(): - self._model_info_map[model_info.model_id] = model_info + def delete_model(self, model_id: str) -> None: + # Use write lock to protect entire deletion process + with self._lock_pool.get_lock(model_id).write_lock(): + model_info = None + category_value = None + for cat_value, category_dict in self._models.items(): + if model_id in category_dict: + model_info = category_dict[model_id] + category_value = cat_value + break + + if not model_info: + logger.warning(f"Model {model_id} does not exist, cannot delete") + return + + if model_info.category == ModelCategory.BUILTIN: + raise BuiltInModelDeletionError(model_id) + model_info.state = ModelStates.DROPPING + model_path = os.path.join( + self._models_dir, model_info.category.value, model_id + ) + if model_path.exists(): + try: + shutil.rmtree(model_path) + logger.info(f"Deleted model directory: {model_path}") + except Exception as e: + logger.error(f"Failed to delete model directory {model_path}: {e}") + raise - def get_model_info(self, model_id: str) -> ModelInfo: - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return self._model_info_map[model_id] - else: - raise ValueError(f"Model {model_id} does not exist.") + if category_value and model_id in self._models[category_value]: + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from storage") - def update_model_state(self, model_id: str, state: ModelStates): - with self._lock_pool.get_lock(model_id).write_lock(): - if model_id in self._model_info_map: - self._model_info_map[model_id].state = state - else: - raise ValueError(f"Model {model_id} does not exist.") + return - def get_built_in_model_type(self, model_id: str) -> BuiltInModelType: + # ==================== Query Methods ==================== + + def get_model_info( + self, model_id: str, category: Optional[ModelCategory] = None + ) -> Optional[ModelInfo]: """ - Get the type of the model with the given model_id. + Get single model information - Args: - model_id (str): The ID of the model. + If category is specified, use model_id's lock + If category is not specified, need to traverse all dictionaries, use global lock + """ + if category: + # Category specified, only need to access specific dictionary, use model_id's lock + with self._lock_pool.get_lock(model_id).read_lock(): + return self._models[category.value].get(model_id) + else: + # Category not specified, need to traverse all dictionaries, use global lock + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return category_dict[model_id] + return None + + def get_model_infos( + self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None + ) -> List[ModelInfo]: + """ + Get model information list - Returns: - str: The type of the model. + Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure + For single model access, using model_id-based lock would be more efficient """ - with self._lock_pool.get_lock(model_id).read_lock(): - if model_id in self._model_info_map: - return get_built_in_model_type( - self._model_info_map[model_id].model_type - ) + matching_models = [] + + # For traversal operations, we need to protect the entire dictionary structure + # Use a special lock (using empty string as key) to protect the entire dictionary + with self._lock_pool.get_lock("").read_lock(): + if category and model_type: + for model_info in self._models[category.value].values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models + elif category: + return list(self._models[category.value].values()) + elif model_type: + for category_dict in self._models.values(): + for model_info in category_dict.values(): + if model_info.model_type == model_type: + matching_models.append(model_info) + return matching_models else: - raise ValueError(f"Model {model_id} does not exist.") + for category_dict in self._models.values(): + matching_models.extend(category_dict.values()) + return matching_models + + def is_model_registered(self, model_id: str) -> bool: + """Check if model is registered (search in _models)""" + # Lazy registration: if it's a Transformers model and not registered, register it first + if self.ensure_transformers_registered(model_id) is None: + return False + + with self._lock_pool.get_lock("").read_lock(): + for category_dict in self._models.values(): + if model_id in category_dict: + return True + return False + + def get_registered_models(self) -> List[str]: + """Get list of all registered model IDs""" + with self._lock_pool.get_lock("").read_lock(): + model_ids = [] + for category_dict in self._models.values(): + model_ids.extend(category_dict.keys()) + return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/__init__.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sktime/__init__.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json new file mode 100644 index 000000000000..1561124badd1 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/arima/config.json @@ -0,0 +1,25 @@ +{ + "model_type": "sktime", + "model_id": "arima", + "predict_length": 1, + "order": [1, 0, 0], + "seasonal_order": [0, 0, 0, 0], + "start_params": null, + "method": "lbfgs", + "maxiter": 50, + "suppress_warnings": false, + "out_of_sample_size": 0, + "scoring": "mse", + "scoring_args": null, + "trend": null, + "with_intercept": true, + "time_varying_regression": false, + "enforce_stationarity": true, + "enforce_invertibility": true, + "simple_differencing": false, + "measurement_error": false, + "mle_regression": true, + "hamilton_representation": false, + "concentrate_scale": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py new file mode 100644 index 000000000000..261de3c9abe7 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -0,0 +1,409 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Union + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) +from iotdb.ainode.core.log import Logger + +logger = Logger() + + +@dataclass +class AttributeConfig: + """Base class for attribute configuration""" + + name: str + default: Any + type: str # 'int', 'float', 'str', 'bool', 'list', 'tuple' + low: Union[int, float, None] = None + high: Union[int, float, None] = None + choices: List[str] = field(default_factory=list) + value_type: type = None # Element type for list and tuple + + def validate_value(self, value): + """Validate if the value meets the requirements""" + if self.type == "int": + if value is None: + return True # Allow None for optional int parameters + if not isinstance(value, int): + raise WrongAttributeTypeError(self.name, "int") + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "float": + if value is None: + return True # Allow None for optional float parameters + if not isinstance(value, (int, float)): + raise WrongAttributeTypeError(self.name, "float") + value = float(value) + if self.low is not None and self.high is not None: + if not (self.low <= value <= self.high): + raise NumericalRangeException(self.name, value, self.low, self.high) + elif self.type == "str": + if value is None: + return True # Allow None for optional str parameters + if not isinstance(value, str): + raise WrongAttributeTypeError(self.name, "str") + if self.choices and value not in self.choices: + raise StringRangeException(self.name, value, self.choices) + elif self.type == "bool": + if value is None: + return True # Allow None for optional bool parameters + if not isinstance(value, bool): + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + if not isinstance(value, list): + raise WrongAttributeTypeError(self.name, "list") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + elif self.type == "tuple": + if not isinstance(value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + for item in value: + if not isinstance(item, self.value_type): + raise WrongAttributeTypeError(self.name, self.value_type) + return True + + def parse(self, string_value: str): + """Parse string value to corresponding type""" + if self.type == "int": + if string_value.lower() == "none" or string_value.strip() == "": + return None + try: + return int(string_value) + except: + raise WrongAttributeTypeError(self.name, "int") + elif self.type == "float": + if string_value.lower() == "none" or string_value.strip() == "": + return None + try: + return float(string_value) + except: + raise WrongAttributeTypeError(self.name, "float") + elif self.type == "str": + if string_value.lower() == "none" or string_value.strip() == "": + return None + return string_value + elif self.type == "bool": + if string_value.lower() == "true": + return True + elif string_value.lower() == "false": + return False + elif string_value.lower() == "none" or string_value.strip() == "": + return None + else: + raise WrongAttributeTypeError(self.name, "bool") + elif self.type == "list": + try: + list_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "list") + if not isinstance(list_value, list): + raise WrongAttributeTypeError(self.name, "list") + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return list_value + elif self.type == "tuple": + try: + tuple_value = eval(string_value) + except: + raise WrongAttributeTypeError(self.name, "tuple") + if not isinstance(tuple_value, tuple): + raise WrongAttributeTypeError(self.name, "tuple") + list_value = list(tuple_value) + for i in range(len(list_value)): + try: + list_value[i] = self.value_type(list_value[i]) + except: + raise ListRangeException( + self.name, list_value, str(self.value_type) + ) + return tuple(list_value) + + +# Model configuration definitions - using concise dictionary format +MODEL_CONFIGS = { + "NAIVE_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "strategy": AttributeConfig( + "strategy", "last", "str", choices=["last", "mean", "drift"] + ), + "window_length": AttributeConfig("window_length", None, "int"), + "sp": AttributeConfig("sp", 1, "int", 1, 5000), + }, + "EXPONENTIAL_SMOOTHING": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "damped_trend": AttributeConfig("damped_trend", False, "bool"), + "initialization_method": AttributeConfig( + "initialization_method", + "estimated", + "str", + choices=["estimated", "heuristic", "legacy-heuristic", "known"], + ), + "optimized": AttributeConfig("optimized", True, "bool"), + "remove_bias": AttributeConfig("remove_bias", False, "bool"), + "use_brute": AttributeConfig("use_brute", False, "bool"), + }, + "ARIMA": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "order": AttributeConfig("order", (1, 0, 0), "tuple", value_type=int), + "seasonal_order": AttributeConfig( + "seasonal_order", (0, 0, 0, 0), "tuple", value_type=int + ), + "start_params": AttributeConfig("start_params", None, "str"), + "method": AttributeConfig( + "method", + "lbfgs", + "str", + choices=["lbfgs", "bfgs", "newton", "nm", "cg", "ncg", "powell"], + ), + "maxiter": AttributeConfig("maxiter", 50, "int", 1, 5000), + "suppress_warnings": AttributeConfig("suppress_warnings", False, "bool"), + "out_of_sample_size": AttributeConfig("out_of_sample_size", 0, "int", 0, 5000), + "scoring": AttributeConfig( + "scoring", + "mse", + "str", + choices=["mse", "mae", "rmse", "mape", "smape", "rmsle", "r2"], + ), + "scoring_args": AttributeConfig("scoring_args", None, "str"), + "trend": AttributeConfig("trend", None, "str"), + "with_intercept": AttributeConfig("with_intercept", True, "bool"), + "time_varying_regression": AttributeConfig( + "time_varying_regression", False, "bool" + ), + "enforce_stationarity": AttributeConfig("enforce_stationarity", True, "bool"), + "enforce_invertibility": AttributeConfig("enforce_invertibility", True, "bool"), + "simple_differencing": AttributeConfig("simple_differencing", False, "bool"), + "measurement_error": AttributeConfig("measurement_error", False, "bool"), + "mle_regression": AttributeConfig("mle_regression", True, "bool"), + "hamilton_representation": AttributeConfig( + "hamilton_representation", False, "bool" + ), + "concentrate_scale": AttributeConfig("concentrate_scale", False, "bool"), + }, + "STL_FORECASTER": { + "predict_length": AttributeConfig("predict_length", 1, "int", 1, 5000), + "sp": AttributeConfig("sp", 2, "int", 1, 5000), + "seasonal": AttributeConfig("seasonal", 7, "int", 1, 5000), + "trend": AttributeConfig("trend", None, "int"), + "low_pass": AttributeConfig("low_pass", None, "int"), + "seasonal_deg": AttributeConfig("seasonal_deg", 1, "int", 0, 5000), + "trend_deg": AttributeConfig("trend_deg", 1, "int", 0, 5000), + "low_pass_deg": AttributeConfig("low_pass_deg", 1, "int", 0, 5000), + "robust": AttributeConfig("robust", False, "bool"), + "seasonal_jump": AttributeConfig("seasonal_jump", 1, "int", 0, 5000), + "trend_jump": AttributeConfig("trend_jump", 1, "int", 0, 5000), + "low_pass_jump": AttributeConfig("low_pass_jump", 1, "int", 0, 5000), + "inner_iter": AttributeConfig("inner_iter", None, "int"), + "outer_iter": AttributeConfig("outer_iter", None, "int"), + "forecaster_trend": AttributeConfig("forecaster_trend", None, "str"), + "forecaster_seasonal": AttributeConfig("forecaster_seasonal", None, "str"), + "forecaster_resid": AttributeConfig("forecaster_resid", None, "str"), + }, + "GAUSSIAN_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", 0.01, "float", -1e10, 1e10), + "covars_weight": AttributeConfig("covars_weight", 1, "float", -1e10, 1e10), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "random_state": AttributeConfig("random_state", None, "float"), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), + "params": AttributeConfig("params", "stmc", "str", choices=["stmc", "stm"]), + "init_params": AttributeConfig( + "init_params", "stmc", "str", choices=["stmc", "stm"] + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "GMM_HMM": { + "n_components": AttributeConfig("n_components", 1, "int", 1, 5000), + "n_mix": AttributeConfig("n_mix", 1, "int", 1, 5000), + "min_covar": AttributeConfig("min_covar", 1e-3, "float", -1e10, 1e10), + "startprob_prior": AttributeConfig( + "startprob_prior", 1.0, "float", -1e10, 1e10 + ), + "transmat_prior": AttributeConfig("transmat_prior", 1.0, "float", -1e10, 1e10), + "weights_prior": AttributeConfig("weights_prior", 1.0, "float", -1e10, 1e10), + "means_prior": AttributeConfig("means_prior", 0.0, "float", -1e10, 1e10), + "means_weight": AttributeConfig("means_weight", 0.0, "float", -1e10, 1e10), + "covars_prior": AttributeConfig("covars_prior", None, "float"), + "covars_weight": AttributeConfig("covars_weight", None, "float"), + "algorithm": AttributeConfig( + "algorithm", "viterbi", "str", choices=["viterbi", "map"] + ), + "covariance_type": AttributeConfig( + "covariance_type", + "diag", + "str", + choices=["spherical", "diag", "full", "tied"], + ), + "random_state": AttributeConfig("random_state", None, "int"), + "n_iter": AttributeConfig("n_iter", 10, "int", 1, 5000), + "tol": AttributeConfig("tol", 1e-2, "float", -1e10, 1e10), + "verbose": AttributeConfig("verbose", False, "bool"), + "init_params": AttributeConfig( + "init_params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "params": AttributeConfig( + "params", + "stmcw", + "str", + choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], + ), + "implementation": AttributeConfig( + "implementation", "log", "str", choices=["log", "scaling"] + ), + }, + "STRAY": { + "alpha": AttributeConfig("alpha", 0.01, "float", -1e10, 1e10), + "k": AttributeConfig("k", 10, "int", 1, 5000), + "knn_algorithm": AttributeConfig( + "knn_algorithm", + "brute", + "str", + choices=["brute", "kd_tree", "ball_tree", "auto"], + ), + "p": AttributeConfig("p", 0.5, "float", -1e10, 1e10), + "size_threshold": AttributeConfig("size_threshold", 50, "int", 1, 5000), + "outlier_tail": AttributeConfig( + "outlier_tail", "max", "str", choices=["min", "max"] + ), + }, +} + + +def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: + """Get attribute configuration for Sktime model""" + model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id + if model_id not in MODEL_CONFIGS: + raise BuiltInModelNotSupportError(model_id) + return MODEL_CONFIGS[model_id] + + +def update_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, AttributeConfig] +) -> Dict[str, Any]: + """Update Sktime model attributes using input attributes""" + attributes = {} + for name, config in attribute_map.items(): + if name in input_attributes: + value = config.parse(input_attributes[name]) + config.validate_value(value) + attributes[name] = value + else: + attributes[name] = config.default + return attributes diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json new file mode 100644 index 000000000000..4126d9de857a --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/exponential_smoothing/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "model_id": "exponential_smoothing", + "predict_length": 1, + "damped_trend": false, + "initialization_method": "estimated", + "optimized": true, + "remove_bias": false, + "use_brute": false +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json new file mode 100644 index 000000000000..94f7d7ec659f --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gaussian_hmm/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "model_id": "gaussian_hmm", + "n_components": 1, + "covariance_type": "diag", + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "means_prior": 0, + "means_weight": 0, + "covars_prior": 0.01, + "covars_weight": 1, + "algorithm": "viterbi", + "random_state": null, + "n_iter": 10, + "tol": 0.01, + "verbose": false, + "params": "stmc", + "init_params": "stmc", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json new file mode 100644 index 000000000000..fb19d1aaf86d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/gmm_hmm/config.json @@ -0,0 +1,24 @@ +{ + "model_type": "sktime", + "model_id": "gmm_hmm", + "n_components": 1, + "n_mix": 1, + "min_covar": 0.001, + "startprob_prior": 1.0, + "transmat_prior": 1.0, + "weights_prior": 1.0, + "means_prior": 0.0, + "means_weight": 0.0, + "covars_prior": null, + "covars_weight": null, + "algorithm": "viterbi", + "covariance_type": "diag", + "random_state": null, + "n_iter": 10, + "tol": 0.01, + "verbose": false, + "init_params": "stmcw", + "params": "stmcw", + "implementation": "log" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py new file mode 100644 index 000000000000..eca812d35ec9 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -0,0 +1,180 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from abc import abstractmethod +from typing import Any, Dict + +import numpy as np +import pandas as pd +from sklearn.preprocessing import MinMaxScaler +from sktime.detection.hmm_learn import GMMHMM, GaussianHMM +from sktime.detection.stray import STRAY +from sktime.forecasting.arima import ARIMA +from sktime.forecasting.exp_smoothing import ExponentialSmoothing +from sktime.forecasting.naive import NaiveForecaster +from sktime.forecasting.trend import STLForecaster + +from iotdb.ainode.core.exception import ( + BuiltInModelNotSupportError, + InferenceModelInternalError, +) +from iotdb.ainode.core.log import Logger + +from .configuration_sktime import get_attributes, update_attribute + +logger = Logger() + + +class SktimeModel: + """Base class for Sktime models""" + + def __init__(self, attributes: Dict[str, Any]): + self._attributes = attributes + self._model = None + + @abstractmethod + def generate(self, data, **kwargs): + """Execute generation/inference""" + raise NotImplementedError + + +class ForecastingModel(SktimeModel): + """Base class for forecasting models""" + + def generate(self, data, **kwargs): + """Execute forecasting""" + try: + predict_length = kwargs.get( + "predict_length", self._attributes["predict_length"] + ) + self._model.fit(data) + output = self._model.predict(fh=range(predict_length)) + return np.array(output, dtype=np.float64) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class DetectionModel(SktimeModel): + """Base class for detection models""" + + def generate(self, data, **kwargs): + """Execute detection""" + try: + predict_length = kwargs.get("predict_length", data.size) + output = self._model.fit_transform(data[:predict_length]) + if isinstance(output, pd.DataFrame): + return np.array(output["labels"], dtype=np.int32) + else: + return np.array(output, dtype=np.int32) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +class ArimaModel(ForecastingModel): + """ARIMA model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ARIMA( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class ExponentialSmoothingModel(ForecastingModel): + """Exponential smoothing model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = ExponentialSmoothing( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class NaiveForecasterModel(ForecastingModel): + """Naive forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = NaiveForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class STLForecasterModel(ForecastingModel): + """STL forecaster model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STLForecaster( + **{k: v for k, v in attributes.items() if k != "predict_length"} + ) + + +class GMMHMMModel(DetectionModel): + """GMM HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GMMHMM(**attributes) + + +class GaussianHmmModel(DetectionModel): + """Gaussian HMM model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = GaussianHMM(**attributes) + + +class STRAYModel(DetectionModel): + """STRAY anomaly detection model""" + + def __init__(self, attributes: Dict[str, Any]): + super().__init__(attributes) + self._model = STRAY(**{k: v for k, v in attributes.items() if v is not None}) + + def generate(self, data, **kwargs): + """STRAY requires special handling: normalize first""" + try: + scaled_data = MinMaxScaler().fit_transform(data.values.reshape(-1, 1)) + scaled_data = pd.Series(scaled_data.flatten()) + return super().generate(scaled_data, **kwargs) + except Exception as e: + raise InferenceModelInternalError(str(e)) + + +# Model factory mapping +_MODEL_FACTORY = { + "ARIMA": ArimaModel, + "EXPONENTIAL_SMOOTHING": ExponentialSmoothingModel, + "HOLTWINTERS": ExponentialSmoothingModel, # Use the same model class + "NAIVE_FORECASTER": NaiveForecasterModel, + "STL_FORECASTER": STLForecasterModel, + "GMM_HMM": GMMHMMModel, + "GAUSSIAN_HMM": GaussianHmmModel, + "STRAY": STRAYModel, +} + + +def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: + """Create a Sktime model instance""" + attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) + model_class = _MODEL_FACTORY.get(model_id.upper()) + if model_class is None: + raise BuiltInModelNotSupportError(model_id) + return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json new file mode 100644 index 000000000000..3dadd7c3b1e5 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/naive_forecaster/config.json @@ -0,0 +1,9 @@ +{ + "model_type": "sktime", + "model_id": "naive_forecaster", + "predict_length": 1, + "strategy": "last", + "window_length": null, + "sp": 1 +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py new file mode 100644 index 000000000000..ced21f29a2b8 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/pipeline_sktime.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import numpy as np +import pandas as pd +import torch + +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline + + +class SktimePipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + model_kwargs.pop("device", None) # sktime models run on CPU + super().__init__(model_info, model_kwargs=model_kwargs) + + def _preprocess(self, inputs): + return inputs + + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + input_ids = self._preprocess(inputs) + + # Convert to pandas Series for sktime (sktime expects Series or DataFrame) + # Handle batch dimension: if batch_size > 1, process each sample separately + if len(input_ids.shape) == 2 and input_ids.shape[0] > 1: + # Batch processing: convert each row to Series + outputs = [] + for i in range(input_ids.shape[0]): + series = pd.Series( + input_ids[i].cpu().numpy() + if isinstance(input_ids, torch.Tensor) + else input_ids[i] + ) + output = self.model.generate(series, predict_length=predict_length) + outputs.append(output) + output = np.array(outputs) + else: + # Single sample: convert to Series + if isinstance(input_ids, torch.Tensor): + series = pd.Series(input_ids.squeeze().cpu().numpy()) + else: + series = pd.Series(input_ids.squeeze()) + output = self.model.generate(series, predict_length=predict_length) + # Add batch dimension if needed + if len(output.shape) == 1: + output = output[np.newaxis, :] + + return self._postprocess(output) + + def _postprocess(self, output): + if isinstance(output, np.ndarray): + return torch.from_numpy(output).float() + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json new file mode 100644 index 000000000000..bfe71dbc4861 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stl_forecaster/config.json @@ -0,0 +1,22 @@ +{ + "model_type": "sktime", + "model_id": "stl_forecaster", + "predict_length": 1, + "sp": 2, + "seasonal": 7, + "trend": null, + "low_pass": null, + "seasonal_deg": 1, + "trend_deg": 1, + "low_pass_deg": 1, + "robust": false, + "seasonal_jump": 1, + "trend_jump": 1, + "low_pass_jump": 1, + "inner_iter": null, + "outer_iter": null, + "forecaster_trend": null, + "forecaster_seasonal": null, + "forecaster_resid": null +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json new file mode 100644 index 000000000000..e5bcc03cd071 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/stray/config.json @@ -0,0 +1,11 @@ +{ + "model_type": "sktime", + "model_id": "stray", + "alpha": 0.01, + "k": 10, + "knn_algorithm": "brute", + "p": 0.5, + "size_threshold": 50, + "outlier_tail": "max" +} + diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py index 3ebf516f705e..dc1de32506e5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/modeling_sundial.py @@ -16,13 +16,10 @@ # under the License. # -import os -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F -from huggingface_hub import hf_hub_download -from safetensors.torch import load_file as load_safetensors from torch import nn from transformers import Cache, DynamicCache, PreTrainedModel from transformers.activations import ACT2FN @@ -32,13 +29,10 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig from iotdb.ainode.core.model.sundial.flow_loss import FlowLoss from iotdb.ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin -logger = Logger() - def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] @@ -616,11 +610,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): @@ -633,10 +623,9 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - token_num = ( - input_ids.shape[1] + self.config.input_token_len - 1 - ) // self.config.input_token_len - position_ids = position_ids[:, -token_num:] + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py similarity index 56% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 17c88e32fb5a..85b6f7db2ffe 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timer_sundial_inference_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -19,33 +19,33 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.sundial.configuration_sundial import SundialConfig +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class TimerSundialInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-Sundial model inference. - """ +class SundialPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) - def __init__(self, model_config: SundialConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) + def _preprocess(self, inputs): if len(inputs.shape) != 2: raise InferenceModelInternalError( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py return inputs - def post_decode(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass - - def post_inference(self): - # TODO: Disassemble and adapt with Sundial's ts_generation_mixin.py - pass + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + num_samples = infer_kwargs.get("num_samples", 10) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, + max_new_tokens=predict_length, + num_samples=num_samples, + revin=revin, + ) + return self._postprocess(output) + + def _postprocess(self, output: torch.Tensor): + return output.mean(dim=1) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py new file mode 100644 index 000000000000..2a1e720805f2 --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/configuration_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/configuration_timer.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py similarity index 98% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py index 0a33c682742a..fc9d7b41388b 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/modeling_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/modeling_timer.py @@ -16,7 +16,7 @@ # under the License. # -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F @@ -29,11 +29,8 @@ MoeModelOutputWithPast, ) -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig -from iotdb.ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin - -logger = Logger() +from iotdb.ainode.core.model.timer_xl.configuration_timer import TimerConfig +from iotdb.ainode.core.model.timer_xl.ts_generation_mixin import TSGenerationMixin def rotate_half(x): @@ -606,11 +603,7 @@ def prepare_inputs_for_generation( if attention_mask is not None and attention_mask.shape[1] > ( input_ids.shape[1] // self.config.input_token_len ): - input_ids = input_ids[ - :, - -(attention_mask.shape[1] - past_length) - * self.config.input_token_len :, - ] + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py similarity index 52% rename from iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index dc1dd304f68e..c0f00b1f5caf 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/strategy/timerxl_inference_pipeline.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -19,33 +19,29 @@ import torch from iotdb.ainode.core.exception import InferenceModelInternalError -from iotdb.ainode.core.inference.strategy.abstract_inference_pipeline import ( - AbstractInferencePipeline, -) -from iotdb.ainode.core.model.timerxl.configuration_timer import TimerConfig +from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline -class TimerXLInferencePipeline(AbstractInferencePipeline): - """ - Strategy for Timer-XL model inference. - """ +class TimerPipeline(ForecastPipeline): + def __init__(self, model_info, **model_kwargs): + super().__init__(model_info, model_kwargs=model_kwargs) - def __init__(self, model_config: TimerConfig, **infer_kwargs): - super().__init__(model_config, infer_kwargs=infer_kwargs) - - def preprocess_inputs(self, inputs: torch.Tensor): - super().preprocess_inputs(inputs) + def _preprocess(self, inputs): if len(inputs.shape) != 2: raise InferenceModelInternalError( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) - # Considering that we are currently using the generate function interface, it seems that no pre-processing is required return inputs - def post_decode(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass + def forecast(self, inputs, **infer_kwargs): + predict_length = infer_kwargs.get("predict_length", 96) + revin = infer_kwargs.get("revin", True) + + input_ids = self._preprocess(inputs) + output = self.model.generate( + input_ids, max_new_tokens=predict_length, revin=revin + ) + return self._postprocess(output) - def post_inference(self): - # Considering that we are currently using the generate function interface, it seems that no post-processing is required - pass + def _postprocess(self, output: torch.Tensor): + return output diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py similarity index 100% rename from iotdb-core/ainode/iotdb/ainode/core/model/timerxl/ts_generation_mixin.py rename to iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/ts_generation_mixin.py diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py deleted file mode 100644 index b2e759e00ce0..000000000000 --- a/iotdb-core/ainode/iotdb/ainode/core/model/uri_utils.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# -import os -from enum import Enum -from typing import List - -from huggingface_hub import snapshot_download -from requests import Session -from requests.adapters import HTTPAdapter - -from iotdb.ainode.core.constant import ( - DEFAULT_CHUNK_SIZE, - DEFAULT_RECONNECT_TIMEOUT, - DEFAULT_RECONNECT_TIMES, -) -from iotdb.ainode.core.exception import UnsupportedError -from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_enums import ModelFileType -from iotdb.ainode.core.model.model_info import get_model_file_type - -HTTP_PREFIX = "http://" -HTTPS_PREFIX = "https://" - -logger = Logger() - - -class UriType(Enum): - REPO = "repo" - FILE = "file" - HTTP = "http" - HTTPS = "https" - - @classmethod - def values(cls) -> List[str]: - return [item.value for item in cls] - - @staticmethod - def parse_uri_type(uri: str): - """ - Parse the URI type from the given string. - """ - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): - return UriType.FILE - elif uri.startswith("http://"): - return UriType.HTTP - elif uri.startswith("https://"): - return UriType.HTTPS - else: - raise ValueError(f"Invalid URI type for {uri}") - - -def get_model_register_strategy(uri: str): - """ - Determine the loading strategy for a model based on its URI/path. - - Args: - uri (str): The URI of the model to be registered. - - Returns: - uri_type (UriType): The type of the URI, which can be one of: REPO, FILE, HTTP, or HTTPS. - parsed_uri (str): Parsed uri to get related file - model_file_type (ModelFileType): The type of the model file, which can be one of: SAFETENSORS, PYTORCH, or UNKNOWN. - """ - - uri_type = UriType.parse_uri_type(uri) - if uri_type in (UriType.HTTP, UriType.HTTPS): - # TODO: support HTTP(S) URI - raise UnsupportedError("CREATE MODEL FROM HTTP(S) URI") - else: - parsed_uri = uri[7:] - if uri_type == UriType.FILE: - # handle ~ in URI - parsed_uri = os.path.expanduser(parsed_uri) - model_file_type = get_model_file_type(uri) - elif uri_type == UriType.REPO: - # Currently, UriType.REPO only corresponds to huggingface repository with SAFETENSORS format - model_file_type = ModelFileType.SAFETENSORS - else: - raise ValueError(f"Invalid URI type for {uri}") - return uri_type, parsed_uri, model_file_type - - -def download_snapshot_from_hf(repo_id: str, local_dir: str): - """ - Download everything from a HuggingFace repository. - - Args: - repo_id (str): The HuggingFace repository ID. - local_dir (str): The local directory to save the downloaded files. - """ - try: - snapshot_download( - repo_id=repo_id, - local_dir=local_dir, - ) - except Exception as e: - logger.error(f"Failed to download HuggingFace model {repo_id}: {e}") - raise e - - -def download_file(url: str, storage_path: str) -> None: - """ - Args: - url: url of file to download - storage_path: path to save the file - Returns: - None - """ - logger.info(f"Start Downloading file from {url} to {storage_path}") - session = Session() - adapter = HTTPAdapter(max_retries=DEFAULT_RECONNECT_TIMES) - session.mount(HTTP_PREFIX, adapter) - session.mount(HTTPS_PREFIX, adapter) - response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) - response.raise_for_status() - with open(storage_path, "wb") as file: - for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): - if chunk: - file.write(chunk) - logger.info(f"Download file from {url} to {storage_path} success") diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py new file mode 100644 index 000000000000..1cd0ee44912d --- /dev/null +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -0,0 +1,98 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import importlib +import json +import os.path +import sys +from contextlib import contextmanager +from typing import Dict, Tuple + +from iotdb.ainode.core.model.model_constants import ( + MODEL_CONFIG_FILE_IN_JSON, + MODEL_WEIGHTS_FILE_IN_SAFETENSORS, + UriType, +) + + +def parse_uri_type(uri: str) -> UriType: + if uri.startswith("repo://"): + return UriType.REPO + elif uri.startswith("file://"): + return UriType.FILE + else: + raise ValueError( + f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + ) + + +def get_parsed_uri(uri: str) -> str: + return uri[7:] # Remove "repo://" or "file://" prefix + + +@contextmanager +def temporary_sys_path(path: str): + """Context manager for temporarily adding a path to sys.path""" + path_added = path not in sys.path + if path_added: + sys.path.insert(0, path) + try: + yield + finally: + if path_added and path in sys.path: + sys.path.remove(path) + + +def load_model_config_in_json(config_path: str) -> Dict: + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + + +def validate_model_files(model_dir: str) -> Tuple[str, str]: + """Validate model files exist, return config and weights file paths""" + + config_path = os.path.join(model_dir, MODEL_CONFIG_FILE_IN_JSON) + weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) + + if not os.path.exists(config_path): + raise ValueError(f"Model config file does not exist: {config_path}") + if not os.path.exists(weights_path): + raise ValueError(f"Model weights file does not exist: {weights_path}") + + # Create __init__.py file to ensure model directory can be imported as a module + init_file = os.path.join(model_dir, "__init__.py") + if not os.path.exists(init_file): + with open(init_file, "w"): + pass + + return config_path, weights_path + + +def import_class_from_path(module_name, class_path: str): + file_name, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_name + "." + file_name) + return getattr(module, class_name) + + +def ensure_init_file(dir_path: str): + """Ensure __init__.py file exists in the given dir path""" + init_file = os.path.join(dir_path, "__init__.py") + os.makedirs(dir_path, exist_ok=True) + if not os.path.exists(init_file): + with open(init_file, "w"): + pass diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py index e2be6459508b..ea6362ef080a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/client.py @@ -38,7 +38,6 @@ TAINodeRemoveReq, TAINodeRestartReq, TNodeVersionInfo, - TUpdateModelInfoReq, ) logger = Logger() @@ -155,13 +154,8 @@ def _wait_and_reconnect(self) -> None: self._try_to_connect() except TException: # can not connect to each config node - self._sync_latest_config_node_list() self._try_to_connect() - def _sync_latest_config_node_list(self) -> None: - # TODO - pass - def _update_config_node_leader(self, status: TSStatus) -> bool: if status.code == TSStatusCode.REDIRECTION_RECOMMEND.get_status_code(): if status.redirectNode is not None: @@ -271,36 +265,3 @@ def get_ainode_configuration(self, node_id: int) -> map: self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) - - def update_model_info( - self, - model_id: str, - model_status: int, - attribute: str = "", - ainode_id=None, - input_length=0, - output_length=0, - ) -> None: - if ainode_id is None: - ainode_id = [] - for _ in range(0, self._RETRY_NUM): - try: - req = TUpdateModelInfoReq(model_id, model_status, attribute) - if ainode_id is not None: - req.aiNodeIds = ainode_id - req.inputLength = input_length - req.outputLength = output_length - status = self._client.updateModelInfo(req) - if not self._update_config_node_leader(status): - verify_success( - status, "An error occurs when calling update model info" - ) - return status - except TTransport.TException: - logger.warning( - "Failed to connect to ConfigNode {} from AINode when executing update model info", - self._config_leader, - ) - self._config_leader = None - self._wait_and_reconnect() - raise TException(self._MSG_RECONNECTION_FAIL) diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index f01e1594f069..6c4eedeb99f7 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -29,6 +29,7 @@ TAIHeartbeatResp, TDeleteModelReq, TForecastReq, + TForecastResp, TInferenceReq, TInferenceResp, TLoadModelReq, @@ -78,8 +79,14 @@ def stopAINode(self) -> TSStatus: def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) + def deleteModel(self, req: TDeleteModelReq) -> TSStatus: + return self._model_manager.delete_model(req) + + def showModels(self, req: TShowModelsReq) -> TShowModelsResp: + return self._model_manager.show_models(req) + def loadModel(self, req: TLoadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.existingModelId) + status = self._ensure_model_is_registered(req.existingModelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -88,7 +95,7 @@ def loadModel(self, req: TLoadModelReq) -> TSStatus: return self._inference_manager.load_model(req) def unloadModel(self, req: TUnloadModelReq) -> TSStatus: - status = self._ensure_model_is_built_in_or_fine_tuned(req.modelId) + status = self._ensure_model_is_registered(req.modelId) if status.code != TSStatusCode.SUCCESS_STATUS.value: return status status = _ensure_device_id_is_available(req.deviceIdList) @@ -96,21 +103,6 @@ def unloadModel(self, req: TUnloadModelReq) -> TSStatus: return status return self._inference_manager.unload_model(req) - def deleteModel(self, req: TDeleteModelReq) -> TSStatus: - return self._model_manager.delete_model(req) - - def inference(self, req: TInferenceReq) -> TInferenceResp: - return self._inference_manager.inference(req) - - def forecast(self, req: TForecastReq) -> TSStatus: - return self._inference_manager.forecast(req) - - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) - - def showModels(self, req: TShowModelsReq) -> TShowModelsResp: - return self._model_manager.show_models(req) - def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: status = _ensure_device_id_is_available(req.deviceIdList) if status.code != TSStatusCode.SUCCESS_STATUS.value: @@ -123,13 +115,28 @@ def showAIDevices(self) -> TShowAIDevicesResp: deviceIdList=get_available_devices(), ) + def inference(self, req: TInferenceReq) -> TInferenceResp: + status = self._ensure_model_is_registered(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TInferenceResp(status, []) + return self._inference_manager.inference(req) + + def forecast(self, req: TForecastReq) -> TForecastResp: + status = self._ensure_model_is_registered(req.modelId) + if status.code != TSStatusCode.SUCCESS_STATUS.value: + return TForecastResp(status, []) + return self._inference_manager.forecast(req) + + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + def createTrainingTask(self, req: TTrainingReq) -> TSStatus: pass - def _ensure_model_is_built_in_or_fine_tuned(self, model_id: str) -> TSStatus: - if not self._model_manager.is_built_in_or_fine_tuned(model_id): + def _ensure_model_is_registered(self, model_id: str) -> TSStatus: + if not self._model_manager.is_model_registered(model_id): return TSStatus( code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not a built-in or fine-tuned model. You can use 'SHOW MODELS' to retrieve the available models.", + message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index 331cb8ab32a3..e93bb3dfdaf1 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -79,7 +79,7 @@ exclude = [ python = ">=3.11.0,<3.14.0" # ---- DL / HF stack ---- -torch = ">=2.7.0" +torch = "^2.8.0,<2.9.0" torchmetrics = "^1.8.0" transformers = "==4.56.2" tokenizers = ">=0.22.0,<=0.23.0" @@ -93,7 +93,10 @@ scipy = "^1.12.0" pandas = "^2.3.2" scikit-learn = "^1.7.1" statsmodels = "^0.14.5" -sktime = "0.38.5" +sktime = "0.40.1" +pmdarima = "2.1.1" +hmmlearn = "0.3.2" +accelerate = "^1.10.1" # ---- Optimizers / utils ---- optuna = "^4.4.0" @@ -112,7 +115,7 @@ black = "25.1.0" isort = "6.0.1" setuptools = ">=75.3.0" joblib = ">=1.4.2" -urllib3 = ">=2.2.3" +urllib3 = "2.6.0" [tool.poetry.scripts] ainode = "iotdb.ainode.core.script:main" diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java index 2721fedafb1e..8d9081f43527 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncAINodeHeartbeatClientPool.java @@ -21,21 +21,28 @@ import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.ClientPoolFactory; import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.confignode.client.async.handlers.heartbeat.AINodeHeartbeatHandler; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; +/** Asynchronously send RPC requests to AINodes. */ public class AsyncAINodeHeartbeatClientPool { - private final IClientManager clientManager; + private final IClientManager clientManager; private AsyncAINodeHeartbeatClientPool() { clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeHeartbeatClientPoolFactory()); + new IClientManager.Factory() + .createClientManager( + new ClientPoolFactory.AsyncAINodeHeartbeatServiceClientPoolFactory()); } + /** + * Only used in LoadManager. + * + * @param endPoint The specific DataNode + */ public void getAINodeHeartBeat( TEndPoint endPoint, TAIHeartbeatReq req, AINodeHeartbeatHandler handler) { try { @@ -56,6 +63,6 @@ private AsyncAINodeHeartbeatClientPoolHolder() { } public static AsyncAINodeHeartbeatClientPool getInstance() { - return AsyncAINodeHeartbeatClientPool.AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; + return AsyncAINodeHeartbeatClientPoolHolder.INSTANCE; } } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java index ccc19f1a9f38..324e35130278 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/client/async/AsyncDataNodeHeartbeatClientPool.java @@ -63,7 +63,6 @@ public void writeAuditLog( } } - // TODO: Is the AsyncDataNodeHeartbeatClientPool must be a singleton? private static class AsyncDataNodeHeartbeatClientPoolHolder { private static final AsyncDataNodeHeartbeatClientPool INSTANCE = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java index e0b2c144c0ea..65c1ee0a9fed 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/ConfigPhysicalPlan.java @@ -21,8 +21,6 @@ import org.apache.iotdb.commons.exception.runtime.SerializationRunTimeException; import org.apache.iotdb.confignode.consensus.request.read.ainode.GetAINodeConfigurationPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.subscription.ShowTopicPlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RegisterAINodePlan; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; @@ -52,10 +50,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -574,24 +568,6 @@ public static ConfigPhysicalPlan create(final ByteBuffer buffer) throws IOExcept case UPDATE_CQ_LAST_EXEC_TIME: plan = new UpdateCQLastExecTimePlan(); break; - case CreateModel: - plan = new CreateModelPlan(); - break; - case UpdateModelInfo: - plan = new UpdateModelInfoPlan(); - break; - case DropModel: - plan = new DropModelPlan(); - break; - case ShowModel: - plan = new ShowModelPlan(); - break; - case DropModelInNode: - plan = new DropModelInNodePlan(); - break; - case GetModelInfo: - plan = new GetModelInfoPlan(); - break; case CreatePipePlugin: plan = new CreatePipePluginPlan(); break; diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java deleted file mode 100644 index dd79910e51fa..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/GetModelInfoPlan.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; - -import java.util.Objects; - -public class GetModelInfoPlan extends ConfigPhysicalReadPlan { - - private String modelId; - - public GetModelInfoPlan() { - super(ConfigPhysicalPlanType.GetModelInfo); - } - - public GetModelInfoPlan(final TGetModelInfoReq getModelInfoReq) { - super(ConfigPhysicalPlanType.GetModelInfo); - this.modelId = getModelInfoReq.getModelId(); - } - - public String getModelId() { - return modelId; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final GetModelInfoPlan that = (GetModelInfoPlan) o; - return Objects.equals(modelId, that.modelId); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java deleted file mode 100644 index eca00e8827d9..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/read/model/ShowModelPlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.read.model; - -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; -import org.apache.iotdb.confignode.consensus.request.read.ConfigPhysicalReadPlan; - -import java.util.Objects; - -public class ShowModelPlan extends ConfigPhysicalReadPlan { - - private String modelName; - - public ShowModelPlan() { - super(ConfigPhysicalPlanType.ShowModel); - } - - public ShowModelPlan(final TShowModelsReq showModelReq) { - super(ConfigPhysicalPlanType.ShowModel); - if (showModelReq.isSetModelId()) { - this.modelName = showModelReq.getModelId(); - } - } - - public boolean isSetModelName() { - return modelName != null; - } - - public String getModelName() { - return modelName; - } - - @Override - public boolean equals(final Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - final ShowModelPlan that = (ShowModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java deleted file mode 100644 index 61e37cdd2187..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/CreateModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class CreateModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public CreateModelPlan() { - super(ConfigPhysicalPlanType.CreateModel); - } - - public CreateModelPlan(String modelName) { - super(ConfigPhysicalPlanType.CreateModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - CreateModelPlan that = (CreateModelPlan) o; - return Objects.equals(modelName, that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java deleted file mode 100644 index 885543f84e15..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelInNodePlan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelInNodePlan extends ConfigPhysicalPlan { - - private int nodeId; - - public DropModelInNodePlan() { - super(ConfigPhysicalPlanType.DropModelInNode); - } - - public DropModelInNodePlan(int nodeId) { - super(ConfigPhysicalPlanType.DropModelInNode); - this.nodeId = nodeId; - } - - public int getNodeId() { - return nodeId; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - stream.writeInt(nodeId); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - nodeId = buffer.getInt(); - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (!(o instanceof DropModelInNodePlan)) return false; - DropModelInNodePlan that = (DropModelInNodePlan) o; - return nodeId == that.nodeId; - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), nodeId); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java deleted file mode 100644 index 813b116c645c..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/DropModelPlan.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.Objects; - -public class DropModelPlan extends ConfigPhysicalPlan { - - private String modelName; - - public DropModelPlan() { - super(ConfigPhysicalPlanType.DropModel); - } - - public DropModelPlan(String modelName) { - super(ConfigPhysicalPlanType.DropModel); - this.modelName = modelName; - } - - public String getModelName() { - return modelName; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - modelName = ReadWriteIOUtils.readString(buffer); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - DropModelPlan that = (DropModelPlan) o; - return modelName.equals(that.modelName); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java deleted file mode 100644 index ce7219e42813..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/request/write/model/UpdateModelInfoPlan.java +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.request.write.model; - -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlan; -import org.apache.iotdb.confignode.consensus.request.ConfigPhysicalPlanType; - -import org.apache.tsfile.utils.ReadWriteIOUtils; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Objects; - -public class UpdateModelInfoPlan extends ConfigPhysicalPlan { - - private String modelName; - private ModelInformation modelInformation; - - // The node which has the model which is only updated in model registration - private List nodeIds; - - public UpdateModelInfoPlan() { - super(ConfigPhysicalPlanType.UpdateModelInfo); - } - - public UpdateModelInfoPlan(String modelName, ModelInformation modelInformation) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = Collections.emptyList(); - } - - public UpdateModelInfoPlan( - String modelName, ModelInformation modelInformation, List nodeIds) { - super(ConfigPhysicalPlanType.UpdateModelInfo); - this.modelName = modelName; - this.modelInformation = modelInformation; - this.nodeIds = nodeIds; - } - - public String getModelName() { - return modelName; - } - - public ModelInformation getModelInformation() { - return modelInformation; - } - - public List getNodeIds() { - return nodeIds; - } - - public void setNodeIds(List nodeIds) { - this.nodeIds = nodeIds; - } - - @Override - protected void serializeImpl(DataOutputStream stream) throws IOException { - stream.writeShort(getType().getPlanType()); - ReadWriteIOUtils.write(modelName, stream); - this.modelInformation.serialize(stream); - ReadWriteIOUtils.write(nodeIds.size(), stream); - for (Integer nodeId : nodeIds) { - ReadWriteIOUtils.write(nodeId, stream); - } - } - - @Override - protected void deserializeImpl(ByteBuffer buffer) throws IOException { - this.modelName = ReadWriteIOUtils.readString(buffer); - this.modelInformation = ModelInformation.deserialize(buffer); - int size = ReadWriteIOUtils.readInt(buffer); - this.nodeIds = new ArrayList<>(); - for (int i = 0; i < size; i++) { - this.nodeIds.add(ReadWriteIOUtils.readInt(buffer)); - } - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - if (!super.equals(o)) { - return false; - } - UpdateModelInfoPlan that = (UpdateModelInfoPlan) o; - return modelName.equals(that.modelName) - && modelInformation.equals(that.modelInformation) - && nodeIds.equals(that.nodeIds); - } - - @Override - public int hashCode() { - return Objects.hash(super.hashCode(), modelName, modelInformation, nodeIds); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java deleted file mode 100644 index cebc1301b891..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/GetModelInfoResp.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.consensus.common.DataSet; - -public class GetModelInfoResp implements DataSet { - - private final TSStatus status; - - private int targetAINodeId; - private TEndPoint targetAINodeAddress; - - public TSStatus getStatus() { - return status; - } - - public GetModelInfoResp(TSStatus status) { - this.status = status; - } - - public int getTargetAINodeId() { - return targetAINodeId; - } - - public void setTargetAINodeId(int targetAINodeId) { - this.targetAINodeId = targetAINodeId; - } - - public void setTargetAINodeAddress(TAINodeConfiguration aiNodeConfiguration) { - if (aiNodeConfiguration.getLocation() == null) { - return; - } - this.targetAINodeAddress = aiNodeConfiguration.getLocation().getInternalEndPoint(); - } - - public TGetModelInfoResp convertToThriftResponse() { - TGetModelInfoResp resp = new TGetModelInfoResp(status); - resp.setAiNodeAddress(targetAINodeAddress); - return resp; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java deleted file mode 100644 index 7490a53a01c5..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/consensus/response/model/ModelTableResp.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.consensus.response.model; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.consensus.common.DataSet; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -// TODO: Will be removed in the future -public class ModelTableResp implements DataSet { - - private final TSStatus status; - private final List serializedAllModelInformation; - private Map modelTypeMap; - private Map algorithmMap; - - public ModelTableResp(TSStatus status) { - this.status = status; - this.serializedAllModelInformation = new ArrayList<>(); - } - - public void addModelInformation(List modelInformationList) throws IOException { - for (ModelInformation modelInformation : modelInformationList) { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - } - - public void addModelInformation(ModelInformation modelInformation) throws IOException { - this.serializedAllModelInformation.add(modelInformation.serializeShowModelResult()); - } - - public void setModelTypeMap(Map modelTypeMap) { - this.modelTypeMap = modelTypeMap; - } - - public void setAlgorithmMap(Map algorithmMap) { - this.algorithmMap = algorithmMap; - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java index 5d4b09adfc71..9d7151a8d20e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ConfigManager.java @@ -19,22 +19,12 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.IDataSchema; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeConfiguration; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.common.rpc.thrift.TFlushReq; import org.apache.iotdb.common.rpc.thrift.TPipeHeartbeatResp; import org.apache.iotdb.common.rpc.thrift.TRegionReplicaSet; @@ -58,7 +48,6 @@ import org.apache.iotdb.commons.conf.TrimProperties; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelStatus; import org.apache.iotdb.commons.path.PartialPath; import org.apache.iotdb.commons.path.PathPatternTree; import org.apache.iotdb.commons.path.PathPatternUtil; @@ -97,7 +86,6 @@ import org.apache.iotdb.confignode.consensus.request.write.database.SetTTLPlan; import org.apache.iotdb.confignode.consensus.request.write.database.SetTimePartitionIntervalPlan; import org.apache.iotdb.confignode.consensus.request.write.datanode.RemoveDataNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; import org.apache.iotdb.confignode.consensus.request.write.template.CreateSchemaTemplatePlan; import org.apache.iotdb.confignode.consensus.response.ainode.AINodeRegisterResp; import org.apache.iotdb.confignode.consensus.response.auth.PermissionInfoResp; @@ -129,7 +117,6 @@ import org.apache.iotdb.confignode.manager.schema.ClusterSchemaQuotaStatistics; import org.apache.iotdb.confignode.manager.subscription.SubscriptionManager; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -144,7 +131,6 @@ import org.apache.iotdb.confignode.persistence.schema.ClusterSchemaInfo; import org.apache.iotdb.confignode.persistence.subscription.SubscriptionInfo; import org.apache.iotdb.confignode.procedure.impl.schema.SchemaUtils; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartReq; import org.apache.iotdb.confignode.rpc.thrift.TAINodeRestartResp; @@ -163,13 +149,11 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTopicReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateTrainingReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTriggerReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRegisterReq; import org.apache.iotdb.confignode.rpc.thrift.TDataNodeRestartReq; @@ -186,7 +170,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -203,8 +186,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -257,11 +238,8 @@ import org.apache.iotdb.confignode.rpc.thrift.TTimeSlotList; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.consensus.common.DataSet; import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.schemaengine.template.Template; import org.apache.iotdb.db.schemaengine.template.TemplateAlterOperationType; import org.apache.iotdb.db.schemaengine.template.alter.TemplateAlterOperationUtil; @@ -340,9 +318,6 @@ public class ConfigManager implements IManager { /** CQ. */ private final CQManager cqManager; - /** AI Model. */ - private final ModelManager modelManager; - /** Pipe */ private final PipeManager pipeManager; @@ -362,8 +337,6 @@ public class ConfigManager implements IManager { private static final String DATABASE = "\tDatabase="; - private static final String DOT = "."; - public ConfigManager() throws IOException { // Build the persistence module ClusterInfo clusterInfo = new ClusterInfo(); @@ -375,7 +348,6 @@ public ConfigManager() throws IOException { UDFInfo udfInfo = new UDFInfo(); TriggerInfo triggerInfo = new TriggerInfo(); CQInfo cqInfo = new CQInfo(); - ModelInfo modelInfo = new ModelInfo(); PipeInfo pipeInfo = new PipeInfo(); QuotaInfo quotaInfo = new QuotaInfo(); TTLInfo ttlInfo = new TTLInfo(); @@ -393,7 +365,6 @@ public ConfigManager() throws IOException { udfInfo, triggerInfo, cqInfo, - modelInfo, pipeInfo, subscriptionInfo, quotaInfo, @@ -415,7 +386,6 @@ public ConfigManager() throws IOException { this.udfManager = new UDFManager(this, udfInfo); this.triggerManager = new TriggerManager(this, triggerInfo); this.cqManager = new CQManager(this); - this.modelManager = new ModelManager(this, modelInfo); this.pipeManager = new PipeManager(this, pipeInfo); this.subscriptionManager = new SubscriptionManager(this, subscriptionInfo); this.auditLogger = new CNAuditLogger(this); @@ -1289,11 +1259,6 @@ public TriggerManager getTriggerManager() { return triggerManager; } - @Override - public ModelManager getModelManager() { - return modelManager; - } - @Override public PipeManager getPipeManager() { return pipeManager; @@ -2757,150 +2722,6 @@ public TSStatus transfer(List newUnknownDataList) { return transferResult; } - @Override - public TSStatus createModel(TCreateModelReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.createModel(req) - : status; - } - - private List fetchSchemaForTreeModel(TCreateTrainingReq req) { - List dataSchemaList = new ArrayList<>(); - for (int i = 0; i < req.getDataSchemaForTree().getPathSize(); i++) { - IDataSchema dataSchema = new IDataSchema(req.getDataSchemaForTree().getPath().get(i)); - dataSchema.setTimeRange(req.getTimeRanges().get(i)); - dataSchemaList.add(dataSchema); - } - return dataSchemaList; - } - - private List fetchSchemaForTableModel(TCreateTrainingReq req) { - return Collections.singletonList(new IDataSchema(req.getDataSchemaForTable().getTargetSql())); - } - - public TSStatus createTraining(TCreateTrainingReq req) { - TSStatus status = confirmLeader(); - if (nodeManager.getRegisteredAINodes().isEmpty()) { - return new TSStatus(TSStatusCode.NO_REGISTERED_AI_NODE_ERROR.getStatusCode()) - .setMessage("There is no available AINode! Try to start one."); - } - - TTrainingReq trainingReq = new TTrainingReq(); - trainingReq.setModelId(req.getModelId()); - if (req.isSetExistingModelId()) { - trainingReq.setExistingModelId(req.getExistingModelId()); - } - if (req.isSetParameters() && !req.getParameters().isEmpty()) { - trainingReq.setParameters(req.getParameters()); - } - - try { - status = getConsensusManager().write(new CreateModelPlan(req.getModelId())); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new MetadataException("Can't init model " + req.getModelId()); - } - - List dataSchema; - if (req.isTableModel) { - dataSchema = fetchSchemaForTableModel(req); - trainingReq.setDbType("iotdb.table"); - } else { - dataSchema = fetchSchemaForTreeModel(req); - trainingReq.setDbType("iotdb.tree"); - } - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.TRAINING.ordinal())); - trainingReq.setTargetDataSchema(dataSchema); - - TAINodeInfo registeredAINode = getNodeManager().getRegisteredAINodeInfoList().get(0); - TEndPoint targetAINodeEndPoint = - new TEndPoint(registeredAINode.getInternalAddress(), registeredAINode.getInternalPort()); - try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint)) { - status = client.createTrainingTask(trainingReq); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new IllegalArgumentException(status.message); - } - } - } catch (final Exception e) { - status.setCode(TSStatusCode.CAN_NOT_CONNECT_CONFIGNODE.getStatusCode()); - status.setMessage(e.getMessage()); - try { - updateModelInfo(new TUpdateModelInfoReq(req.modelId, ModelStatus.UNAVAILABLE.ordinal())); - } catch (Exception e2) { - LOGGER.error(e2.getMessage()); - } - } - return status; - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.dropModel(req) - : status; - } - - @Override - public TSStatus loadModel(TLoadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.loadModel(req) - : status; - } - - @Override - public TSStatus unloadModel(TUnloadModelReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.unloadModel(req) - : status; - } - - @Override - public TShowModelsResp showModel(TShowModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showModel(req) - : new TShowModelsResp(status); - } - - @Override - public TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showLoadedModel(req) - : new TShowLoadedModelsResp(status, Collections.emptyMap()); - } - - @Override - public TShowAIDevicesResp showAIDevices() { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.showAIDevices() - : new TShowAIDevicesResp(status, Collections.emptyList()); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.getModelInfo(req) - : new TGetModelInfoResp(status); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - TSStatus status = confirmLeader(); - return status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode() - ? modelManager.updateModelInfo(req) - : status; - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) { TSStatus status = confirmLeader(); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java index 33e77db24907..dff994d70e7e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/IManager.java @@ -19,13 +19,6 @@ package org.apache.iotdb.confignode.manager; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupId; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; @@ -82,7 +75,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -103,7 +95,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -120,8 +111,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -255,13 +244,6 @@ public interface IManager { */ CQManager getCQManager(); - /** - * Get {@link ModelManager}. - * - * @return {@link ModelManager} instance - */ - ModelManager getModelManager(); - /** * Get {@link PipeManager}. * @@ -880,30 +862,6 @@ TDataPartitionTableResp getOrCreateDataPartition( TSStatus transfer(List newUnknownDataList); - /** Create a model. */ - TSStatus createModel(TCreateModelReq req); - - /** Drop a model. */ - TSStatus dropModel(TDropModelReq req); - - /** Load the specific model to the specific devices. */ - TSStatus loadModel(TLoadModelReq req); - - /** Unload the specific model from the specific devices. */ - TSStatus unloadModel(TUnloadModelReq req); - - /** Return the model table. */ - TShowModelsResp showModel(TShowModelsReq req); - - /** Return the loaded model instances. */ - TShowLoadedModelsResp showLoadedModel(TShowLoadedModelsReq req); - - /** Return all available AI devices. */ - TShowAIDevicesResp showAIDevices(); - - /** Update the model state */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req); - /** Set space quota. */ TSStatus setSpaceQuota(TSetSpaceQuotaReq req); diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java deleted file mode 100644 index 3efdbc222b6d..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ModelManager.java +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.manager; - -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.exception.NoAvailableAINodeException; -import org.apache.iotdb.confignode.persistence.ModelInfo; -import org.apache.iotdb.confignode.rpc.thrift.TAINodeInfo; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; - -public class ModelManager { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelManager.class); - - private final ConfigManager configManager; - private final ModelInfo modelInfo; - - public ModelManager(ConfigManager configManager, ModelInfo modelInfo) { - this.configManager = configManager; - this.modelInfo = modelInfo; - } - - public TSStatus createModel(TCreateModelReq req) { - if (modelInfo.contain(req.modelName)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s already exists", req.modelName)); - } - try { - if (req.uri.isEmpty()) { - return configManager.getConsensusManager().write(new CreateModelPlan(req.modelName)); - } - return configManager.getProcedureManager().createModel(req.modelName, req.uri); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while getting model: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - public TSStatus dropModel(TDropModelReq req) { - if (modelInfo.checkModelType(req.getModelId()) != ModelType.USER_DEFINED) { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("Built-in model %s can't be removed", req.modelId)); - } - if (!modelInfo.contain(req.modelId)) { - return new TSStatus(TSStatusCode.MODEL_EXIST_ERROR.getStatusCode()) - .setMessage(String.format("Model name %s doesn't exists", req.modelId)); - } - return configManager.getProcedureManager().dropModel(req.getModelId()); - } - - public TSStatus loadModel(TLoadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq loadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq( - req.existingModelId, req.deviceIdList); - return client.loadModel(loadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to load model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TSStatus unloadModel(TUnloadModelReq req) { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq unloadModelReq = - new org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq(req.modelId, req.deviceIdList); - return client.unloadModel(unloadModelReq); - } catch (Exception e) { - LOGGER.warn("Failed to unload model due to", e); - return new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage()); - } - } - - public TShowModelsResp showModel(final TShowModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowModelsReq showModelsReq = new TShowModelsReq(); - if (req.isSetModelId()) { - showModelsReq.setModelId(req.getModelId()); - } - TShowModelsResp resp = client.showModels(showModelsReq); - TShowModelsResp res = - new TShowModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setModelIdList(resp.getModelIdList()); - res.setModelTypeMap(resp.getModelTypeMap()); - res.setCategoryMap(resp.getCategoryMap()); - res.setStateMap(resp.getStateMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show models due to", e); - return new TShowModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowLoadedModelsResp showLoadedModel(final TShowLoadedModelsReq req) { - try (AINodeClient client = getAINodeClient()) { - TShowLoadedModelsReq showModelsReq = - new TShowLoadedModelsReq().setDeviceIdList(req.getDeviceIdList()); - TShowLoadedModelsResp resp = client.showLoadedModels(showModelsReq); - TShowLoadedModelsResp res = - new TShowLoadedModelsResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceLoadedModelsMap(resp.getDeviceLoadedModelsMap()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show loaded models due to", e); - return new TShowLoadedModelsResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TShowAIDevicesResp showAIDevices() { - try (AINodeClient client = getAINodeClient()) { - org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp resp = client.showAIDevices(); - TShowAIDevicesResp res = - new TShowAIDevicesResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - res.setDeviceIdList(resp.getDeviceIdList()); - return res; - } catch (Exception e) { - LOGGER.warn("Failed to show AI devices due to", e); - return new TShowAIDevicesResp() - .setStatus( - new TSStatus(TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return new TGetModelInfoResp() - .setStatus(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())) - .setAiNodeAddress( - configManager - .getNodeManager() - .getRegisteredAINodes() - .get(0) - .getLocation() - .getInternalEndPoint()); - } - - // Currently this method is only used by built-in timer_xl - public TSStatus updateModelInfo(TUpdateModelInfoReq req) { - if (!modelInfo.contain(req.getModelId())) { - return new TSStatus(TSStatusCode.MODEL_NOT_FOUND_ERROR.getStatusCode()) - .setMessage(String.format("Model %s doesn't exists", req.getModelId())); - } - try { - ModelInformation modelInformation = - new ModelInformation(ModelType.USER_DEFINED, req.getModelId()); - modelInformation.updateStatus(ModelStatus.values()[req.getModelStatus()]); - modelInformation.setAttribute(req.getAttributes()); - modelInformation.setInputColumnSize(1); - if (req.isSetOutputLength()) { - modelInformation.setOutputLength(req.getOutputLength()); - } - if (req.isSetInputLength()) { - modelInformation.setInputLength(req.getInputLength()); - } - UpdateModelInfoPlan updateModelInfoPlan = - new UpdateModelInfoPlan(req.getModelId(), modelInformation); - if (req.isSetAiNodeIds()) { - updateModelInfoPlan.setNodeIds(req.getAiNodeIds()); - } - return configManager.getConsensusManager().write(updateModelInfoPlan); - } catch (ConsensusException e) { - LOGGER.warn("Unexpected error happened while updating model info: ", e); - // consensus layer related errors - TSStatus res = new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()); - res.setMessage(e.getMessage()); - return res; - } - } - - private AINodeClient getAINodeClient() throws NoAvailableAINodeException, ClientManagerException { - List aiNodeInfo = configManager.getNodeManager().getRegisteredAINodeInfoList(); - if (aiNodeInfo.isEmpty()) { - throw new NoAvailableAINodeException(); - } - TEndPoint targetAINodeEndPoint = - new TEndPoint(aiNodeInfo.get(0).getInternalAddress(), aiNodeInfo.get(0).getInternalPort()); - try { - return AINodeClientManager.getInstance().borrowClient(targetAINodeEndPoint); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public List getModelDistributions(String modelName) { - return modelInfo.getNodeIds(modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java index 2e4227af3fc8..d67e7721eef8 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/manager/ProcedureManager.java @@ -61,8 +61,6 @@ import org.apache.iotdb.confignode.procedure.env.RegionMaintainHandler; import org.apache.iotdb.confignode.procedure.env.RemoveDataNodeHandler; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -1414,24 +1412,6 @@ public TSStatus createCQ(TCreateCQReq req, ScheduledExecutorService scheduledExe return waitingProcedureFinished(procedure); } - public TSStatus createModel(String modelName, String uri) { - long procedureId = executor.submitProcedure(new CreateModelProcedure(modelName, uri)); - LOGGER.info("CreateModelProcedure was submitted, procedureId: {}.", procedureId); - return RpcUtils.SUCCESS_STATUS; - } - - public TSStatus dropModel(String modelId) { - DropModelProcedure procedure = new DropModelProcedure(modelId); - executor.submitProcedure(procedure); - TSStatus status = waitingProcedureFinished(procedure); - if (status.getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return status; - } else { - return new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(status.getMessage()); - } - } - public TSStatus createPipePlugin( PipePluginMeta pipePluginMeta, byte[] jarFile, boolean isSetIfNotExistsCondition) { final CreatePipePluginProcedure createPipePluginProcedure = diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java deleted file mode 100644 index aeada03d15cc..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/ModelInfo.java +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.persistence; - -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.ModelTable; -import org.apache.iotdb.commons.model.ModelType; -import org.apache.iotdb.commons.snapshot.SnapshotProcessor; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.consensus.response.model.GetModelInfoResp; -import org.apache.iotdb.confignode.consensus.response.model.ModelTableResp; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.PublicBAOS; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.annotation.concurrent.ThreadSafe; - -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.concurrent.locks.ReadWriteLock; -import java.util.concurrent.locks.ReentrantReadWriteLock; - -@ThreadSafe -public class ModelInfo implements SnapshotProcessor { - - private static final Logger LOGGER = LoggerFactory.getLogger(ModelInfo.class); - - private static final String SNAPSHOT_FILENAME = "model_info.snapshot"; - - private ModelTable modelTable; - - private final Map> modelNameToNodes; - - private final ReadWriteLock modelTableLock = new ReentrantReadWriteLock(); - - private static final Set builtInForecastModel = new HashSet<>(); - - private static final Set builtInAnomalyDetectionModel = new HashSet<>(); - - static { - builtInForecastModel.add("arima"); - builtInForecastModel.add("naive_forecaster"); - builtInForecastModel.add("stl_forecaster"); - builtInForecastModel.add("holtwinters"); - builtInForecastModel.add("exponential_smoothing"); - builtInForecastModel.add("timer_xl"); - builtInForecastModel.add("sundial"); - builtInAnomalyDetectionModel.add("gaussian_hmm"); - builtInAnomalyDetectionModel.add("gmm_hmm"); - builtInAnomalyDetectionModel.add("stray"); - } - - public ModelInfo() { - this.modelTable = new ModelTable(); - this.modelNameToNodes = new HashMap<>(); - } - - public boolean contain(String modelName) { - return modelTable.containsModel(modelName); - } - - public void acquireModelTableReadLock() { - LOGGER.info("acquire ModelTableReadLock"); - modelTableLock.readLock().lock(); - } - - public void releaseModelTableReadLock() { - LOGGER.info("release ModelTableReadLock"); - modelTableLock.readLock().unlock(); - } - - public void acquireModelTableWriteLock() { - LOGGER.info("acquire ModelTableWriteLock"); - modelTableLock.writeLock().lock(); - } - - public void releaseModelTableWriteLock() { - LOGGER.info("release ModelTableWriteLock"); - modelTableLock.writeLock().unlock(); - } - - // init the model in modeInfo, it won't update the details information of the model - public TSStatus createModel(CreateModelPlan plan) { - try { - acquireModelTableWriteLock(); - String modelName = plan.getModelName(); - modelTable.addModel(new ModelInformation(modelName, ModelStatus.LOADING)); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } catch (Exception e) { - final String errorMessage = - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes, because of %s", - plan.getModelName(), e); - LOGGER.warn(errorMessage, e); - return new TSStatus(TSStatusCode.CREATE_MODEL_ERROR.getStatusCode()).setMessage(errorMessage); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModelInNode(int aiNodeId) { - acquireModelTableWriteLock(); - try { - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - entry.getValue().remove(Integer.valueOf(aiNodeId)); - // if list is empty, remove this model totally - if (entry.getValue().isEmpty()) { - modelTable.removeModel(entry.getKey()); - modelNameToNodes.remove(entry.getKey()); - } - } - // currently, we only have one AINode at a time, so we can just clear failed model. - modelTable.clearFailedModel(); - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - public TSStatus dropModel(String modelName) { - acquireModelTableWriteLock(); - TSStatus status; - if (modelTable.containsModel(modelName)) { - modelTable.removeModel(modelName); - modelNameToNodes.remove(modelName); - status = new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - status = - new TSStatus(TSStatusCode.DROP_MODEL_ERROR.getStatusCode()) - .setMessage(String.format("model [%s] has not been created.", modelName)); - } - releaseModelTableWriteLock(); - return status; - } - - public List getNodeIds(String modelName) { - return modelNameToNodes.getOrDefault(modelName, Collections.emptyList()); - } - - private ModelInformation getModelByName(String modelName) { - ModelType modelType = checkModelType(modelName); - if (modelType != ModelType.USER_DEFINED) { - if (modelType == ModelType.BUILT_IN_FORECAST && builtInForecastModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName); - } else if (modelType == ModelType.BUILT_IN_ANOMALY_DETECTION - && builtInAnomalyDetectionModel.contains(modelName)) { - return new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName); - } - } else { - return modelTable.getModelInformationById(modelName); - } - return null; - } - - public ModelTableResp showModel(ShowModelPlan plan) { - acquireModelTableReadLock(); - try { - ModelTableResp modelTableResp = - new ModelTableResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - if (plan.isSetModelName()) { - ModelInformation modelInformation = getModelByName(plan.getModelName()); - if (modelInformation != null) { - modelTableResp.addModelInformation(modelInformation); - } - } else { - modelTableResp.addModelInformation(modelTable.getAllModelInformation()); - for (String modelName : builtInForecastModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_FORECAST, modelName)); - } - for (String modelName : builtInAnomalyDetectionModel) { - modelTableResp.addModelInformation( - new ModelInformation(ModelType.BUILT_IN_ANOMALY_DETECTION, modelName)); - } - } - return modelTableResp; - } catch (IOException e) { - LOGGER.warn("Fail to get ModelTable", e); - return new ModelTableResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - private boolean containsBuiltInModelName(Set builtInModelSet, String modelName) { - // ignore the case - for (String builtInModelName : builtInModelSet) { - if (builtInModelName.equalsIgnoreCase(modelName)) { - return true; - } - } - return false; - } - - public ModelType checkModelType(String modelName) { - if (containsBuiltInModelName(builtInForecastModel, modelName)) { - return ModelType.BUILT_IN_FORECAST; - } else if (containsBuiltInModelName(builtInAnomalyDetectionModel, modelName)) { - return ModelType.BUILT_IN_ANOMALY_DETECTION; - } else { - return ModelType.USER_DEFINED; - } - } - - private int getAvailableAINodeForModel(String modelName, ModelType modelType) { - if (modelType == ModelType.USER_DEFINED) { - List aiNodeIds = modelNameToNodes.get(modelName); - if (aiNodeIds != null) { - return aiNodeIds.get(0); - } - } else { - // any AINode is fine for built-in model - // 0 is always the nodeId for configNode, so it's fine to use 0 as special value - return 0; - } - return -1; - } - - // This method will be used by dataNode to get schema of the model for inference - public GetModelInfoResp getModelInfo(GetModelInfoPlan plan) { - acquireModelTableReadLock(); - try { - String modelName = plan.getModelId(); - GetModelInfoResp getModelInfoResp; - ModelInformation modelInformation; - ModelType modelType; - // check if it's a built-in model - if ((modelType = checkModelType(modelName)) != ModelType.USER_DEFINED) { - modelInformation = new ModelInformation(modelType, modelName); - } else { - modelInformation = modelTable.getModelInformationById(modelName); - } - - if (modelInformation != null) { - getModelInfoResp = - new GetModelInfoResp(new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode())); - } else { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("model [%s] has not been created.", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } - PublicBAOS buffer = new PublicBAOS(); - DataOutputStream stream = new DataOutputStream(buffer); - modelInformation.serialize(stream); - // select the nodeId to process the task, currently we default use the first one. - int aiNodeId = getAvailableAINodeForModel(modelName, modelType); - if (aiNodeId == -1) { - TSStatus errorStatus = new TSStatus(TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - errorStatus.setMessage(String.format("There is no AINode with %s available", modelName)); - getModelInfoResp = new GetModelInfoResp(errorStatus); - return getModelInfoResp; - } else { - getModelInfoResp.setTargetAINodeId(aiNodeId); - } - return getModelInfoResp; - } catch (IOException e) { - LOGGER.warn("Fail to get model info", e); - return new GetModelInfoResp( - new TSStatus(TSStatusCode.EXECUTE_STATEMENT_ERROR.getStatusCode()) - .setMessage(e.getMessage())); - } finally { - releaseModelTableReadLock(); - } - } - - public TSStatus updateModelInfo(UpdateModelInfoPlan plan) { - acquireModelTableWriteLock(); - try { - String modelName = plan.getModelName(); - if (modelTable.containsModel(modelName)) { - modelTable.updateModel(modelName, plan.getModelInformation()); - } - if (!plan.getNodeIds().isEmpty()) { - // only used in model registration, so we can just put the nodeIds in the map without - // checking - modelNameToNodes.put(modelName, plan.getNodeIds()); - } - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } finally { - releaseModelTableWriteLock(); - } - } - - @Override - public boolean processTakeSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (snapshotFile.exists() && snapshotFile.isFile()) { - LOGGER.error( - "Failed to take snapshot of ModelInfo, because snapshot file [{}] is already exist.", - snapshotFile.getAbsolutePath()); - return false; - } - - acquireModelTableReadLock(); - try (FileOutputStream fileOutputStream = new FileOutputStream(snapshotFile)) { - modelTable.serialize(fileOutputStream); - ReadWriteIOUtils.write(modelNameToNodes.size(), fileOutputStream); - for (Map.Entry> entry : modelNameToNodes.entrySet()) { - ReadWriteIOUtils.write(entry.getKey(), fileOutputStream); - ReadWriteIOUtils.write(entry.getValue().size(), fileOutputStream); - for (Integer nodeId : entry.getValue()) { - ReadWriteIOUtils.write(nodeId, fileOutputStream); - } - } - fileOutputStream.getFD().sync(); - return true; - } finally { - releaseModelTableReadLock(); - } - } - - @Override - public void processLoadSnapshot(File snapshotDir) throws TException, IOException { - File snapshotFile = new File(snapshotDir, SNAPSHOT_FILENAME); - if (!snapshotFile.exists() || !snapshotFile.isFile()) { - LOGGER.error( - "Failed to load snapshot of ModelInfo, snapshot file [{}] does not exist.", - snapshotFile.getAbsolutePath()); - return; - } - acquireModelTableWriteLock(); - try (FileInputStream fileInputStream = new FileInputStream(snapshotFile)) { - modelTable.clear(); - modelTable = ModelTable.deserialize(fileInputStream); - int size = ReadWriteIOUtils.readInt(fileInputStream); - for (int i = 0; i < size; i++) { - String modelName = ReadWriteIOUtils.readString(fileInputStream); - int nodeSize = ReadWriteIOUtils.readInt(fileInputStream); - List nodes = new LinkedList<>(); - for (int j = 0; j < nodeSize; j++) { - nodes.add(ReadWriteIOUtils.readInt(fileInputStream)); - } - modelNameToNodes.put(modelName, nodes); - } - } finally { - releaseModelTableWriteLock(); - } - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java index fe8b28c4da2e..d6bad518f6f4 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/persistence/executor/ConfigPlanExecutor.java @@ -35,8 +35,6 @@ import org.apache.iotdb.confignode.consensus.request.read.datanode.GetDataNodeConfigurationPlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetFunctionTablePlan; import org.apache.iotdb.confignode.consensus.request.read.function.GetUDFJarPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.GetModelInfoPlan; -import org.apache.iotdb.confignode.consensus.request.read.model.ShowModelPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.CountTimeSlotListPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetDataPartitionPlan; import org.apache.iotdb.confignode.consensus.request.read.partition.GetNodePathsPartitionPlan; @@ -84,10 +82,6 @@ import org.apache.iotdb.confignode.consensus.request.write.function.DropTableModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.DropTreeModelFunctionPlan; import org.apache.iotdb.confignode.consensus.request.write.function.UpdateFunctionPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AddRegionLocationPlan; import org.apache.iotdb.confignode.consensus.request.write.partition.AutoCleanPartitionTablePlan; import org.apache.iotdb.confignode.consensus.request.write.partition.CreateDataPartitionPlan; @@ -150,7 +144,6 @@ import org.apache.iotdb.confignode.exception.physical.UnknownPhysicalPlanTypeException; import org.apache.iotdb.confignode.manager.pipe.agent.PipeConfigNodeAgent; import org.apache.iotdb.confignode.persistence.ClusterInfo; -import org.apache.iotdb.confignode.persistence.ModelInfo; import org.apache.iotdb.confignode.persistence.ProcedureInfo; import org.apache.iotdb.confignode.persistence.TTLInfo; import org.apache.iotdb.confignode.persistence.TriggerInfo; @@ -210,8 +203,6 @@ public class ConfigPlanExecutor { private final CQInfo cqInfo; - private final ModelInfo modelInfo; - private final PipeInfo pipeInfo; private final SubscriptionInfo subscriptionInfo; @@ -230,7 +221,6 @@ public ConfigPlanExecutor( UDFInfo udfInfo, TriggerInfo triggerInfo, CQInfo cqInfo, - ModelInfo modelInfo, PipeInfo pipeInfo, SubscriptionInfo subscriptionInfo, QuotaInfo quotaInfo, @@ -262,9 +252,6 @@ public ConfigPlanExecutor( this.cqInfo = cqInfo; this.snapshotProcessorList.add(cqInfo); - this.modelInfo = modelInfo; - this.snapshotProcessorList.add(modelInfo); - this.pipeInfo = pipeInfo; this.snapshotProcessorList.add(pipeInfo); @@ -362,10 +349,6 @@ public DataSet executeQueryPlan(final ConfigPhysicalReadPlan req) return udfInfo.getUDFJar((GetUDFJarPlan) req); case GetAllFunctionTable: return udfInfo.getAllUDFTable(); - case ShowModel: - return modelInfo.showModel((ShowModelPlan) req); - case GetModelInfo: - return modelInfo.getModelInfo((GetModelInfoPlan) req); case GetPipePluginTable: return pipeInfo.getPipePluginInfo().showPipePlugins(); case GetPipePluginJar: @@ -648,14 +631,6 @@ public TSStatus executeNonQueryPlan(ConfigPhysicalPlan physicalPlan) return cqInfo.activeCQ((ActiveCQPlan) physicalPlan); case UPDATE_CQ_LAST_EXEC_TIME: return cqInfo.updateCQLastExecutionTime((UpdateCQLastExecTimePlan) physicalPlan); - case CreateModel: - return modelInfo.createModel((CreateModelPlan) physicalPlan); - case UpdateModelInfo: - return modelInfo.updateModelInfo((UpdateModelInfoPlan) physicalPlan); - case DropModel: - return modelInfo.dropModel(((DropModelPlan) physicalPlan).getModelName()); - case DropModelInNode: - return modelInfo.dropModelInNode(((DropModelInNodePlan) physicalPlan).getNodeId()); case CreatePipePlugin: return pipeInfo.getPipePluginInfo().createPipePlugin((CreatePipePluginPlan) physicalPlan); case DropPipePlugin: diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java deleted file mode 100644 index 989061610213..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/CreateModelProcedure.java +++ /dev/null @@ -1,250 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.commons.model.ModelStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.CreateModelPlan; -import org.apache.iotdb.confignode.consensus.request.write.model.UpdateModelInfoPlan; -import org.apache.iotdb.confignode.manager.ConfigManager; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.CreateModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.consensus.exception.ConsensusException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; - -public class CreateModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(CreateModelProcedure.class); - private static final int RETRY_THRESHOLD = 0; - - private String modelName; - - private String uri; - - private ModelInformation modelInformation = null; - - private List aiNodeIds; - - private String loadErrorMsg = ""; - - public CreateModelProcedure() { - super(); - } - - public CreateModelProcedure(String modelName, String uri) { - super(); - this.modelName = modelName; - this.uri = uri; - this.aiNodeIds = new ArrayList<>(); - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, CreateModelState state) { - if (modelName == null || uri == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case LOADING: - initModel(env); - loadModel(env); - setNextState(CreateModelState.ACTIVE); - break; - case ACTIVE: - modelInformation.updateStatus(ModelStatus.ACTIVE); - updateModel(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing createModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in CreateModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to create model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - modelInformation = new ModelInformation(modelName, ModelStatus.UNAVAILABLE); - modelInformation.setAttribute(loadErrorMsg); - updateModel(env); - setFailure( - new ProcedureException( - String.format("Fail to create model [%s] at STATE [%s]", modelName, state))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void initModel(ConfigNodeProcedureEnv env) throws ConsensusException { - LOGGER.info("Start to add model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - TSStatus response = configManager.getConsensusManager().write(new CreateModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to add model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } - - private void checkModelInformationEquals(ModelInformation receiveModelInfo) { - if (modelInformation == null) { - modelInformation = receiveModelInfo; - } else { - if (!modelInformation.equals(receiveModelInfo)) { - throw new ModelManagementException( - String.format( - "Failed to load model [%s] on AI Nodes, model information is not equal in different nodes", - modelName)); - } - } - } - - private void loadModel(ConfigNodeProcedureEnv env) { - for (TAINodeConfiguration curNodeConfig : - env.getConfigManager().getNodeManager().getRegisteredAINodes()) { - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient(curNodeConfig.getLocation().getInternalEndPoint())) { - ModelInformation resp = client.registerModel(modelName, uri); - checkModelInformationEquals(resp); - aiNodeIds.add(curNodeConfig.getLocation().aiNodeId); - } catch (LoadModelException e) { - LOGGER.warn(e.getMessage()); - loadErrorMsg = e.getMessage(); - } catch (Exception e) { - LOGGER.warn( - "Failed to load model on AINode {} from ConfigNode", - curNodeConfig.getLocation().getInternalEndPoint()); - loadErrorMsg = e.getMessage(); - } - } - - if (aiNodeIds.isEmpty()) { - throw new ModelManagementException( - String.format("CREATE MODEL [%s] failed on all AINodes:[%s]", modelName, loadErrorMsg)); - } - } - - private void updateModel(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to update model [{}]", modelName); - - ConfigManager configManager = env.getConfigManager(); - try { - TSStatus response = - configManager - .getConsensusManager() - .write(new UpdateModelInfoPlan(modelName, modelInformation, aiNodeIds)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, response.getMessage())); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Failed to update model [%s] in ModelTable on Config Nodes: %s", - modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, CreateModelState state) - throws IOException, InterruptedException, ProcedureException { - // do nothing - } - - @Override - protected boolean isRollbackSupported(CreateModelState state) { - return false; - } - - @Override - protected CreateModelState getState(int stateId) { - return CreateModelState.values()[stateId]; - } - - @Override - protected int getStateId(CreateModelState createModelState) { - return createModelState.ordinal(); - } - - @Override - protected CreateModelState getInitialState() { - return CreateModelState.LOADING; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.CREATE_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - ReadWriteIOUtils.write(uri, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - uri = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof CreateModelProcedure) { - CreateModelProcedure thatProc = (CreateModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && Objects.equals(thatProc.modelName, this.modelName) - && Objects.equals(thatProc.uri, this.uri); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName, uri); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java deleted file mode 100644 index daa029e04ddf..000000000000 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/model/DropModelProcedure.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.confignode.procedure.impl.model; - -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.common.rpc.thrift.TAINodeConfiguration; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.model.exception.ModelManagementException; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelPlan; -import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; -import org.apache.iotdb.confignode.procedure.exception.ProcedureException; -import org.apache.iotdb.confignode.procedure.impl.node.AbstractNodeProcedure; -import org.apache.iotdb.confignode.procedure.state.model.DropModelState; -import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.thrift.TException; -import org.apache.tsfile.utils.ReadWriteIOUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.util.List; -import java.util.Objects; - -import static org.apache.iotdb.confignode.procedure.state.model.DropModelState.CONFIG_NODE_DROPPED; - -public class DropModelProcedure extends AbstractNodeProcedure { - - private static final Logger LOGGER = LoggerFactory.getLogger(DropModelProcedure.class); - private static final int RETRY_THRESHOLD = 1; - - private String modelName; - - public DropModelProcedure() { - super(); - } - - public DropModelProcedure(String modelName) { - super(); - this.modelName = modelName; - } - - @Override - protected Flow executeFromState(ConfigNodeProcedureEnv env, DropModelState state) { - if (modelName == null) { - return Flow.NO_MORE_STATE; - } - try { - switch (state) { - case AI_NODE_DROPPED: - LOGGER.info("Start to drop model [{}] on AI Nodes", modelName); - dropModelOnAINode(env); - setNextState(CONFIG_NODE_DROPPED); - break; - case CONFIG_NODE_DROPPED: - dropModelOnConfigNode(env); - return Flow.NO_MORE_STATE; - default: - throw new UnsupportedOperationException( - String.format("Unknown state during executing dropModelProcedure, %s", state)); - } - } catch (Exception e) { - if (isRollbackSupported(state)) { - LOGGER.error("Fail in DropModelProcedure", e); - setFailure(new ProcedureException(e.getMessage())); - } else { - LOGGER.error( - "Retrievable error trying to drop model [{}], state [{}]", modelName, state, e); - if (getCycles() > RETRY_THRESHOLD) { - setFailure( - new ProcedureException( - String.format( - "Fail to drop model [%s] at STATE [%s], %s", - modelName, state, e.getMessage()))); - } - } - } - return Flow.HAS_MORE_STATE; - } - - private void dropModelOnAINode(ConfigNodeProcedureEnv env) { - LOGGER.info("Start to drop model file [{}] on AI Node", modelName); - - List aiNodes = - env.getConfigManager().getNodeManager().getRegisteredAINodes(); - aiNodes.forEach( - aiNode -> { - int nodeId = aiNode.getLocation().getAiNodeId(); - try (AINodeClient client = - AINodeClientManager.getInstance() - .borrowClient( - env.getConfigManager() - .getNodeManager() - .getRegisteredAINode(nodeId) - .getLocation() - .getInternalEndPoint())) { - TSStatus status = client.deleteModel(new TDeleteModelReq(modelName)); - if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - status.getMessage()); - } - } catch (Exception e) { - LOGGER.warn( - "Failed to drop model [{}] on AINode [{}], status: {}", - modelName, - nodeId, - e.getMessage()); - } - }); - } - - private void dropModelOnConfigNode(ConfigNodeProcedureEnv env) { - try { - TSStatus response = - env.getConfigManager().getConsensusManager().write(new DropModelPlan(modelName)); - if (response.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(response.getMessage()); - } - } catch (Exception e) { - throw new ModelManagementException( - String.format( - "Fail to start training model [%s] on AI Node: %s", modelName, e.getMessage())); - } - } - - @Override - protected void rollbackState(ConfigNodeProcedureEnv env, DropModelState state) - throws IOException, InterruptedException, ProcedureException { - // no need to rollback - } - - @Override - protected DropModelState getState(int stateId) { - return DropModelState.values()[stateId]; - } - - @Override - protected int getStateId(DropModelState dropModelState) { - return dropModelState.ordinal(); - } - - @Override - protected DropModelState getInitialState() { - return DropModelState.AI_NODE_DROPPED; - } - - @Override - public void serialize(DataOutputStream stream) throws IOException { - stream.writeShort(ProcedureType.DROP_MODEL_PROCEDURE.getTypeCode()); - super.serialize(stream); - ReadWriteIOUtils.write(modelName, stream); - } - - @Override - public void deserialize(ByteBuffer byteBuffer) { - super.deserialize(byteBuffer); - modelName = ReadWriteIOUtils.readString(byteBuffer); - } - - @Override - public boolean equals(Object that) { - if (that instanceof DropModelProcedure) { - DropModelProcedure thatProc = (DropModelProcedure) that; - return thatProc.getProcId() == this.getProcId() - && thatProc.getState() == this.getState() - && (thatProc.modelName).equals(this.modelName); - } - return false; - } - - @Override - public int hashCode() { - return Objects.hash(getProcId(), getState(), modelName); - } -} diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java index 2cab08c28244..2a1c6881b141 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/impl/node/RemoveAINodeProcedure.java @@ -23,13 +23,12 @@ import org.apache.iotdb.common.rpc.thrift.TSStatus; import org.apache.iotdb.commons.utils.ThriftCommonsSerDeUtils; import org.apache.iotdb.confignode.consensus.request.write.ainode.RemoveAINodePlan; -import org.apache.iotdb.confignode.consensus.request.write.model.DropModelInNodePlan; import org.apache.iotdb.confignode.procedure.env.ConfigNodeProcedureEnv; import org.apache.iotdb.confignode.procedure.exception.ProcedureException; import org.apache.iotdb.confignode.procedure.state.RemoveAINodeState; import org.apache.iotdb.confignode.procedure.store.ProcedureType; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.slf4j.Logger; @@ -65,17 +64,11 @@ protected Flow executeFromState(ConfigNodeProcedureEnv env, RemoveAINodeState st try { switch (state) { - case MODEL_DELETE: - env.getConfigManager() - .getConsensusManager() - .write(new DropModelInNodePlan(removedAINode.aiNodeId)); - // Cause the AINode is removed, so we don't need to remove the model file. - setNextState(RemoveAINodeState.NODE_STOP); - break; case NODE_STOP: TSStatus resp = null; try (AINodeClient client = - AINodeClientManager.getInstance().borrowClient(removedAINode.getInternalEndPoint())) { + AINodeClientManager.getInstance() + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { resp = client.stopAINode(); } catch (Exception e) { LOGGER.warn( @@ -148,7 +141,7 @@ protected int getStateId(RemoveAINodeState removeAINodeState) { @Override protected RemoveAINodeState getInitialState() { - return RemoveAINodeState.MODEL_DELETE; + return RemoveAINodeState.NODE_STOP; } @Override diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java index 8a1a6a1bb03b..49820df66361 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/state/RemoveAINodeState.java @@ -20,7 +20,6 @@ package org.apache.iotdb.confignode.procedure.state; public enum RemoveAINodeState { - MODEL_DELETE, NODE_STOP, NODE_REMOVE } diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java index e023171f4fa8..f20a6999d593 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureFactory.java @@ -22,8 +22,6 @@ import org.apache.iotdb.commons.exception.runtime.ThriftSerDeException; import org.apache.iotdb.confignode.procedure.Procedure; import org.apache.iotdb.confignode.procedure.impl.cq.CreateCQProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.CreateModelProcedure; -import org.apache.iotdb.confignode.procedure.impl.model.DropModelProcedure; import org.apache.iotdb.confignode.procedure.impl.node.AddConfigNodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveAINodeProcedure; import org.apache.iotdb.confignode.procedure.impl.node.RemoveConfigNodeProcedure; @@ -263,12 +261,6 @@ public Procedure create(ByteBuffer buffer) throws IOException { case DROP_PIPE_PLUGIN_PROCEDURE: procedure = new DropPipePluginProcedure(); break; - case CREATE_MODEL_PROCEDURE: - procedure = new CreateModelProcedure(); - break; - case DROP_MODEL_PROCEDURE: - procedure = new DropModelProcedure(); - break; case AUTH_OPERATE_PROCEDURE: procedure = new AuthOperationProcedure(false); break; @@ -494,10 +486,6 @@ public static ProcedureType getProcedureType(final Procedure procedure) { return ProcedureType.CREATE_PIPE_PLUGIN_PROCEDURE; } else if (procedure instanceof DropPipePluginProcedure) { return ProcedureType.DROP_PIPE_PLUGIN_PROCEDURE; - } else if (procedure instanceof CreateModelProcedure) { - return ProcedureType.CREATE_MODEL_PROCEDURE; - } else if (procedure instanceof DropModelProcedure) { - return ProcedureType.DROP_MODEL_PROCEDURE; } else if (procedure instanceof CreatePipeProcedureV2) { return ProcedureType.CREATE_PIPE_PROCEDURE_V2; } else if (procedure instanceof StartPipeProcedureV2) { diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java index 65ac1fb24ad5..d076a7d9d926 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/procedure/store/ProcedureType.java @@ -85,7 +85,9 @@ public enum ProcedureType { RENAME_VIEW_PROCEDURE((short) 764), /** AI Model */ + @Deprecated // Since 2.0.6, all models are managed by AINode CREATE_MODEL_PROCEDURE((short) 800), + @Deprecated // Since 2.0.6, all models are managed by AINode DROP_MODEL_PROCEDURE((short) 801), REMOVE_AI_NODE_PROCEDURE((short) 802), diff --git a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java index 59ce7352312f..6582a5bfff8e 100644 --- a/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java +++ b/iotdb-core/confignode/src/main/java/org/apache/iotdb/confignode/service/thrift/ConfigNodeRPCServiceProcessor.java @@ -115,7 +115,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -144,7 +143,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -163,8 +161,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -226,7 +222,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.confignode.service.ConfigNode; import org.apache.iotdb.consensus.exception.ConsensusException; import org.apache.iotdb.db.queryengine.plan.relational.type.AuthorRType; @@ -1362,26 +1357,6 @@ public TShowCQResp showCQ() { return configManager.showCQ(); } - @Override - public TSStatus createModel(TCreateModelReq req) { - return configManager.createModel(req); - } - - @Override - public TSStatus dropModel(TDropModelReq req) { - return configManager.dropModel(req); - } - - @Override - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) { - return configManager.getModelInfo(req); - } - - @Override - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return configManager.updateModelInfo(req); - } - @Override public TSStatus setSpaceQuota(final TSetSpaceQuotaReq req) throws TException { return configManager.setSpaceQuota(req); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java deleted file mode 100644 index 0d784617c090..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/AINodeClientFactory.java +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.ClientManagerMetrics; -import org.apache.iotdb.commons.client.IClientPoolFactory; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ClientPoolProperty; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.concurrent.ThreadName; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AsyncAINodeServiceClient; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.commons.pool2.impl.GenericKeyedObjectPool; - -import java.util.Optional; - -/** Dedicated factory for AINodeClient + AINodeClientPoolFactory. */ -public class AINodeClientFactory extends ThriftClientFactory { - - private static final int connectionTimeout = - CommonDescriptor.getInstance().getConfig().getDnConnectionTimeoutInMS(); - - public AINodeClientFactory( - ClientManager manager, ThriftClientProperty thriftProperty) { - super(manager, thriftProperty); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public void destroyObject(TEndPoint key, PooledObject pooled) throws Exception { - pooled.getObject().invalidate(); - } - - @Override - public boolean validateObject(TEndPoint key, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(org.apache.thrift.transport.TTransport::isOpen) - .orElse(false); - } - - /** The PoolFactory originally inside ClientPoolFactory — now moved here. */ - public static class AINodeClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - // Build thrift client properties - ThriftClientProperty thriftProperty = - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(connectionTimeout) - .setRpcThriftCompressionEnabled( - CommonDescriptor.getInstance().getConfig().isRpcThriftCompressionEnabled()) - .build(); - - GenericKeyedObjectPool pool = - new GenericKeyedObjectPool<>( - new AINodeClientFactory(manager, thriftProperty), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode( - CommonDescriptor.getInstance().getConfig().getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), pool); - - return pool; - } - } - - public static class AINodeHeartbeatClientPoolFactory - implements IClientPoolFactory { - - @Override - public GenericKeyedObjectPool createClientPool( - ClientManager manager) { - - final CommonConfig conf = CommonDescriptor.getInstance().getConfig(); - - GenericKeyedObjectPool clientPool = - new GenericKeyedObjectPool<>( - new AsyncAINodeServiceClient.Factory( - manager, - new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) - .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) - .setPrintLogWhenEncounterException(false) - .build(), - ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), - new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) - .build() - .getConfig()); - - ClientManagerMetrics.getInstance() - .registerClientManager(this.getClass().getSimpleName(), clientPool); - - return clientPool; - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java index 2c037cf0f3e5..df80d49b502b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ConfigNodeClient.java @@ -73,7 +73,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateConsumerReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateSchemaTemplateReq; @@ -102,7 +101,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -121,8 +119,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListReq; import org.apache.iotdb.confignode.rpc.thrift.TGetJarInListResp; import org.apache.iotdb.confignode.rpc.thrift.TGetLocationForTriggerResp; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesReq; import org.apache.iotdb.confignode.rpc.thrift.TGetPathsSetTemplatesResp; import org.apache.iotdb.confignode.rpc.thrift.TGetPipePluginTableResp; @@ -184,7 +180,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TThrottleQuotaResp; import org.apache.iotdb.confignode.rpc.thrift.TUnsetSchemaTemplateReq; import org.apache.iotdb.confignode.rpc.thrift.TUnsubscribeReq; -import org.apache.iotdb.confignode.rpc.thrift.TUpdateModelInfoReq; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; @@ -525,7 +520,8 @@ public TAINodeRestartResp restartAINode(TAINodeRestartReq req) throws TException @Override public TGetAINodeLocationResp getAINodeLocation() throws TException { - return client.getAINodeLocation(); + return executeRemoteCallWithRetry( + () -> client.getAINodeLocation(), resp -> !updateConfigNodeLeader(resp.status)); } @Override @@ -1339,28 +1335,6 @@ public TShowCQResp showCQ() throws TException { () -> client.showCQ(), resp -> !updateConfigNodeLeader(resp.status)); } - @Override - public TSStatus createModel(TCreateModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.createModel(req), status -> !updateConfigNodeLeader(status)); - } - - @Override - public TSStatus dropModel(TDropModelReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.dropModel(req), status -> !updateConfigNodeLeader(status)); - } - - public TGetModelInfoResp getModelInfo(TGetModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.getModelInfo(req), resp -> !updateConfigNodeLeader(resp.getStatus())); - } - - public TSStatus updateModelInfo(TUpdateModelInfoReq req) throws TException { - return executeRemoteCallWithRetry( - () -> client.updateModelInfo(req), status -> !updateConfigNodeLeader(status)); - } - @Override public TSStatus setSpaceQuota(TSetSpaceQuotaReq req) throws TException { return executeRemoteCallWithRetry( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java index b5f5df430129..da0d84d8466f 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/DataNodeClientPoolFactory.java @@ -27,12 +27,13 @@ import org.apache.iotdb.commons.consensus.ConfigRegionId; import org.apache.iotdb.db.conf.IoTDBConfig; import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; import org.apache.commons.pool2.impl.GenericKeyedObjectPool; public class DataNodeClientPoolFactory { - private static final IoTDBConfig conf = IoTDBDescriptor.getInstance().getConfig(); + private static final IoTDBConfig CONF = IoTDBDescriptor.getInstance().getConfig(); private DataNodeClientPoolFactory() { // Empty constructor @@ -49,11 +50,11 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS()) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() @@ -73,15 +74,38 @@ public GenericKeyedObjectPool createClientPool new ConfigNodeClient.Factory( manager, new ThriftClientProperty.Builder() - .setConnectionTimeoutMs(conf.getConnectionTimeoutInMS() * 10) - .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnable()) + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS() * 10) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) .setSelectorNumOfAsyncClientManager( - conf.getSelectorNumOfClientManager() / 10 > 0 - ? conf.getSelectorNumOfClientManager() / 10 + CONF.getSelectorNumOfClientManager() / 10 > 0 + ? CONF.getSelectorNumOfClientManager() / 10 : 1) .build()), new ClientPoolProperty.Builder() - .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } + + public static class AINodeClientPoolFactory implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AINodeClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(CONF.getConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(CONF.isRpcThriftCompressionEnable()) + .build()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(CONF.getMaxClientNumForEachNode()) .build() .getConfig()); ClientManagerMetrics.getInstance() diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java deleted file mode 100644 index 54150b8f3007..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClient.java +++ /dev/null @@ -1,401 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; -import org.apache.iotdb.ainode.rpc.thrift.TConfigs; -import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; -import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; -import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; -import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; -import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.ClientManager; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.ThriftClient; -import org.apache.iotdb.commons.client.factory.ThriftClientFactory; -import org.apache.iotdb.commons.client.property.ThriftClientProperty; -import org.apache.iotdb.commons.conf.CommonConfig; -import org.apache.iotdb.commons.conf.CommonDescriptor; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.ainode.LoadModelException; -import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.rpc.TConfigurationConst; -import org.apache.iotdb.rpc.TSStatusCode; - -import org.apache.commons.pool2.PooledObject; -import org.apache.commons.pool2.impl.DefaultPooledObject; -import org.apache.thrift.TException; -import org.apache.thrift.transport.TSSLTransportFactory; -import org.apache.thrift.transport.TSocket; -import org.apache.thrift.transport.TTransport; -import org.apache.thrift.transport.TTransportException; -import org.apache.thrift.transport.layered.TFramedTransport; -import org.apache.tsfile.enums.TSDataType; -import org.apache.tsfile.read.common.block.TsBlock; -import org.apache.tsfile.read.common.block.column.TsBlockSerde; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.atomic.AtomicReference; - -import static org.apache.iotdb.rpc.TSStatusCode.CAN_NOT_CONNECT_AINODE; -import static org.apache.iotdb.rpc.TSStatusCode.INTERNAL_SERVER_ERROR; - -public class AINodeClient implements AutoCloseable, ThriftClient { - - private static final Logger logger = LoggerFactory.getLogger(AINodeClient.class); - - private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - - private TEndPoint endPoint; - - private TTransport transport; - - private final ThriftClientProperty property; - private IAINodeRPCService.Client client; - - public static final String MSG_CONNECTION_FAIL = - "Fail to connect to AINode. Please check status of AINode"; - private static final int MAX_RETRY = 3; - - @FunctionalInterface - private interface RemoteCall { - R apply(IAINodeRPCService.Client c) throws TException; - } - - private final TsBlockSerde tsBlockSerde = new TsBlockSerde(); - - ClientManager clientManager; - - private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = - ConfigNodeClientManager.getInstance(); - - private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); - - public static TEndPoint getCurrentEndpoint() { - TAINodeLocation loc = CURRENT_LOCATION.get(); - if (loc == null) { - loc = refreshFromConfigNode(); - } - return (loc == null) ? null : pickEndpointFrom(loc); - } - - public static void updateGlobalAINodeLocation(final TAINodeLocation loc) { - if (loc != null) { - CURRENT_LOCATION.set(loc); - } - } - - private R executeRemoteCallWithRetry(RemoteCall call) throws TException { - TException last = null; - for (int attempt = 1; attempt <= MAX_RETRY; attempt++) { - try { - if (transport == null || !transport.isOpen()) { - final TEndPoint ep = getCurrentEndpoint(); - if (ep == null) { - throw new TException("AINode endpoint unavailable"); - } - this.endPoint = ep; - init(); - } - return call.apply(client); - } catch (TException e) { - last = e; - invalidate(); - final TAINodeLocation loc = refreshFromConfigNode(); - if (loc != null) { - this.endPoint = pickEndpointFrom(loc); - } - try { - Thread.sleep(1000L * attempt); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - } - } - } - throw (last != null ? last : new TException(MSG_CONNECTION_FAIL)); - } - - private static TAINodeLocation refreshFromConfigNode() { - try (final ConfigNodeClient cn = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TGetAINodeLocationResp resp = cn.getAINodeLocation(); - if (resp != null && resp.isSetAiNodeLocation()) { - final TAINodeLocation loc = resp.getAiNodeLocation(); - CURRENT_LOCATION.set(loc); - return loc; - } - } catch (Exception e) { - LoggerFactory.getLogger(AINodeClient.class) - .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); - } - return null; - } - - private static TEndPoint pickEndpointFrom(final TAINodeLocation loc) { - if (loc == null) return null; - if (loc.isSetInternalEndPoint() && loc.getInternalEndPoint() != null) { - return loc.getInternalEndPoint(); - } - return null; - } - - public AINodeClient( - ThriftClientProperty property, - TEndPoint endPoint, - ClientManager clientManager) - throws TException { - this.property = property; - this.clientManager = clientManager; - // Instance default endpoint (pool key). Global location can override it on retries. - this.endPoint = endPoint; - init(); - } - - private void init() throws TException { - try { - if (commonConfig.isEnableInternalSSL()) { - TSSLTransportFactory.TSSLTransportParameters params = - new TSSLTransportFactory.TSSLTransportParameters(); - params.setTrustStore(commonConfig.getTrustStorePath(), commonConfig.getTrustStorePwd()); - params.setKeyStore(commonConfig.getKeyStorePath(), commonConfig.getKeyStorePwd()); - transport = - new TFramedTransport.Factory() - .getTransport( - TSSLTransportFactory.getClientSocket( - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs(), - params)); - } else { - transport = - new TFramedTransport.Factory() - .getTransport( - new TSocket( - TConfigurationConst.defaultTConfiguration, - endPoint.getIp(), - endPoint.getPort(), - property.getConnectionTimeoutMs())); - } - if (!transport.isOpen()) { - transport.open(); - } - } catch (TTransportException e) { - throw new TException(MSG_CONNECTION_FAIL); - } - client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); - } - - public TTransport getTransport() { - return transport; - } - - public TSStatus stopAINode() throws TException { - try { - TSStatus status = client.stopAINode(); - if (status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new TException(status.message); - } - return status; - } catch (TException e) { - logger.warn( - "Failed to connect to AINode from ConfigNode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public ModelInformation registerModel(String modelName, String uri) throws LoadModelException { - try { - TRegisterModelReq req = new TRegisterModelReq(uri, modelName); - TRegisterModelResp resp = client.registerModel(req); - if (resp.status.code != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - throw new LoadModelException(resp.status.message, resp.status.getCode()); - } - return parseModelInformation(modelName, resp.getAttributes(), resp.getConfigs()); - } catch (TException e) { - throw new LoadModelException( - e.getMessage(), TSStatusCode.AI_NODE_INTERNAL_ERROR.getStatusCode()); - } - } - - private ModelInformation parseModelInformation( - String modelName, String attributes, TConfigs configs) { - int[] inputShape = configs.getInput_shape().stream().mapToInt(Integer::intValue).toArray(); - int[] outputShape = configs.getOutput_shape().stream().mapToInt(Integer::intValue).toArray(); - - TSDataType[] inputType = new TSDataType[inputShape[1]]; - TSDataType[] outputType = new TSDataType[outputShape[1]]; - for (int i = 0; i < inputShape[1]; i++) { - inputType[i] = TSDataType.values()[configs.getInput_type().get(i)]; - } - for (int i = 0; i < outputShape[1]; i++) { - outputType[i] = TSDataType.values()[configs.getOutput_type().get(i)]; - } - - return new ModelInformation( - modelName, inputShape, outputShape, inputType, outputType, attributes); - } - - public TSStatus deleteModel(TDeleteModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.deleteModel(req)); - } - - public TSStatus loadModel(TLoadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.loadModel(req)); - } - - public TSStatus unloadModel(TUnloadModelReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.unloadModel(req)); - } - - public TShowModelsResp showModels(TShowModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showModels(req)); - } - - public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.showLoadedModels(req)); - } - - public TShowAIDevicesResp showAIDevices() throws TException { - return executeRemoteCallWithRetry(IAINodeRPCService.Client::showAIDevices); - } - - public TInferenceResp inference( - String modelId, - TsBlock inputTsBlock, - Map inferenceAttributes, - TWindowParams windowParams) - throws TException { - try { - TInferenceReq inferenceReq = new TInferenceReq(modelId, tsBlockSerde.serialize(inputTsBlock)); - if (windowParams != null) { - inferenceReq.setWindowParams(windowParams); - } - if (inferenceAttributes != null) { - inferenceReq.setInferenceAttributes(inferenceAttributes); - } - return executeRemoteCallWithRetry(c -> c.inference(inferenceReq)); - } catch (IOException e) { - throw new TException("An exception occurred while serializing input data", e); - } catch (TException e) { - logger.warn( - "Error happens in AINode when executing {}: {}", - Thread.currentThread().getStackTrace()[1].getMethodName(), - e.getMessage()); - throw new TException(MSG_CONNECTION_FAIL); - } - } - - public TForecastResp forecast( - String modelId, TsBlock inputTsBlock, int outputLength, Map options) { - try { - TForecastReq forecastReq = - new TForecastReq(modelId, tsBlockSerde.serialize(inputTsBlock), outputLength); - forecastReq.setOptions(options); - return executeRemoteCallWithRetry(c -> c.forecast(forecastReq)); - } catch (IOException e) { - TSStatus tsStatus = new TSStatus(INTERNAL_SERVER_ERROR.getStatusCode()); - tsStatus.setMessage(String.format("Failed to serialize input tsblock %s", e.getMessage())); - return new TForecastResp(tsStatus); - } catch (TException e) { - TSStatus tsStatus = new TSStatus(CAN_NOT_CONNECT_AINODE.getStatusCode()); - tsStatus.setMessage( - String.format( - "Failed to connect to AINode when executing %s: %s", - Thread.currentThread().getStackTrace()[1].getMethodName(), e.getMessage())); - return new TForecastResp(tsStatus); - } - } - - public TSStatus createTrainingTask(TTrainingReq req) throws TException { - return executeRemoteCallWithRetry(c -> c.createTrainingTask(req)); - } - - @Override - public void close() throws Exception { - clientManager.returnClient(endPoint, this); - } - - @Override - public void invalidate() { - Optional.ofNullable(transport).ifPresent(TTransport::close); - } - - @Override - public void invalidateAll() { - clientManager.clear(endPoint); - } - - @Override - public boolean printLogWhenEncounterException() { - return property.isPrintLogWhenEncounterException(); - } - - public static class Factory extends ThriftClientFactory { - - public Factory( - ClientManager clientClientManager, - ThriftClientProperty thriftClientProperty) { - super(clientClientManager, thriftClientProperty); - } - - @Override - public void destroyObject(TEndPoint tEndPoint, PooledObject pooledObject) - throws Exception { - pooledObject.getObject().invalidate(); - } - - @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { - return new DefaultPooledObject<>( - new AINodeClient(thriftClientProperty, endPoint, clientManager)); - } - - @Override - public boolean validateObject(TEndPoint tEndPoint, PooledObject pooledObject) { - return Optional.ofNullable(pooledObject.getObject().getTransport()) - .map(TTransport::isOpen) - .orElse(false); - } - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java deleted file mode 100644 index faef1c1ae7b6..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AINodeClientManager.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.protocol.client.ainode; - -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.db.protocol.client.AINodeClientFactory; - -public class AINodeClientManager { - - public static final int DEFAULT_AINODE_ID = 0; - - private static final AINodeClientManager INSTANCE = new AINodeClientManager(); - - private final IClientManager clientManager; - - private volatile TEndPoint defaultAINodeEndPoint; - - private AINodeClientManager() { - this.clientManager = - new IClientManager.Factory() - .createClientManager(new AINodeClientFactory.AINodeClientPoolFactory()); - } - - public static AINodeClientManager getInstance() { - return INSTANCE; - } - - public void updateDefaultAINodeLocation(TEndPoint endPoint) { - this.defaultAINodeEndPoint = endPoint; - } - - public AINodeClient borrowClient(TEndPoint endPoint) throws Exception { - return clientManager.borrowClient(endPoint); - } - - public AINodeClient borrowClient(int aiNodeId) throws Exception { - if (aiNodeId != DEFAULT_AINODE_ID) { - throw new IllegalArgumentException("Unsupported AINodeId: " + aiNodeId); - } - if (defaultAINodeEndPoint == null) { - defaultAINodeEndPoint = AINodeClient.getCurrentEndpoint(); - } - return clientManager.borrowClient(defaultAINodeEndPoint); - } - - public void clear(TEndPoint endPoint) { - clientManager.clear(endPoint); - } - - public void clearAll() { - clientManager.close(); - } - - public IClientManager getRawClientManager() { - return clientManager; - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java new file mode 100644 index 000000000000..5eaffc40af9c --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatReq; +import org.apache.iotdb.ainode.rpc.thrift.TAIHeartbeatResp; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; +import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; +import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; +import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; +import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; +import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; +import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.common.rpc.thrift.TSStatus; +import org.apache.iotdb.commons.client.ClientManager; +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.commons.client.ThriftClient; +import org.apache.iotdb.commons.client.factory.ThriftClientFactory; +import org.apache.iotdb.commons.client.property.ThriftClientProperty; +import org.apache.iotdb.commons.client.sync.SyncThriftClientWithErrorHandler; +import org.apache.iotdb.commons.conf.CommonConfig; +import org.apache.iotdb.commons.conf.CommonDescriptor; +import org.apache.iotdb.commons.consensus.ConfigRegionId; +import org.apache.iotdb.confignode.rpc.thrift.TGetAINodeLocationResp; +import org.apache.iotdb.db.conf.IoTDBConfig; +import org.apache.iotdb.db.conf.IoTDBDescriptor; +import org.apache.iotdb.db.protocol.client.ConfigNodeClient; +import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; +import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.rpc.DeepCopyRpcTransportFactory; + +import org.apache.commons.pool2.PooledObject; +import org.apache.commons.pool2.impl.DefaultPooledObject; +import org.apache.thrift.TException; +import org.apache.thrift.transport.TTransport; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import javax.net.ssl.SSLHandshakeException; + +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public class AINodeClient implements IAINodeRPCService.Iface, AutoCloseable, ThriftClient { + + private static final Logger LOGGER = LoggerFactory.getLogger(AINodeClient.class); + + private static final CommonConfig COMMON_CONFIG = CommonDescriptor.getInstance().getConfig(); + private static final IoTDBConfig IOTDB_CONFIG = IoTDBDescriptor.getInstance().getConfig(); + + private TTransport transport; + + private final ThriftClientProperty property; + private IAINodeRPCService.Client client; + + private static final int MAX_RETRY = 5; + private static final int RETRY_INTERVAL_MS = 100; + public static final String MSG_ALL_RETRY_FAILED = + String.format( + "Failed to connect to AINode after %d retries, please check the status of AINode", + MAX_RETRY); + public static final String MSG_AINODE_CONNECTION_FAIL = + "Fail to connect to AINode from DataNode %s when executing %s."; + private static final String UNSUPPORTED_INVOCATION = + "This method is not supported for invocation by DataNode"; + + @Override + public TSStatus stopAINode() throws TException { + return executeRemoteCallWithRetry(() -> client.stopAINode()); + } + + @Override + public TShowModelsResp showModels(TShowModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showModels(req)); + } + + @Override + public TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.showLoadedModels(req)); + } + + @Override + public TShowAIDevicesResp showAIDevices() throws TException { + return executeRemoteCallWithRetry(() -> client.showAIDevices()); + } + + @Override + public TSStatus deleteModel(TDeleteModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.deleteModel(req)); + } + + @Override + public TRegisterModelResp registerModel(TRegisterModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.registerModel(req)); + } + + @Override + public TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) { + throw new UnsupportedOperationException(UNSUPPORTED_INVOCATION); + } + + @Override + public TSStatus createTrainingTask(TTrainingReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.createTrainingTask(req)); + } + + @Override + public TSStatus loadModel(TLoadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.loadModel(req)); + } + + @Override + public TSStatus unloadModel(TUnloadModelReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.unloadModel(req)); + } + + @Override + public TInferenceResp inference(TInferenceReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.inference(req)); + } + + @Override + public TForecastResp forecast(TForecastReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.forecast(req)); + } + + @FunctionalInterface + private interface RemoteCall { + R apply() throws TException; + } + + ClientManager clientManager; + + private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = + ConfigNodeClientManager.getInstance(); + + private static final AtomicReference CURRENT_LOCATION = new AtomicReference<>(); + + private R executeRemoteCallWithRetry(RemoteCall call) throws TException { + for (int attempt = 0; attempt < MAX_RETRY; attempt++) { + try { + return call.apply(); + } catch (TException e) { + final String message = + String.format( + MSG_AINODE_CONNECTION_FAIL, + IOTDB_CONFIG.getAddressAndPort(), + Thread.currentThread().getStackTrace()[2].getMethodName()); + LOGGER.warn(message, e); + CURRENT_LOCATION.set(null); + if (e.getCause() != null && e.getCause() instanceof SSLHandshakeException) { + throw e; + } + } + try { + TimeUnit.MILLISECONDS.sleep(RETRY_INTERVAL_MS); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.warn( + "Unexpected interruption when waiting to try to connect to AINode, may because current node has been down. Will break current execution process to avoid meaningless wait."); + break; + } + tryToConnect(property.getConnectionTimeoutMs()); + } + throw new TException(MSG_ALL_RETRY_FAILED); + } + + private void tryToConnect(int timeoutMs) { + TEndPoint endpoint = getCurrentEndpoint(); + if (endpoint != null) { + try { + connect(endpoint, timeoutMs); + return; + } catch (TException e) { + LOGGER.warn("The current AINode may have been down {}, because", endpoint, e); + CURRENT_LOCATION.set(null); + } + } else { + LOGGER.warn("Cannot connect to any AINode due to there are no available ones."); + } + if (transport != null) { + transport.close(); + } + } + + public void connect(TEndPoint endpoint, int timeoutMs) throws TException { + transport = + COMMON_CONFIG.isEnableInternalSSL() + ? DeepCopyRpcTransportFactory.INSTANCE.getTransport( + endpoint.getIp(), + endpoint.getPort(), + timeoutMs, + COMMON_CONFIG.getTrustStorePath(), + COMMON_CONFIG.getTrustStorePwd(), + COMMON_CONFIG.getKeyStorePath(), + COMMON_CONFIG.getKeyStorePwd()) + : DeepCopyRpcTransportFactory.INSTANCE.getTransport( + // As there is a try-catch already, we do not need to use TSocket.wrap + endpoint.getIp(), endpoint.getPort(), timeoutMs); + if (!transport.isOpen()) { + transport.open(); + } + client = new IAINodeRPCService.Client(property.getProtocolFactory().getProtocol(transport)); + } + + public TEndPoint getCurrentEndpoint() { + TAINodeLocation loc = CURRENT_LOCATION.get(); + if (loc == null) { + loc = refreshFromConfigNode(); + } + return (loc == null) ? null : loc.getInternalEndPoint(); + } + + private TAINodeLocation refreshFromConfigNode() { + try (final ConfigNodeClient cn = + CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { + final TGetAINodeLocationResp resp = cn.getAINodeLocation(); + if (resp.isSetAiNodeLocation()) { + final TAINodeLocation loc = resp.getAiNodeLocation(); + CURRENT_LOCATION.set(loc); + return loc; + } + } catch (Exception e) { + LoggerFactory.getLogger(AINodeClient.class) + .debug("[AINodeClient] refreshFromConfigNode failed: {}", e.toString()); + } + return null; + } + + public AINodeClient( + ThriftClientProperty property, ClientManager clientManager) { + this.property = property; + this.clientManager = clientManager; + tryToConnect(property.getConnectionTimeoutMs()); + } + + public TTransport getTransport() { + return transport; + } + + @Override + public void close() { + clientManager.returnClient(AINodeClientManager.AINODE_ID_PLACEHOLDER, this); + } + + @Override + public void invalidate() { + Optional.ofNullable(transport).ifPresent(TTransport::close); + } + + @Override + public void invalidateAll() { + clientManager.clear(AINodeClientManager.AINODE_ID_PLACEHOLDER); + } + + @Override + public boolean printLogWhenEncounterException() { + return property.isPrintLogWhenEncounterException(); + } + + public static class Factory extends ThriftClientFactory { + + public Factory( + ClientManager clientClientManager, + ThriftClientProperty thriftClientProperty) { + super(clientClientManager, thriftClientProperty); + } + + @Override + public void destroyObject(Integer aiNodeId, PooledObject pooledObject) { + pooledObject.getObject().invalidate(); + } + + @Override + public PooledObject makeObject(Integer Integer) throws Exception { + return new DefaultPooledObject<>( + SyncThriftClientWithErrorHandler.newErrorHandler( + AINodeClient.class, + AINodeClient.class.getConstructor( + thriftClientProperty.getClass(), clientManager.getClass()), + thriftClientProperty, + clientManager)); + } + + @Override + public boolean validateObject(Integer Integer, PooledObject pooledObject) { + return Optional.ofNullable(pooledObject.getObject().getTransport()) + .map(TTransport::isOpen) + .orElse(false); + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java new file mode 100644 index 000000000000..698c8e793883 --- /dev/null +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClientManager.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.protocol.client.an; + +import org.apache.iotdb.commons.client.IClientManager; +import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; + +public class AINodeClientManager { + + public static final int AINODE_ID_PLACEHOLDER = 0; + + private AINodeClientManager() { + // Empty constructor + } + + public static IClientManager getInstance() { + return AINodeClientManagerHolder.INSTANCE; + } + + private static class AINodeClientManagerHolder { + + private static final IClientManager INSTANCE = + new IClientManager.Factory() + .createClientManager(new DataNodeClientPoolFactory.AINodeClientPoolFactory()); + + private AINodeClientManagerHolder() { + // Empty constructor + } + } +} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java index 7126af78b8b5..29e5580311d0 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/ai/InferenceOperator.java @@ -19,18 +19,15 @@ package org.apache.iotdb.db.queryengine.execution.operator.process.ai; +import org.apache.iotdb.ainode.rpc.thrift.TInferenceReq; import org.apache.iotdb.ainode.rpc.thrift.TInferenceResp; -import org.apache.iotdb.ainode.rpc.thrift.TWindowParams; import org.apache.iotdb.db.exception.runtime.ModelInferenceProcessException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper; import org.apache.iotdb.db.queryengine.execution.operator.Operator; import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; @@ -75,7 +72,6 @@ public class InferenceOperator implements ProcessOperator { private int resultIndex = 0; private List results; private final TsBlockSerde serde = new TsBlockSerde(); - private InferenceWindowType windowType = null; private final boolean generateTimeColumn; private long maxTimestamp; @@ -109,10 +105,6 @@ public InferenceOperator( this.maxReturnSize = maxReturnSize; this.totalRow = 0; - if (modelInferenceDescriptor.getInferenceWindowParameter() != null) { - windowType = modelInferenceDescriptor.getInferenceWindowParameter().getWindowType(); - } - if (generateTimeColumn) { this.interval = 0; this.minTimestamp = Long.MAX_VALUE; @@ -237,62 +229,6 @@ private void appendTsBlockToBuilder(TsBlock inputTsBlock) { } } - private TWindowParams getWindowParams() { - TWindowParams windowParams; - if (windowType == null) { - return null; - } - if (windowType == InferenceWindowType.COUNT) { - CountInferenceWindowParameter countInferenceWindowParameter = - (CountInferenceWindowParameter) modelInferenceDescriptor.getInferenceWindowParameter(); - windowParams = new TWindowParams(); - windowParams.setWindowInterval((int) countInferenceWindowParameter.getInterval()); - windowParams.setWindowStep((int) countInferenceWindowParameter.getStep()); - } else { - windowParams = null; - } - return windowParams; - } - - private TsBlock preProcess(TsBlock inputTsBlock) { - // boolean notBuiltIn = !modelInferenceDescriptor.getModelInformation().isBuiltIn(); - boolean notBuiltIn = false; - if (windowType == null || windowType == InferenceWindowType.HEAD) { - if (notBuiltIn - && totalRow != modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data does not match the model input %s. Try to use LIMIT in SQL or WINDOW in CALL INFERENCE", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - return inputTsBlock; - } else if (windowType == InferenceWindowType.COUNT) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - } else if (windowType == InferenceWindowType.TAIL) { - if (notBuiltIn - && totalRow < modelInferenceDescriptor.getModelInformation().getInputShape()[0]) { - throw new ModelInferenceProcessException( - String.format( - "The number of rows %s in the input data is less than the model input %s. ", - totalRow, modelInferenceDescriptor.getModelInformation().getInputShape()[0])); - } - // Tail window logic: get the latest data for inference - long windowSize = - (int) - ((BottomInferenceWindowParameter) - modelInferenceDescriptor.getInferenceWindowParameter()) - .getWindowSize(); - return inputTsBlock.subTsBlock((int) (totalRow - windowSize)); - } - return inputTsBlock; - } - private void submitInferenceTask() { if (generateTimeColumn) { @@ -301,20 +237,16 @@ private void submitInferenceTask() { TsBlock inputTsBlock = inputTsBlockBuilder.build(); - TsBlock finalInputTsBlock = preProcess(inputTsBlock); - TWindowParams windowParams = getWindowParams(); - inferenceExecutionFuture = Futures.submit( () -> { try (AINodeClient client = AINodeClientManager.getInstance() - .borrowClient(modelInferenceDescriptor.getTargetAINode())) { + .borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { return client.inference( - modelInferenceDescriptor.getModelName(), - finalInputTsBlock, - modelInferenceDescriptor.getInferenceAttributes(), - windowParams); + new TInferenceReq( + modelInferenceDescriptor.getModelId(), serde.serialize(inputTsBlock)) + .setInferenceAttributes(modelInferenceDescriptor.getInferenceAttributes())); } catch (Exception e) { throw new ModelInferenceProcessException(e.getMessage()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java index fc6888165659..daceffce6b7b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java @@ -19,14 +19,11 @@ package org.apache.iotdb.db.queryengine.execution.operator.source.relational; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; -import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; import org.apache.iotdb.common.rpc.thrift.Model; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TConfigNodeLocation; import org.apache.iotdb.common.rpc.thrift.TConsensusGroupType; import org.apache.iotdb.common.rpc.thrift.TDataNodeLocation; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.audit.UserEntity; import org.apache.iotdb.commons.client.exception.ClientManagerException; import org.apache.iotdb.commons.conf.IoTDBConstant; @@ -68,8 +65,6 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClient; import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; @@ -157,8 +152,6 @@ public static Iterator getSupplier( return new SubscriptionSupplier(dataTypes, userEntity); case InformationSchema.VIEWS: return new ViewsSupplier(dataTypes, userEntity); - case InformationSchema.MODELS: - return new ModelsSupplier(dataTypes); case InformationSchema.FUNCTIONS: return new FunctionsSupplier(dataTypes); case InformationSchema.CONFIGURATIONS: @@ -798,112 +791,6 @@ public boolean hasNext() { } } - private static class ModelsSupplier extends TsBlockSupplier { - private final ModelIterator iterator; - - private ModelsSupplier(final List dataTypes) throws Exception { - super(dataTypes); - final TEndPoint ep = AINodeClient.getCurrentEndpoint(); - try (final AINodeClient ai = AINodeClientManager.getInstance().borrowClient(ep)) { - iterator = new ModelIterator(ai.showModels(new TShowModelsReq())); - } - } - - private static class ModelIterator implements Iterator { - - private int index = 0; - private final TShowModelsResp resp; - - private ModelIterator(TShowModelsResp resp) { - this.resp = resp; - } - - @Override - public boolean hasNext() { - return index < resp.getModelIdListSize(); - } - - @Override - public ModelInfoInString next() { - String modelId = resp.getModelIdList().get(index++); - return new ModelInfoInString( - modelId, - resp.getModelTypeMap().get(modelId), - resp.getCategoryMap().get(modelId), - resp.getStateMap().get(modelId)); - } - } - - private static class ModelInfoInString { - - private final String modelId; - private final String modelType; - private final String category; - private final String state; - - public ModelInfoInString(String modelId, String modelType, String category, String state) { - this.modelId = modelId; - this.modelType = modelType; - this.category = category; - this.state = state; - } - - public String getModelId() { - return modelId; - } - - public String getModelType() { - return modelType; - } - - public String getCategory() { - return category; - } - - public String getState() { - return state; - } - } - - @Override - protected void constructLine() { - final ModelInfoInString modelInfo = iterator.next(); - columnBuilders[0].writeBinary( - new Binary(modelInfo.getModelId(), TSFileConfig.STRING_CHARSET)); - columnBuilders[1].writeBinary( - new Binary(modelInfo.getModelType(), TSFileConfig.STRING_CHARSET)); - columnBuilders[2].writeBinary( - new Binary(modelInfo.getCategory(), TSFileConfig.STRING_CHARSET)); - columnBuilders[3].writeBinary(new Binary(modelInfo.getState(), TSFileConfig.STRING_CHARSET)); - // if (Objects.equals(modelType, ModelType.USER_DEFINED.toString())) { - // columnBuilders[3].writeBinary( - // new Binary( - // INPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_SHAPE - // + ReadWriteIOUtils.readString(modelInfo) - // + INPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo) - // + OUTPUT_DATA_TYPE - // + ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // columnBuilders[4].writeBinary( - // new Binary(ReadWriteIOUtils.readString(modelInfo), - // TSFileConfig.STRING_CHARSET)); - // } else { - // columnBuilders[3].appendNull(); - // columnBuilders[4].writeBinary( - // new Binary("Built-in model in IoTDB", TSFileConfig.STRING_CHARSET)); - // } - resultBuilder.declarePosition(); - } - - @Override - public boolean hasNext() { - return iterator.hasNext(); - } - } - private static class FunctionsSupplier extends TsBlockSupplier { private final Iterator udfIterator; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java index 34a289b76c9d..dc56fe118b7b 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/AnalyzeVisitor.java @@ -25,7 +25,6 @@ import org.apache.iotdb.commons.conf.IoTDBConstant; import org.apache.iotdb.commons.exception.IllegalPathException; import org.apache.iotdb.commons.exception.MetadataException; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -55,14 +54,6 @@ import org.apache.iotdb.db.queryengine.common.schematree.IMeasurementSchemaInfo; import org.apache.iotdb.db.queryengine.common.schematree.ISchemaTree; import org.apache.iotdb.db.queryengine.execution.operator.window.WindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.BottomInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.CountInferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.HeadInferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindow; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowType; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.TailInferenceWindow; import org.apache.iotdb.db.queryengine.metric.QueryPlanCostMetricSet; import org.apache.iotdb.db.queryengine.plan.analyze.load.LoadTsFileAnalyzer; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; @@ -425,46 +416,14 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem return; } - // Get model metadata from configNode and do some check + // Get model metadata from AINode String modelId = queryStatement.getModelId(); TSStatus status = modelFetcher.fetchModel(modelId, analysis); if (status.getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { throw new GetModelInfoException(status.getMessage()); } - // set inference window if there is - if (queryStatement.isSetInferenceWindow()) { - InferenceWindow window = queryStatement.getInferenceWindow(); - if (InferenceWindowType.HEAD == window.getType()) { - long windowSize = ((HeadInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - if (queryStatement.hasLimit() && queryStatement.getRowLimit() < windowSize) { - throw new SemanticException( - "Limit in Sql should be larger than window size in inference"); - } - // optimize head window by limitNode - queryStatement.setRowLimit(windowSize); - } else if (InferenceWindowType.TAIL == window.getType()) { - long windowSize = ((TailInferenceWindow) window).getWindowSize(); - // checkWindowSize(windowSize, modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new BottomInferenceWindowParameter(windowSize); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } else if (InferenceWindowType.COUNT == window.getType()) { - CountInferenceWindow countInferenceWindow = (CountInferenceWindow) window; - // checkWindowSize(countInferenceWindow.getInterval(), modelInformation); - InferenceWindowParameter inferenceWindowParameter = - new CountInferenceWindowParameter( - countInferenceWindow.getInterval(), countInferenceWindow.getStep()); - analysis - .getModelInferenceDescriptor() - .setInferenceWindowParameter(inferenceWindowParameter); - } - } - - // set inference attributes if there is + // Set inference attributes if there is if (queryStatement.hasInferenceAttributes()) { analysis .getModelInferenceDescriptor() @@ -472,12 +431,6 @@ private void analyzeModelInference(Analysis analysis, QueryStatement queryStatem } } - private void checkWindowSize(long windowSize, ModelInformation modelInformation) { - if (modelInformation.isBuiltIn()) { - return; - } - } - private ISchemaTree analyzeSchema( QueryStatement queryStatement, Analysis analysis, @@ -1717,22 +1670,11 @@ static void analyzeOutput( } if (queryStatement.hasModelInference()) { - ModelInformation modelInformation = analysis.getModelInformation(); // check input - checkInputShape(modelInformation, outputExpressions); - checkInputType(analysis, modelInformation, outputExpressions); - + checkInputType(analysis, outputExpressions); // set output List columnHeaders = new ArrayList<>(); - int[] outputShape = modelInformation.getOutputShape(); - TSDataType[] outputDataType = modelInformation.getOutputDataType(); - for (int i = 0; i < outputShape[1]; i++) { - columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME + i, outputDataType[i])); - } - analysis - .getModelInferenceDescriptor() - .setOutputColumnNames( - columnHeaders.stream().map(ColumnHeader::getColumnName).collect(Collectors.toList())); + columnHeaders.add(new ColumnHeader(INFERENCE_COLUMN_NAME, TSDataType.DOUBLE)); boolean isIgnoreTimestamp = !queryStatement.isGenerateTime(); analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); return; @@ -1756,74 +1698,16 @@ static void analyzeOutput( analysis.setRespDatasetHeader(new DatasetHeader(columnHeaders, isIgnoreTimestamp)); } - // check if the result of SQL matches the input of model - private static void checkInputShape( - ModelInformation modelInformation, List> outputExpressions) { - if (modelInformation.isBuiltIn()) { - modelInformation.setInputColumnSize(outputExpressions.size()); - return; - } - - // check inputShape - int[] inputShape = modelInformation.getInputShape(); - if (inputShape.length != 2) { - throw new SemanticException( - String.format( - "The input shape of model is not correct, the dimension of input shape should be 2, actual dimension is %d", - inputShape.length)); - } - int columnNumber = inputShape[1]; - if (columnNumber != outputExpressions.size()) { - throw new SemanticException( - String.format( - "The column number of SQL result does not match the number of model input [%d] for inference", - columnNumber)); - } - } - private static void checkInputType( - Analysis analysis, - ModelInformation modelInformation, - List> outputExpressions) { - - if (modelInformation.isBuiltIn()) { - TSDataType[] inputType = new TSDataType[outputExpressions.size()]; - for (int i = 0; i < outputExpressions.size(); i++) { - Expression inputExpression = outputExpressions.get(i).left; - TSDataType inputDataType = analysis.getType(inputExpression); - if (!inputDataType.isNumeric()) { - throw new SemanticException( - String.format( - "The type of SQL result column [%s in %d] should be numeric when inference", - inputDataType, i)); - } - inputType[i] = inputDataType; - } - modelInformation.setInputDataType(inputType); - return; - } - - TSDataType[] inputType = modelInformation.getInputDataType(); - if (inputType.length != modelInformation.getInputShape()[1]) { - throw new SemanticException( - String.format( - "The inputType does not match the input shape [%d] for inference", - modelInformation.getInputShape()[1])); - } - for (int i = 0; i < inputType.length; i++) { + Analysis analysis, List> outputExpressions) { + for (int i = 0; i < outputExpressions.size(); i++) { Expression inputExpression = outputExpressions.get(i).left; TSDataType inputDataType = analysis.getType(inputExpression); - boolean isExpressionNumeric = inputDataType.isNumeric(); - boolean isModelNumeric = inputType[i].isNumeric(); - if (isExpressionNumeric && isModelNumeric) { - // every model supports numeric by default - continue; - } - if (inputDataType != inputType[i]) { + if (!inputDataType.isNumeric()) { throw new SemanticException( String.format( - "The type of SQL result column [%s in %d] does not match the type of model input [%s] when inference", - inputDataType, i, inputType[i])); + "The type of SQL result column [%s in %d] should be numeric when inference", + inputDataType, i)); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java index 586e12e589ab..1feecaefde9c 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/IModelFetcher.java @@ -20,12 +20,8 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; public interface IModelFetcher { /** Get model information by model id from configNode. */ TSStatus fetchModel(String modelId, Analysis analysis); - - // currently only used by table model - ModelInferenceDescriptor fetchModel(String modelName); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java index dbeee4e8ed4b..b4123c237bbd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/analyze/ModelFetcher.java @@ -20,27 +20,13 @@ package org.apache.iotdb.db.queryengine.plan.analyze; import org.apache.iotdb.common.rpc.thrift.TSStatus; -import org.apache.iotdb.commons.client.IClientManager; -import org.apache.iotdb.commons.client.exception.ClientManagerException; -import org.apache.iotdb.commons.consensus.ConfigRegionId; -import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoReq; -import org.apache.iotdb.confignode.rpc.thrift.TGetModelInfoResp; -import org.apache.iotdb.db.exception.ainode.ModelNotFoundException; -import org.apache.iotdb.db.exception.sql.StatementAnalyzeException; -import org.apache.iotdb.db.protocol.client.ConfigNodeClient; -import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; -import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; +import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; -import org.apache.thrift.TException; - +// TODO: This class should contact with AINode directly and cache model info in DataNode public class ModelFetcher implements IModelFetcher { - private final IClientManager configNodeClientManager = - ConfigNodeClientManager.getInstance(); - private static final class ModelFetcherHolder { private static final ModelFetcher INSTANCE = new ModelFetcher(); @@ -55,34 +41,9 @@ public static ModelFetcher getInstance() { private ModelFetcher() {} @Override - public TSStatus fetchModel(String modelName, Analysis analysis) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new StatementAnalyzeException(e.getMessage()); - } - } - - @Override - public ModelInferenceDescriptor fetchModel(String modelName) { - try (ConfigNodeClient client = - configNodeClientManager.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - TGetModelInfoResp getModelInfoResp = client.getModelInfo(new TGetModelInfoReq(modelName)); - if (getModelInfoResp.getStatus().getCode() == TSStatusCode.SUCCESS_STATUS.getStatusCode()) { - return new ModelInferenceDescriptor(getModelInfoResp.aiNodeAddress); - } else { - throw new ModelNotFoundException(getModelInfoResp.getStatus().getMessage()); - } - } catch (ClientManagerException | TException e) { - throw new IoTDBRuntimeException( - String.format("fetch model [%s] info failed: %s", modelName, e.getMessage()), - TSStatusCode.GET_MODEL_INFO_ERROR.getStatusCode()); - } + public TSStatus fetchModel(String modelId, Analysis analysis) { + analysis.setModelInferenceDescriptor( + new ModelInferenceDescriptor(new ModelInformation(modelId))); + return new TSStatus(TSStatusCode.SUCCESS_STATUS.getStatusCode()); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 34bbd57640a3..5f27b3feb8c3 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -19,7 +19,10 @@ package org.apache.iotdb.db.queryengine.plan.execution.config.executor; +import org.apache.iotdb.ainode.rpc.thrift.TDeleteModelReq; import org.apache.iotdb.ainode.rpc.thrift.TLoadModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelReq; +import org.apache.iotdb.ainode.rpc.thrift.TRegisterModelResp; import org.apache.iotdb.ainode.rpc.thrift.TShowAIDevicesResp; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; @@ -96,7 +99,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TCountTimeSlotListResp; import org.apache.iotdb.confignode.rpc.thrift.TCreateCQReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TCreateModelReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TCreatePipeReq; import org.apache.iotdb.confignode.rpc.thrift.TCreateTableViewReq; @@ -114,7 +116,6 @@ import org.apache.iotdb.confignode.rpc.thrift.TDescTableResp; import org.apache.iotdb.confignode.rpc.thrift.TDropCQReq; import org.apache.iotdb.confignode.rpc.thrift.TDropFunctionReq; -import org.apache.iotdb.confignode.rpc.thrift.TDropModelReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipePluginReq; import org.apache.iotdb.confignode.rpc.thrift.TDropPipeReq; import org.apache.iotdb.confignode.rpc.thrift.TDropSubscriptionReq; @@ -175,8 +176,8 @@ import org.apache.iotdb.db.protocol.client.ConfigNodeClientManager; import org.apache.iotdb.db.protocol.client.ConfigNodeInfo; import org.apache.iotdb.db.protocol.client.DataNodeClientPoolFactory; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; @@ -381,6 +382,8 @@ public class ClusterConfigTaskExecutor implements IConfigTaskExecutor { private static final IClientManager CONFIG_NODE_CLIENT_MANAGER = ConfigNodeClientManager.getInstance(); + private static final IClientManager AI_NODE_CLIENT_MANAGER = + AINodeClientManager.getInstance(); /** FIXME Consolidate this clientManager with the upper one. */ private static final IClientManager @@ -3617,16 +3620,16 @@ public SettableFuture showContinuousQueries() { @Override public SettableFuture createModel(String modelId, String uri) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TCreateModelReq req = new TCreateModelReq(modelId, uri); - final TSStatus status = client.createModel(req); - if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { - future.setException(new IoTDBException(status)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TRegisterModelReq req = new TRegisterModelReq(modelId, uri); + final TRegisterModelResp resp = client.registerModel(req); + if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != resp.getStatus().getCode()) { + future.setException(new IoTDBException(resp.getStatus())); } else { future.set(new ConfigTaskResult(TSStatusCode.SUCCESS_STATUS)); } - } catch (final ClientManagerException | TException e) { + } catch (final TException | ClientManagerException e) { future.setException(e); } return future; @@ -3635,9 +3638,9 @@ public SettableFuture createModel(String modelId, String uri) @Override public SettableFuture dropModel(final String modelId) { final SettableFuture future = SettableFuture.create(); - try (final ConfigNodeClient client = - CONFIG_NODE_CLIENT_MANAGER.borrowClient(ConfigNodeInfo.CONFIG_REGION_ID)) { - final TSStatus executionStatus = client.dropModel(new TDropModelReq(modelId)); + try (final AINodeClient client = + AI_NODE_CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + final TSStatus executionStatus = client.deleteModel(new TDeleteModelReq(modelId)); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != executionStatus.getCode()) { future.setException(new IoTDBException(executionStatus)); } else { @@ -3653,7 +3656,7 @@ public SettableFuture dropModel(final String modelId) { public SettableFuture showModels(final String modelId) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowModelsReq req = new TShowModelsReq(); if (modelId != null) { req.setModelId(modelId); @@ -3674,7 +3677,7 @@ public SettableFuture showModels(final String modelId) { public SettableFuture showLoadedModels(List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowLoadedModelsReq req = new TShowLoadedModelsReq(); req.setDeviceIdList(deviceIdList != null ? deviceIdList : new ArrayList<>()); final TShowLoadedModelsResp resp = ai.showLoadedModels(req); @@ -3693,7 +3696,7 @@ public SettableFuture showLoadedModels(List deviceIdLi public SettableFuture showAIDevices() { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TShowAIDevicesResp resp = ai.showAIDevices(); if (resp.getStatus().getCode() != TSStatusCode.SUCCESS_STATUS.getStatusCode()) { future.setException(new IoTDBException(resp.getStatus())); @@ -3711,7 +3714,7 @@ public SettableFuture loadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TLoadModelReq req = new TLoadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.loadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3730,7 +3733,7 @@ public SettableFuture unloadModel( String existingModelId, List deviceIdList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TUnloadModelReq req = new TUnloadModelReq(existingModelId, deviceIdList); final TSStatus result = ai.unloadModel(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != result.getCode()) { @@ -3755,7 +3758,7 @@ public SettableFuture createTraining( @Nullable List pathList) { final SettableFuture future = SettableFuture.create(); try (final AINodeClient ai = - AINodeClientManager.getInstance().borrowClient(AINodeClientManager.DEFAULT_AINODE_ID)) { + AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { final TTrainingReq req = new TTrainingReq(); req.setModelId(modelId); req.setParameters(parameters); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java index 09205c9eb564..a01acf86db57 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/node/process/AI/InferenceNode.java @@ -31,6 +31,7 @@ import java.io.DataOutputStream; import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -90,7 +91,7 @@ public PlanNode clone() { @Override public List getOutputColumnNames() { - return modelInferenceDescriptor.getOutputColumnNames(); + return Collections.singletonList("output"); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java index b7c6aaa4f4b0..1301ec97eb32 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/plan/parameter/model/ModelInferenceDescriptor.java @@ -19,9 +19,7 @@ package org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.commons.model.ModelInformation; -import org.apache.iotdb.db.queryengine.execution.operator.window.ainode.InferenceWindowParameter; import org.apache.tsfile.utils.ReadWriteIOUtils; @@ -36,19 +34,15 @@ public class ModelInferenceDescriptor { - private final TEndPoint targetAINode; - private ModelInformation modelInformation; + private final ModelInformation modelInformation; private List outputColumnNames; - private InferenceWindowParameter inferenceWindowParameter; private Map inferenceAttributes; - public ModelInferenceDescriptor(TEndPoint targetAINode) { - this.targetAINode = targetAINode; + public ModelInferenceDescriptor(ModelInformation modelInformation) { + this.modelInformation = modelInformation; } private ModelInferenceDescriptor(ByteBuffer buffer) { - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelInformation = ModelInformation.deserialize(buffer); int outputColumnNamesSize = ReadWriteIOUtils.readInt(buffer); if (outputColumnNamesSize == 0) { @@ -59,12 +53,6 @@ private ModelInferenceDescriptor(ByteBuffer buffer) { this.outputColumnNames.add(ReadWriteIOUtils.readString(buffer)); } } - boolean hasInferenceWindowParameter = ReadWriteIOUtils.readBool(buffer); - if (hasInferenceWindowParameter) { - this.inferenceWindowParameter = InferenceWindowParameter.deserialize(buffer); - } else { - this.inferenceWindowParameter = null; - } int inferenceAttributesSize = ReadWriteIOUtils.readInt(buffer); if (inferenceAttributesSize == 0) { this.inferenceAttributes = null; @@ -85,24 +73,12 @@ public Map getInferenceAttributes() { return inferenceAttributes; } - public void setInferenceWindowParameter(InferenceWindowParameter inferenceWindowParameter) { - this.inferenceWindowParameter = inferenceWindowParameter; - } - - public InferenceWindowParameter getInferenceWindowParameter() { - return inferenceWindowParameter; - } - public ModelInformation getModelInformation() { return modelInformation; } - public TEndPoint getTargetAINode() { - return targetAINode; - } - - public String getModelName() { - return modelInformation.getModelName(); + public String getModelId() { + return modelInformation.getModelId(); } public void setOutputColumnNames(List outputColumnNames) { @@ -114,8 +90,6 @@ public List getOutputColumnNames() { } public void serialize(ByteBuffer byteBuffer) { - ReadWriteIOUtils.write(targetAINode.ip, byteBuffer); - ReadWriteIOUtils.write(targetAINode.port, byteBuffer); modelInformation.serialize(byteBuffer); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, byteBuffer); @@ -125,12 +99,6 @@ public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(outputColumnName, byteBuffer); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, byteBuffer); - } else { - ReadWriteIOUtils.write(true, byteBuffer); - inferenceWindowParameter.serialize(byteBuffer); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, byteBuffer); } else { @@ -143,8 +111,6 @@ public void serialize(ByteBuffer byteBuffer) { } public void serialize(DataOutputStream stream) throws IOException { - ReadWriteIOUtils.write(targetAINode.ip, stream); - ReadWriteIOUtils.write(targetAINode.port, stream); modelInformation.serialize(stream); if (outputColumnNames == null) { ReadWriteIOUtils.write(0, stream); @@ -154,12 +120,6 @@ public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(outputColumnName, stream); } } - if (inferenceWindowParameter == null) { - ReadWriteIOUtils.write(false, stream); - } else { - ReadWriteIOUtils.write(true, stream); - inferenceWindowParameter.serialize(stream); - } if (inferenceAttributes == null) { ReadWriteIOUtils.write(0, stream); } else { @@ -184,20 +144,13 @@ public boolean equals(Object o) { return false; } ModelInferenceDescriptor that = (ModelInferenceDescriptor) o; - return targetAINode.equals(that.targetAINode) - && modelInformation.equals(that.modelInformation) + return modelInformation.equals(that.modelInformation) && outputColumnNames.equals(that.outputColumnNames) - && inferenceWindowParameter.equals(that.inferenceWindowParameter) && inferenceAttributes.equals(that.inferenceAttributes); } @Override public int hashCode() { - return Objects.hash( - targetAINode, - modelInformation, - outputColumnNames, - inferenceWindowParameter, - inferenceAttributes); + return Objects.hash(modelInformation, outputColumnNames, inferenceAttributes); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java index 73c57157f722..e13c52ba8b16 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java @@ -40,7 +40,6 @@ import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableArgumentAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.analyzer.tablefunction.TableFunctionInvocationAnalysis; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; -import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.metadata.ColumnSchema; import org.apache.iotdb.db.queryengine.plan.relational.metadata.Metadata; import org.apache.iotdb.db.queryengine.plan.relational.metadata.QualifiedObjectName; @@ -4684,11 +4683,6 @@ public Scope visitTableFunctionInvocation(TableFunctionInvocation node, Optional String functionName = node.getName().toString(); TableFunction function = metadata.getTableFunction(functionName); - // set model fetcher for ForecastTableFunction - if (function instanceof ForecastTableFunction) { - ((ForecastTableFunction) function).setModelFetcher(metadata.getModelFetcher()); - } - Node errorLocation = node; if (!node.getArguments().isEmpty()) { errorLocation = node.getArguments().get(0); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java index 887d7c26d305..08f7ec6c8335 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/function/tvf/ForecastTableFunction.java @@ -19,14 +19,14 @@ package org.apache.iotdb.db.queryengine.plan.relational.function.tvf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; import org.apache.iotdb.db.exception.sql.SemanticException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.relational.TableFunction; import org.apache.iotdb.udf.api.relational.access.Record; @@ -74,8 +74,9 @@ public class ForecastTableFunction implements TableFunction { + private static final TsBlockSerde SERDE = new TsBlockSerde(); + public static class ForecastTableFunctionHandle implements TableFunctionHandle { - TEndPoint targetAINode; String modelId; int maxInputLength; int outputLength; @@ -95,7 +96,6 @@ public ForecastTableFunctionHandle( int outputLength, long outputStartTime, long outputInterval, - TEndPoint targetAINode, List types) { this.keepInput = keepInput; this.maxInputLength = maxInputLength; @@ -104,7 +104,6 @@ public ForecastTableFunctionHandle( this.outputLength = outputLength; this.outputStartTime = outputStartTime; this.outputInterval = outputInterval; - this.targetAINode = targetAINode; this.types = types; } @@ -112,8 +111,6 @@ public ForecastTableFunctionHandle( public byte[] serialize() { try (PublicBAOS publicBAOS = new PublicBAOS(); DataOutputStream outputStream = new DataOutputStream(publicBAOS)) { - ReadWriteIOUtils.write(targetAINode.getIp(), outputStream); - ReadWriteIOUtils.write(targetAINode.getPort(), outputStream); ReadWriteIOUtils.write(modelId, outputStream); ReadWriteIOUtils.write(maxInputLength, outputStream); ReadWriteIOUtils.write(outputLength, outputStream); @@ -138,8 +135,6 @@ public byte[] serialize() { @Override public void deserialize(byte[] bytes) { ByteBuffer buffer = ByteBuffer.wrap(bytes); - this.targetAINode = - new TEndPoint(ReadWriteIOUtils.readString(buffer), ReadWriteIOUtils.readInt(buffer)); this.modelId = ReadWriteIOUtils.readString(buffer); this.maxInputLength = ReadWriteIOUtils.readInt(buffer); this.outputLength = ReadWriteIOUtils.readInt(buffer); @@ -168,7 +163,6 @@ public boolean equals(Object o) { && outputStartTime == that.outputStartTime && outputInterval == that.outputInterval && keepInput == that.keepInput - && Objects.equals(targetAINode, that.targetAINode) && Objects.equals(modelId, that.modelId) && Objects.equals(options, that.options) && Objects.equals(types, that.types); @@ -177,7 +171,6 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash( - targetAINode, modelId, maxInputLength, outputLength, @@ -284,8 +277,6 @@ public TableFunctionAnalysis analyze(Map arguments) { String.format("%s should never be null or empty", MODEL_ID_PARAMETER_NAME)); } - TEndPoint targetAINode = getModelInfo(modelId).getTargetAINode(); - int outputLength = (int) ((ScalarArgument) arguments.get(OUTPUT_LENGTH_PARAMETER_NAME)).getValue(); if (outputLength <= 0) { @@ -390,7 +381,6 @@ public TableFunctionAnalysis analyze(Map arguments) { outputLength, outputStartTime, outputInterval, - targetAINode, predicatedColumnTypes); // outputColumnSchema @@ -417,10 +407,6 @@ public TableFunctionDataProcessor getDataProcessor() { }; } - private ModelInferenceDescriptor getModelInfo(String modelId) { - return modelFetcher.fetchModel(modelId); - } - // only allow for INT32, INT64, FLOAT, DOUBLE private void checkType(Type type, String columnName) { if (!ALLOWED_INPUT_TYPES.contains(type)) { @@ -456,9 +442,9 @@ private static Map parseOptions(String options) { private static class ForecastDataProcessor implements TableFunctionDataProcessor { private static final TsBlockSerde SERDE = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); - private final TEndPoint targetAINode; private final String modelId; private final int maxInputLength; private final int outputLength; @@ -471,7 +457,6 @@ private static class ForecastDataProcessor implements TableFunctionDataProcessor private final TsBlockBuilder inputTsBlockBuilder; public ForecastDataProcessor(ForecastTableFunctionHandle functionHandle) { - this.targetAINode = functionHandle.targetAINode; this.modelId = functionHandle.modelId; this.maxInputLength = functionHandle.maxInputLength; this.outputLength = functionHandle.outputLength; @@ -619,8 +604,12 @@ private TsBlock forecast() { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(modelId, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(modelId, SERDE.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException(e.getMessage(), CAN_NOT_CONNECT_AINODE.getStatusCode()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java index f0c041ad8053..db706d4980cb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/Metadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.fetcher.TableHeaderSchemaValidator; @@ -211,9 +210,4 @@ DataPartition getDataPartitionWithUnclosedTimeRange( final String database, final List sgNameToQueryParamsMap); TableFunction getTableFunction(final String functionName); - - /** - * @return ModelFetcher - */ - IModelFetcher getModelFetcher(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java index 9260ec8de032..7786876e19ad 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java @@ -1498,11 +1498,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - return modelFetcher; - } - public static boolean isTwoNumericType(List argumentTypes) { return argumentTypes.size() == 2 && isNumericType(argumentTypes.get(0)) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java index 4676559bd7b1..f8cf497546e6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java @@ -96,7 +96,6 @@ public List getDataNodeLocations(final String tableName) { case InformationSchema.TOPICS: case InformationSchema.SUBSCRIPTIONS: case InformationSchema.VIEWS: - case InformationSchema.MODELS: case InformationSchema.FUNCTIONS: case InformationSchema.CONFIGURATIONS: case InformationSchema.KEYWORDS: diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java index 260410954d4c..09f00f8ed672 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/udf/UDTFForecast.java @@ -19,14 +19,12 @@ package org.apache.iotdb.db.queryengine.plan.udf; +import org.apache.iotdb.ainode.rpc.thrift.TForecastReq; import org.apache.iotdb.ainode.rpc.thrift.TForecastResp; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.IClientManager; import org.apache.iotdb.commons.exception.IoTDBRuntimeException; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClient; -import org.apache.iotdb.db.protocol.client.ainode.AINodeClientManager; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; -import org.apache.iotdb.db.queryengine.plan.analyze.ModelFetcher; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; +import org.apache.iotdb.db.protocol.client.an.AINodeClient; +import org.apache.iotdb.db.protocol.client.an.AINodeClientManager; import org.apache.iotdb.rpc.TSStatusCode; import org.apache.iotdb.udf.api.UDTF; import org.apache.iotdb.udf.api.access.Row; @@ -54,8 +52,8 @@ public class UDTFForecast implements UDTF { private static final TsBlockSerde serde = new TsBlockSerde(); - private static final AINodeClientManager CLIENT_MANAGER = AINodeClientManager.getInstance(); - private TEndPoint targetAINode = new TEndPoint("127.0.0.1", 10810); + private static final IClientManager CLIENT_MANAGER = + AINodeClientManager.getInstance(); private String model_id; private int maxInputLength; private int outputLength; @@ -66,7 +64,6 @@ public class UDTFForecast implements UDTF { List types; private LinkedList inputRows; private TsBlockBuilder inputTsBlockBuilder; - private final IModelFetcher modelFetcher = ModelFetcher.getInstance(); private static final Set ALLOWED_INPUT_TYPES = new HashSet<>(); @@ -112,8 +109,6 @@ public void beforeStart(UDFParameters parameters, UDTFConfigurations configurati throw new IllegalArgumentException( "MODEL_ID parameter must be provided and cannot be empty."); } - ModelInferenceDescriptor descriptor = modelFetcher.fetchModel(this.model_id); - this.targetAINode = descriptor.getTargetAINode(); this.outputInterval = parameters.getLongOrDefault(OUTPUT_INTERVAL, DEFAULT_OUTPUT_INTERVAL); this.outputLength = @@ -211,8 +206,12 @@ private TsBlock forecast() throws Exception { TsBlock inputData = inputTsBlockBuilder.build(); TForecastResp resp; - try (AINodeClient client = CLIENT_MANAGER.borrowClient(targetAINode)) { - resp = client.forecast(model_id, inputData, outputLength, options); + try (AINodeClient client = + CLIENT_MANAGER.borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { + resp = + client.forecast( + new TForecastReq(model_id, serde.serialize(inputData), outputLength) + .setOptions(options)); } catch (Exception e) { throw new IoTDBRuntimeException( e.getMessage(), TSStatusCode.CAN_NOT_CONNECT_AINODE.getStatusCode()); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java index e60a14b727ca..79c031560973 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TSBSMetadata.java @@ -28,7 +28,6 @@ import org.apache.iotdb.commons.schema.table.column.TsTableColumnCategory; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.metadata.AlignedDeviceEntry; @@ -402,11 +401,6 @@ public TableFunction getTableFunction(String functionName) { return null; } - @Override - public IModelFetcher getModelFetcher() { - return null; - } - private static final DataPartition DATA_PARTITION = MockTSBSDataPartition.constructDataPartition(); diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java index e56b48936b96..7bbfe150ade4 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TableFunctionTest.java @@ -19,7 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.function.tvf.ForecastTableFunction; @@ -378,7 +377,6 @@ public void testForecastFunction() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan @@ -439,7 +437,6 @@ public void testForecastFunctionWithNoLowerCase() { 96, DEFAULT_OUTPUT_START_TIME, DEFAULT_OUTPUT_INTERVAL, - new TEndPoint("127.0.0.1", 10810), Collections.singletonList(DOUBLE))); // Verify full LogicalPlan // Output - TableFunctionProcessor - TableScan diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java index aa9fcdfd1b51..4b1d18944b73 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/TestMetadata.java @@ -19,8 +19,6 @@ package org.apache.iotdb.db.queryengine.plan.relational.analyzer; -import org.apache.iotdb.common.rpc.thrift.TEndPoint; -import org.apache.iotdb.commons.model.ModelInformation; import org.apache.iotdb.commons.partition.DataPartition; import org.apache.iotdb.commons.partition.DataPartitionQueryParam; import org.apache.iotdb.commons.partition.SchemaNodeManagementPartition; @@ -32,12 +30,10 @@ import org.apache.iotdb.db.exception.sql.SemanticException; import org.apache.iotdb.db.queryengine.common.MPPQueryContext; import org.apache.iotdb.db.queryengine.common.SessionInfo; -import org.apache.iotdb.db.queryengine.plan.analyze.IModelFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.function.Exclude; import org.apache.iotdb.db.queryengine.plan.function.Repeat; import org.apache.iotdb.db.queryengine.plan.function.Split; -import org.apache.iotdb.db.queryengine.plan.planner.plan.parameter.model.ModelInferenceDescriptor; import org.apache.iotdb.db.queryengine.plan.relational.function.OperatorType; import org.apache.iotdb.db.queryengine.plan.relational.function.TableBuiltinTableFunction; import org.apache.iotdb.db.queryengine.plan.relational.function.arithmetic.SubtractionResolver; @@ -560,21 +556,6 @@ public TableFunction getTableFunction(String functionName) { } } - @Override - public IModelFetcher getModelFetcher() { - String modelId = "timer_xl"; - IModelFetcher fetcher = Mockito.mock(IModelFetcher.class); - ModelInferenceDescriptor descriptor = Mockito.mock(ModelInferenceDescriptor.class); - Mockito.when(descriptor.getTargetAINode()).thenReturn(new TEndPoint("127.0.0.1", 10810)); - ModelInformation modelInformation = Mockito.mock(ModelInformation.class); - Mockito.when(modelInformation.available()).thenReturn(true); - Mockito.when(modelInformation.getInputShape()).thenReturn(new int[] {1440, 96}); - Mockito.when(descriptor.getModelInformation()).thenReturn(modelInformation); - Mockito.when(descriptor.getModelName()).thenReturn(modelId); - Mockito.when(fetcher.fetchModel(modelId)).thenReturn(descriptor); - return fetcher; - } - private static final DataPartition TABLE_DATA_PARTITION = MockTableModelDataPartition.constructDataPartition(DB1); diff --git a/iotdb-core/node-commons/pom.xml b/iotdb-core/node-commons/pom.xml index 85ff69ee8ac7..e7c508c195d5 100644 --- a/iotdb-core/node-commons/pom.xml +++ b/iotdb-core/node-commons/pom.xml @@ -65,6 +65,11 @@ iotdb-thrift-confignode 2.0.6-SNAPSHOT + + org.apache.iotdb + iotdb-thrift-ainode + 2.0.6-SNAPSHOT + org.apache.iotdb iotdb-thrift diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java index 106d67b6279d..115f322348c0 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/ClientPoolFactory.java @@ -20,6 +20,7 @@ package org.apache.iotdb.commons.client; import org.apache.iotdb.common.rpc.thrift.TEndPoint; +import org.apache.iotdb.commons.client.async.AsyncAINodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncConfigNodeInternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeExternalServiceClient; import org.apache.iotdb.commons.client.async.AsyncDataNodeInternalServiceClient; @@ -390,4 +391,31 @@ public GenericKeyedObjectPool create return clientPool; } } + + public static class AsyncAINodeHeartbeatServiceClientPoolFactory + implements IClientPoolFactory { + + @Override + public GenericKeyedObjectPool createClientPool( + ClientManager manager) { + GenericKeyedObjectPool clientPool = + new GenericKeyedObjectPool<>( + new AsyncAINodeInternalServiceClient.Factory( + manager, + new ThriftClientProperty.Builder() + .setConnectionTimeoutMs(conf.getCnConnectionTimeoutInMS()) + .setRpcThriftCompressionEnabled(conf.isRpcThriftCompressionEnabled()) + .setSelectorNumOfAsyncClientManager(conf.getSelectorNumOfClientManager()) + .setPrintLogWhenEncounterException(false) + .build(), + ThreadName.ASYNC_DATANODE_HEARTBEAT_CLIENT_POOL.getName()), + new ClientPoolProperty.Builder() + .setMaxClientNumForEachNode(conf.getMaxClientNumForEachNode()) + .build() + .getConfig()); + ClientManagerMetrics.getInstance() + .registerClientManager(this.getClass().getSimpleName(), clientPool); + return clientPool; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java similarity index 83% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java rename to iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java index 26130287697c..8cbd55759633 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/ainode/AsyncAINodeServiceClient.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/client/async/AsyncAINodeInternalServiceClient.java @@ -17,7 +17,7 @@ * under the License. */ -package org.apache.iotdb.db.protocol.client.ainode; +package org.apache.iotdb.commons.client.async; import org.apache.iotdb.ainode.rpc.thrift.IAINodeRPCService; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -35,20 +35,20 @@ import java.io.IOException; -public class AsyncAINodeServiceClient extends IAINodeRPCService.AsyncClient +public class AsyncAINodeInternalServiceClient extends IAINodeRPCService.AsyncClient implements ThriftClient { private static final CommonConfig commonConfig = CommonDescriptor.getInstance().getConfig(); - private final boolean printLogWhenEncounterException; private final TEndPoint endPoint; - private final ClientManager clientManager; + private final boolean printLogWhenEncounterException; + private final ClientManager clientManager; - public AsyncAINodeServiceClient( + public AsyncAINodeInternalServiceClient( ThriftClientProperty property, TEndPoint endPoint, TAsyncClientManager tClientManager, - ClientManager clientManager) + ClientManager clientManager) throws IOException { super( property.getProtocolFactory(), @@ -122,10 +122,10 @@ public boolean isReady() { } public static class Factory - extends AsyncThriftClientFactory { + extends AsyncThriftClientFactory { public Factory( - ClientManager clientManager, + ClientManager clientManager, ThriftClientProperty thriftClientProperty, String threadName) { super(clientManager, thriftClientProperty, threadName); @@ -133,14 +133,15 @@ public Factory( @Override public void destroyObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { pooledObject.getObject().close(); } @Override - public PooledObject makeObject(TEndPoint endPoint) throws Exception { + public PooledObject makeObject(TEndPoint endPoint) + throws Exception { return new DefaultPooledObject<>( - new AsyncAINodeServiceClient( + new AsyncAINodeInternalServiceClient( thriftClientProperty, endPoint, tManagers[clientCnt.incrementAndGet() % tManagers.length], @@ -149,7 +150,7 @@ public PooledObject makeObject(TEndPoint endPoint) thr @Override public boolean validateObject( - TEndPoint endPoint, PooledObject pooledObject) { + TEndPoint endPoint, PooledObject pooledObject) { return pooledObject.getObject().isReady(); } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java index 3fa107685438..01968833db7f 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelInformation.java @@ -32,9 +32,12 @@ public class ModelInformation { + private static final int[] DEFAULT_MODEL_INPUT_SHAPE = new int[] {2880, 1}; + private static final int[] DEFAULT_MODEL_OUTPUT_SHAPE = new int[] {720, 1}; + ModelType modelType; - private final String modelName; + private final String modelId; private final int[] inputShape; @@ -48,9 +51,17 @@ public class ModelInformation { String attribute = ""; + public ModelInformation(String modelId) { + this.modelId = modelId; + this.inputShape = DEFAULT_MODEL_INPUT_SHAPE; + this.inputDataType = new TSDataType[] {TSDataType.DOUBLE}; + this.outputShape = DEFAULT_MODEL_OUTPUT_SHAPE; + this.outputDataType = new TSDataType[] {TSDataType.DOUBLE}; + } + public ModelInformation( ModelType modelType, - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, @@ -58,7 +69,7 @@ public ModelInformation( String attribute, ModelStatus status) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -68,14 +79,14 @@ public ModelInformation( } public ModelInformation( - String modelName, + String modelId, int[] inputShape, int[] outputShape, TSDataType[] inputDataType, TSDataType[] outputDataType, String attribute) { this.modelType = ModelType.USER_DEFINED; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = inputShape; this.outputShape = outputShape; this.inputDataType = inputDataType; @@ -83,9 +94,9 @@ public ModelInformation( this.attribute = attribute; } - public ModelInformation(String modelName, ModelStatus status) { + public ModelInformation(String modelId, ModelStatus status) { this.modelType = ModelType.BUILT_IN_FORECAST; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[0]; this.outputShape = new int[0]; this.outputDataType = new TSDataType[0]; @@ -94,9 +105,9 @@ public ModelInformation(String modelName, ModelStatus status) { } // init built-in modelInformation - public ModelInformation(ModelType modelType, String modelName) { + public ModelInformation(ModelType modelType, String modelId) { this.modelType = modelType; - this.modelName = modelName; + this.modelId = modelId; this.inputShape = new int[2]; this.outputShape = new int[2]; this.inputDataType = new TSDataType[0]; @@ -116,8 +127,8 @@ public void updateStatus(ModelStatus status) { this.status = status; } - public String getModelName() { - return modelName; + public String getModelId() { + return modelId; } public void setInputLength(int length) { @@ -197,7 +208,7 @@ public void setAttribute(String attribute) { public void serialize(DataOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -222,7 +233,7 @@ public void serialize(DataOutputStream stream) throws IOException { public void serialize(FileOutputStream stream) throws IOException { ReadWriteIOUtils.write(modelType.ordinal(), stream); ReadWriteIOUtils.write(status.ordinal(), stream); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -247,7 +258,7 @@ public void serialize(FileOutputStream stream) throws IOException { public void serialize(ByteBuffer byteBuffer) { ReadWriteIOUtils.write(modelType.ordinal(), byteBuffer); ReadWriteIOUtils.write(status.ordinal(), byteBuffer); - ReadWriteIOUtils.write(modelName, byteBuffer); + ReadWriteIOUtils.write(modelId, byteBuffer); if (status == ModelStatus.UNAVAILABLE) { return; } @@ -353,7 +364,7 @@ public static ModelInformation deserialize(InputStream stream) throws IOExceptio public ByteBuffer serializeShowModelResult() throws IOException { PublicBAOS buffer = new PublicBAOS(); DataOutputStream stream = new DataOutputStream(buffer); - ReadWriteIOUtils.write(modelName, stream); + ReadWriteIOUtils.write(modelId, stream); ReadWriteIOUtils.write(modelType.toString(), stream); ReadWriteIOUtils.write(status.toString(), stream); ReadWriteIOUtils.write(Arrays.toString(inputShape), stream); @@ -370,7 +381,7 @@ public ByteBuffer serializeShowModelResult() throws IOException { public boolean equals(Object obj) { if (obj instanceof ModelInformation) { ModelInformation other = (ModelInformation) obj; - return modelName.equals(other.modelName) + return modelId.equals(other.modelId) && modelType.equals(other.modelType) && Arrays.equals(inputShape, other.inputShape) && Arrays.equals(outputShape, other.outputShape) diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java index 64aff12f284e..6c6100086316 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/model/ModelTable.java @@ -42,7 +42,7 @@ public boolean containsModel(String modelId) { } public void addModel(ModelInformation modelInformation) { - modelInfoMap.put(modelInformation.getModelName(), modelInformation); + modelInfoMap.put(modelInformation.getModelId(), modelInformation); } public void removeModel(String modelId) { @@ -63,7 +63,7 @@ public ModelInformation getModelInformationById(String modelId) { public void clearFailedModel() { for (ModelInformation modelInformation : modelInfoMap.values()) { if (modelInformation.getStatus() == ModelStatus.UNAVAILABLE) { - modelInfoMap.remove(modelInformation.getModelName()); + modelInfoMap.remove(modelInformation.getModelId()); } } } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java index 2db41cc3c2dd..243bc41c40ce 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java @@ -43,7 +43,6 @@ public class InformationSchema { public static final String TOPICS = "topics"; public static final String SUBSCRIPTIONS = "subscriptions"; public static final String VIEWS = "views"; - public static final String MODELS = "models"; public static final String FUNCTIONS = "functions"; public static final String CONFIGURATIONS = "configurations"; public static final String KEYWORDS = "keywords"; @@ -256,23 +255,6 @@ public class InformationSchema { viewTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(VIEWS, viewTable); - final TsTable modelTable = new TsTable(MODELS); - modelTable.addColumnSchema( - new TagColumnSchema(ColumnHeaderConstant.MODEL_ID_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema(ColumnHeaderConstant.MODEL_TYPE_TABLE_MODEL, TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.STATE.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.CONFIGS.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.NOTES.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); - modelTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); - schemaTables.put(MODELS, modelTable); - final TsTable functionTable = new TsTable(FUNCTIONS); functionTable.addColumnSchema( new TagColumnSchema(ColumnHeaderConstant.FUNCTION_NAME_TABLE_MODEL, TSDataType.STRING)); diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index e0680cd29b79..1dc2f025f5c3 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -60,13 +60,7 @@ struct TRegisterModelResp { struct TInferenceReq { 1: required string modelId 2: required binary dataset - 3: optional TWindowParams windowParams - 4: optional map inferenceAttributes -} - -struct TWindowParams { - 1: required i32 windowInterval - 2: required i32 windowStep + 3: optional map inferenceAttributes } struct TInferenceResp { diff --git a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift index d8f6318063eb..f2b8ec6b8b07 100644 --- a/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift +++ b/iotdb-protocol/thrift-confignode/src/main/thrift/confignode.thrift @@ -1096,34 +1096,6 @@ struct TUnsetSchemaTemplateReq { 4: optional bool isGeneratedByPipe } -struct TCreateModelReq { - 1: required string modelName - 2: required string uri -} - -struct TDropModelReq { - 1: required string modelId -} - -struct TGetModelInfoReq { - 1: required string modelId -} - -struct TGetModelInfoResp { - 1: required common.TSStatus status - 2: optional binary modelInfo - 3: optional common.TEndPoint aiNodeAddress -} - -struct TUpdateModelInfoReq { - 1: required string modelId - 2: required i32 modelStatus - 3: optional string attributes - 4: optional list aiNodeIds - 5: optional i32 inputLength - 6: optional i32 outputLength -} - struct TDataSchemaForTable{ 1: required string targetSql } @@ -1132,16 +1104,6 @@ struct TDataSchemaForTree{ 1: required list path } -struct TCreateTrainingReq { - 1: required string modelId - 2: required bool isTableModel - 3: required string existingModelId - 4: optional TDataSchemaForTable dataSchemaForTable - 5: optional TDataSchemaForTree dataSchemaForTree - 6: optional map parameters - 7: optional list> timeRanges -} - // ==================================================== // Quota // ==================================================== @@ -2006,31 +1968,6 @@ service IConfigNodeRPCService { */ TShowCQResp showCQ() - // ==================================================== - // AI Model - // ==================================================== - - /** - * Create a model - * - * @return SUCCESS_STATUS if the model was created successfully - */ - common.TSStatus createModel(TCreateModelReq req) - - /** - * Drop a model - * - * @return SUCCESS_STATUS if the model was removed successfully - */ - common.TSStatus dropModel(TDropModelReq req) - - /** - * Return the model info by model_id - */ - TGetModelInfoResp getModelInfo(TGetModelInfoReq req) - - common.TSStatus updateModelInfo(TUpdateModelInfoReq req) - // ====================================================== // Quota // ====================================================== From 9397c134b3dc2f3aae930c9037379d195e908f47 Mon Sep 17 00:00:00 2001 From: Zhenyu Luo Date: Wed, 10 Dec 2025 11:39:49 +0800 Subject: [PATCH 08/13] Active Load: Add cleanup for active load listening directories on DataNode first startup (#16854) * Add cleanup for active load listening directories on DataNode first startup - Add cleanupListeningDirectories() method in ActiveLoadAgent to clean up all listening directories - Call cleanup method when DataNode starts for the first time - Clean up pending, pipe, and failed directories - Silent execution with minimal logging * update * fix (cherry picked from commit bfa71e00e763c62c3d4ef3f5b459d1a814d91ddb) --- .../org/apache/iotdb/db/service/DataNode.java | 3 + .../load/active/ActiveLoadAgent.java | 90 +++++++++++++++++++ 2 files changed, 93 insertions(+) diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java index c4b08f8f05a8..1fac05012fe8 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/service/DataNode.java @@ -111,6 +111,7 @@ import org.apache.iotdb.db.storageengine.dataregion.memtable.TsFileProcessor; import org.apache.iotdb.db.storageengine.dataregion.wal.WALManager; import org.apache.iotdb.db.storageengine.dataregion.wal.utils.WALMode; +import org.apache.iotdb.db.storageengine.load.active.ActiveLoadAgent; import org.apache.iotdb.db.storageengine.rescon.disk.TierManager; import org.apache.iotdb.db.subscription.agent.SubscriptionAgent; import org.apache.iotdb.db.trigger.executor.TriggerExecutor; @@ -245,6 +246,8 @@ protected void start() { sendRegisterRequestToConfigNode(false); saveSecretKey(); saveHardwareCode(); + // Clean up active load listening directories on first startup + ActiveLoadAgent.cleanupListeningDirectories(); } else { /* Check encrypt magic string */ try { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java index f060bd9a96f3..6065c349c8c5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/storageengine/load/active/ActiveLoadAgent.java @@ -19,8 +19,21 @@ package org.apache.iotdb.db.storageengine.load.active; +import org.apache.iotdb.commons.utils.FileUtils; +import org.apache.iotdb.db.conf.IoTDBDescriptor; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.File; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + public class ActiveLoadAgent { + private static final Logger LOGGER = LoggerFactory.getLogger(ActiveLoadAgent.class); + private final ActiveLoadTsFileLoader activeLoadTsFileLoader; private final ActiveLoadDirScanner activeLoadDirScanner; private final ActiveLoadMetricsCollector activeLoadMetricsCollector; @@ -48,4 +61,81 @@ public synchronized void start() { activeLoadDirScanner.start(); activeLoadMetricsCollector.start(); } + + /** + * Clean up all listening directories for active load on DataNode first startup. This method will + * clean up all files and subdirectories in the listening directories, including: 1. Pending + * directories (configured by load_active_listening_dirs) 2. Pipe directory (for pipe data sync) + * 3. Failed directory (for failed files) + * + *

This method is called during DataNode startup and must not throw any exceptions to ensure + * startup can proceed normally. All exceptions are caught and logged internally. + */ + public static void cleanupListeningDirectories() { + try { + final List dirsToClean = new ArrayList<>(); + + dirsToClean.addAll( + Arrays.asList(IoTDBDescriptor.getInstance().getConfig().getLoadActiveListeningDirs())); + + // Add pipe dir + dirsToClean.add(IoTDBDescriptor.getInstance().getConfig().getLoadActiveListeningPipeDir()); + + // Add failed dir + dirsToClean.add(IoTDBDescriptor.getInstance().getConfig().getLoadActiveListeningFailDir()); + + // Clean up each directory + for (final String dirPath : dirsToClean) { + try { + if (dirPath == null || dirPath.isEmpty()) { + continue; + } + + final File dir = new File(dirPath); + + // Check if directory exists and is a directory + // These methods may throw SecurityException if access is denied + try { + if (!dir.exists() || !dir.isDirectory()) { + continue; + } + } catch (Exception e) { + LOGGER.debug("Failed to check directory: {}", dirPath, e); + continue; + } + + // Only delete contents inside the directory, not the directory itself + // listFiles() may throw SecurityException if access is denied + File[] files = null; + try { + files = dir.listFiles(); + } catch (Exception e) { + LOGGER.warn("Failed to list files in directory: {}", dirPath, e); + continue; + } + + if (files != null) { + for (final File file : files) { + // FileUtils.deleteFileOrDirectory internally calls file.isDirectory() and + // file.listFiles() without try-catch, so exceptions may propagate here. + // We need to catch it to prevent one file failure from stopping the cleanup. + try { + FileUtils.deleteFileOrDirectory(file, true); + } catch (Exception e) { + LOGGER.debug("Failed to delete file or directory: {}", file.getAbsolutePath(), e); + } + } + } + } catch (Exception e) { + LOGGER.warn("Failed to cleanup directory: {}", dirPath, e); + } + } + + LOGGER.info("Cleaned up active load listening directories"); + } catch (Throwable t) { + // Catch all exceptions and errors (including OutOfMemoryError, etc.) + // to ensure startup process is not affected + LOGGER.warn("Unexpected error during cleanup of active load listening directories", t); + } + } } From f74eb2671a248a62e6ba5ce486aff8435d1e2f0a Mon Sep 17 00:00:00 2001 From: libo Date: Wed, 10 Dec 2025 12:22:24 +0800 Subject: [PATCH 09/13] Remove the code check port is occupied and resolve the problem that can't rename file successfully (#16889) (cherry picked from commit 4a481f02bf8856b6588f371bbfed9b1359b9ce55) --- .../commons/file/SystemPropertiesHandler.java | 13 ++--- scripts/sbin/windows/start-confignode.bat | 28 ----------- scripts/sbin/windows/start-datanode.bat | 48 ------------------- 3 files changed, 3 insertions(+), 86 deletions(-) diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java index dfbb2104cfb8..ab28952dbcd0 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/file/SystemPropertiesHandler.java @@ -21,9 +21,9 @@ import org.apache.iotdb.commons.conf.ConfigurationFileUtils; import org.apache.iotdb.commons.conf.IoTDBConstant; +import org.apache.iotdb.commons.utils.FileUtils; import org.apache.ratis.util.AutoCloseableLock; -import org.apache.ratis.util.FileUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -184,15 +184,8 @@ private void replaceFormalFile() throws IOException { "Delete formal system properties file fail: %s", formalFile.getAbsoluteFile()); throw new IOException(msg); } - try { - FileUtils.move(tmpFile.toPath(), formalFile.toPath()); - } catch (IOException e) { - String msg = - String.format( - "Failed to replace formal system properties file, you may manually rename it: %s -> %s", - tmpFile.getAbsolutePath(), formalFile.getAbsolutePath()); - throw new IOException(msg, e); - } + + FileUtils.moveFileSafe(tmpFile, formalFile); } public void resetFilePath(String filePath) { diff --git a/scripts/sbin/windows/start-confignode.bat b/scripts/sbin/windows/start-confignode.bat index 2501a0645c2a..64efa6f5580f 100644 --- a/scripts/sbin/windows/start-confignode.bat +++ b/scripts/sbin/windows/start-confignode.bat @@ -111,34 +111,6 @@ IF DEFINED CONFIG_FILE ( set cn_consensus_port=10720 ) -echo Check whether the ports are occupied.... -set occupied=0 -set cn_internal_port_occupied=0 -set cn_consensus_port_occupied=0 -for /f "tokens=1,3,7 delims=: " %%i in ('netstat /ano') do ( - if %%i==TCP ( - if %%j==%cn_internal_port% ( - if !cn_internal_port_occupied!==0 ( - echo The cn_internal_port %cn_internal_port% is already occupied, pid:%%k - set occupied=1 - set cn_internal_port_occupied=1 - ) - ) else if %%j==%cn_consensus_port% ( - if !cn_consensus_port_occupied!==0 ( - echo The cn_consensus_port %cn_consensus_port% is already occupied, pid:%%k - set occupied=1 - set cn_consensus_port_occupied=1 - ) - ) - ) -) - -if %occupied%==1 ( - echo There exists occupied port, please change the configuration. - TIMEOUT /T 10 /NOBREAK - exit 0 -) - set CONF_PARAMS=-s if NOT DEFINED MAIN_CLASS set MAIN_CLASS=org.apache.iotdb.confignode.service.ConfigNode if NOT DEFINED JAVA_HOME goto :err diff --git a/scripts/sbin/windows/start-datanode.bat b/scripts/sbin/windows/start-datanode.bat index 30a7aa50e836..0cf3e6487839 100755 --- a/scripts/sbin/windows/start-datanode.bat +++ b/scripts/sbin/windows/start-datanode.bat @@ -146,54 +146,6 @@ IF DEFINED CONFIG_FILE ( set dn_data_region_consensus_port=10760 ) -echo Check whether the ports are occupied.... -set occupied=0 -set dn_rpc_port_occupied=0 -set dn_internal_port_occupied=0 -set dn_mpp_data_exchange_port_occupied=0 -set dn_schema_region_consensus_port_occupied=0 -set dn_data_region_consensus_port_occupied=0 -for /f "tokens=1,3,7 delims=: " %%i in ('netstat /ano') do ( - if %%i==TCP ( - if %%j==%dn_rpc_port% ( - if !dn_rpc_port_occupied!==0 ( - echo The dn_rpc_port %dn_rpc_port% is already occupied, pid:%%k - set occupied=1 - set dn_rpc_port_occupied=1 - ) - ) else if %%j==%dn_internal_port% ( - if !dn_internal_port_occupied!==0 ( - echo The dn_internal_port %dn_internal_port% is already occupied, pid:%%k - set occupied=1 - set dn_internal_port_occupied=1 - ) - ) else if %%j==%dn_mpp_data_exchange_port% ( - if !dn_mpp_data_exchange_port_occupied!==0 ( - echo The dn_mpp_data_exchange_port %dn_mpp_data_exchange_port% is already occupied, pid:%%k - set occupied=1 - set dn_mpp_data_exchange_port_occupied=1 - ) - ) else if %%j==%dn_schema_region_consensus_port% ( - if !dn_schema_region_consensus_port_occupied!==0 ( - echo The dn_schema_region_consensus_port %dn_schema_region_consensus_port% is already occupied, pid:%%k - set occupied=1 - set dn_schema_region_consensus_port_occupied=1 - ) - ) else if %%j==%dn_data_region_consensus_port% ( - if !dn_data_region_consensus_port_occupied!==0 ( - echo The dn_data_region_consensus_port %dn_data_region_consensus_port% is already occupied, pid:%%k - set occupied=1 - ) - ) - ) -) - -if %occupied%==1 ( - echo There exists occupied port, please change the configuration. - TIMEOUT /T 10 /NOBREAK - exit 0 -) - @setlocal ENABLEDELAYEDEXPANSION ENABLEEXTENSIONS set CONF_PARAMS=-s if NOT DEFINED MAIN_CLASS set MAIN_CLASS=org.apache.iotdb.db.service.DataNode From 1a51743f80d582ecd88fd2953cb4c7bfa1001c72 Mon Sep 17 00:00:00 2001 From: libo Date: Wed, 10 Dec 2025 17:29:56 +0800 Subject: [PATCH 10/13] Remove the code check port is occupied, and resolve the problem that can't rename file successfully. (#16893) (cherry picked from commit 5d1efef31dab1ee7baab8dd5ba73b143a74c2709) --- .../org/apache/iotdb/session/it/IoTDBConnectionsIT.java | 6 +++--- .../iotdb/commons/schema/column/ColumnHeaderConstant.java | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java index bc043c79e8a6..80aa854e66cc 100644 --- a/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/session/it/IoTDBConnectionsIT.java @@ -176,7 +176,7 @@ public void testSameDataNodeGetConnections() { ResultSet resultSet = statement.executeQuery( - "SELECT * FROM connections WHERE data_node_id = '" + dataNodeId + "'"); + "SELECT * FROM connections WHERE datanode_id = '" + dataNodeId + "'"); if (!resultSet.next()) { fail(); } @@ -256,7 +256,7 @@ public void testClosedDataNodeGetConnections() throws Exception { ResultSet resultSet = statement.executeQuery( - "SELECT COUNT(*) FROM connections WHERE data_node_id = '" + closedDataNodeId + "'"); + "SELECT COUNT(*) FROM connections WHERE datanode_id = '" + closedDataNodeId + "'"); if (!resultSet.next()) { fail(); } @@ -306,7 +306,7 @@ public void testClosedDataNodeGetConnections() throws Exception { ResultSet resultSet = statement.executeQuery( - "SELECT COUNT(*) FROM connections WHERE data_node_id = '" + closedDataNodeId + "'"); + "SELECT COUNT(*) FROM connections WHERE datanode_id = '" + closedDataNodeId + "'"); if (!resultSet.next()) { fail(); } diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java index 943bdeb9cba4..67850d991cb6 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java @@ -206,7 +206,7 @@ private ColumnHeaderConstant() { public static final String STATEMENT = "Statement"; // column names for show idle connection - public static final String DATANODE_ID = "data_node_id"; + public static final String DATANODE_ID = "datanode_id"; public static final String USERID = "user_id"; public static final String SESSION_ID = "session_id"; public static final String USER_NAME = "user_name"; From 97fa584ae1281449c53883e712437ddfca490598 Mon Sep 17 00:00:00 2001 From: Weihao Li <60659567+Wei-hao-Li@users.noreply.github.com> Date: Wed, 10 Dec 2025 19:36:32 +0800 Subject: [PATCH 11/13] Support system table current_queries and queries_costs_histogram (#16890) (cherry picked from commit 94461b065095dd8499c85b8691d89da35a878299) --- .../env/cluster/config/MppDataNodeConfig.java | 6 + .../remote/config/RemoteDataNodeConfig.java | 5 + .../iotdb/itbase/env/DataNodeConfig.java | 2 + .../IoTDBCurrentQueriesIT.java | 205 ++++++++++++ .../relational/it/schema/IoTDBDatabaseIT.java | 6 +- .../org/apache/iotdb/db/conf/IoTDBConfig.java | 11 + .../apache/iotdb/db/conf/IoTDBDescriptor.java | 5 + .../rest/v2/impl/RestApiServiceImpl.java | 16 +- .../thrift/impl/ClientRPCServiceImpl.java | 16 +- .../iotdb/db/queryengine/common/QueryId.java | 7 + ...formationSchemaContentSupplierFactory.java | 146 ++++++++- .../db/queryengine/plan/Coordinator.java | 301 +++++++++++++++++- .../plan/execution/IQueryExecution.java | 2 + .../plan/execution/QueryExecution.java | 5 + .../execution/config/ConfigExecution.java | 4 + .../DataNodeLocationSupplierFactory.java | 2 + .../operator/MergeTreeSortOperatorTest.java | 5 + .../plan/relational/planner/PlanTester.java | 2 + .../informationschema/CurrentQueriesTest.java | 107 +++++++ .../informationschema}/ShowQueriesTest.java | 8 +- .../conf/iotdb-system.properties.template | 6 + .../iotdb/commons/concurrent/ThreadName.java | 1 + .../schema/column/ColumnHeaderConstant.java | 10 +- .../schema/table/InformationSchema.java | 41 ++- 24 files changed, 895 insertions(+), 24 deletions(-) create mode 100644 integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java create mode 100644 iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java rename iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/{analyzer => planner/informationschema}/ShowQueriesTest.java (94%) diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java index 357c15c78569..5e418072a7d9 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/cluster/config/MppDataNodeConfig.java @@ -137,4 +137,10 @@ public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportio setProperty("datanode_memory_proportion", dataNodeMemoryProportion); return this; } + + @Override + public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) { + setProperty("query_cost_stat_window", String.valueOf(queryCostStatWindow)); + return this; + } } diff --git a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java index 1af7cb8f613a..bba4c964f957 100644 --- a/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/it/env/remote/config/RemoteDataNodeConfig.java @@ -93,4 +93,9 @@ public DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs) public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion) { return this; } + + @Override + public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) { + return this; + } } diff --git a/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java b/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java index 0ae46ffc70f2..d57015b13964 100644 --- a/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java +++ b/integration-test/src/main/java/org/apache/iotdb/itbase/env/DataNodeConfig.java @@ -51,4 +51,6 @@ DataNodeConfig setLoadTsFileAnalyzeSchemaMemorySizeInBytes( DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs); DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion); + + DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow); } diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java new file mode 100644 index 000000000000..66941ec784f8 --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/informationschema/IoTDBCurrentQueriesIT.java @@ -0,0 +1,205 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.relational.it.query.recent.informationschema; + +import org.apache.iotdb.commons.conf.CommonDescriptor; +import org.apache.iotdb.db.queryengine.execution.QueryState; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.TableLocalStandaloneIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.END_TIME_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.NUMS; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATE_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.table.InformationSchema.getSchemaTables; +import static org.apache.iotdb.db.it.utils.TestUtils.createUser; +import static org.apache.iotdb.itbase.env.BaseEnv.TABLE_SQL_DIALECT; +import static org.junit.Assert.fail; + +@RunWith(IoTDBTestRunner.class) +@Category({TableLocalStandaloneIT.class}) +// This IT will run at least 60s, so we only run it in 1C1D +public class IoTDBCurrentQueriesIT { + private static final int CURRENT_QUERIES_COLUMN_NUM = + getSchemaTables().get("current_queries").getColumnNum(); + private static final int QUERIES_COSTS_HISTOGRAM_COLUMN_NUM = + getSchemaTables().get("queries_costs_histogram").getColumnNum(); + private static final String ADMIN_NAME = + CommonDescriptor.getInstance().getConfig().getDefaultAdminName(); + private static final String ADMIN_PWD = + CommonDescriptor.getInstance().getConfig().getAdminPassword(); + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().getConfig().getDataNodeConfig().setQueryCostStatWindow(1); + EnvFactory.getEnv().initClusterEnvironment(); + createUser("test", "test123123456"); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + @Test + public void testCurrentQueries() { + try { + Connection connection = + EnvFactory.getEnv().getConnection(ADMIN_NAME, ADMIN_PWD, BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement(); + statement.execute("USE information_schema"); + + // 1. query current_queries table + String sql = "SELECT * FROM current_queries"; + ResultSet resultSet = statement.executeQuery(sql); + ResultSetMetaData metaData = resultSet.getMetaData(); + Assert.assertEquals(CURRENT_QUERIES_COLUMN_NUM, metaData.getColumnCount()); + int rowNum = 0; + while (resultSet.next()) { + Assert.assertEquals(QueryState.RUNNING.name(), resultSet.getString(STATE_TABLE_MODEL)); + Assert.assertEquals(null, resultSet.getString(END_TIME_TABLE_MODEL)); + Assert.assertEquals(sql, resultSet.getString(STATEMENT_TABLE_MODEL)); + Assert.assertEquals(ADMIN_NAME, resultSet.getString(USER_TABLE_MODEL)); + rowNum++; + } + Assert.assertEquals(1, rowNum); + resultSet.close(); + + // 2. query queries_costs_histogram table + sql = "SELECT * FROM queries_costs_histogram"; + resultSet = statement.executeQuery(sql); + metaData = resultSet.getMetaData(); + Assert.assertEquals(QUERIES_COSTS_HISTOGRAM_COLUMN_NUM, metaData.getColumnCount()); + rowNum = 0; + int queriesCount = 0; + while (resultSet.next()) { + int nums = resultSet.getInt(NUMS); + if (nums > 0) { + queriesCount++; + } + rowNum++; + } + Assert.assertEquals(1, queriesCount); + Assert.assertEquals(61, rowNum); + + // 3. requery current_queries table + sql = "SELECT * FROM current_queries"; + resultSet = statement.executeQuery(sql); + metaData = resultSet.getMetaData(); + Assert.assertEquals(CURRENT_QUERIES_COLUMN_NUM, metaData.getColumnCount()); + rowNum = 0; + int finishedQueries = 0; + while (resultSet.next()) { + if (QueryState.FINISHED.name().equals(resultSet.getString(STATE_TABLE_MODEL))) { + finishedQueries++; + } + rowNum++; + } + // three rows in the result, 2 FINISHED and 1 RUNNING + Assert.assertEquals(3, rowNum); + Assert.assertEquals(2, finishedQueries); + resultSet.close(); + + // 4. test the expired QueryInfo was evicted + Thread.sleep(61_001); + resultSet = statement.executeQuery(sql); + rowNum = 0; + while (resultSet.next()) { + rowNum++; + } + // one row in the result, current query + Assert.assertEquals(1, rowNum); + resultSet.close(); + + sql = "SELECT * FROM queries_costs_histogram"; + resultSet = statement.executeQuery(sql); + queriesCount = 0; + while (resultSet.next()) { + int nums = resultSet.getInt(NUMS); + if (nums > 0) { + queriesCount++; + } + } + // the last current_queries table query was recorded, others are evicted + Assert.assertEquals(1, queriesCount); + } catch (Exception e) { + fail(e.getMessage()); + } + + // 5. test privilege + testPrivilege(); + } + + private void testPrivilege() { + // 1. test current_queries table + try (Connection connection = + EnvFactory.getEnv().getConnection("test", "test123123456", TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + String sql = "SELECT * FROM information_schema.current_queries"; + + // another user executes a query + try (Connection connection2 = + EnvFactory.getEnv().getConnection(ADMIN_NAME, ADMIN_PWD, BaseEnv.TABLE_SQL_DIALECT)) { + ResultSet resultSet = connection2.createStatement().executeQuery(sql); + resultSet.close(); + } catch (Exception e) { + fail(e.getMessage()); + } + + // current user query current_queries table + ResultSet resultSet = statement.executeQuery(sql); + int rowNum = 0; + while (resultSet.next()) { + rowNum++; + } + // only current query in the result + Assert.assertEquals(1, rowNum); + } catch (SQLException e) { + fail(e.getMessage()); + } + + // 2. test queries_costs_histogram table + try (Connection connection = + EnvFactory.getEnv().getConnection("test", "test123123456", TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.executeQuery("SELECT * FROM information_schema.queries_costs_histogram"); + } catch (SQLException e) { + Assert.assertEquals( + "803: Access Denied: No permissions for this operation, please add privilege SYSTEM", + e.getMessage()); + } + } +} diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java index e7bab16ad1ff..4736d9b0521d 100644 --- a/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/schema/IoTDBDatabaseIT.java @@ -399,6 +399,7 @@ public void testInformationSchema() throws SQLException { "config_nodes,INF,", "configurations,INF,", "connections,INF,", + "current_queries,INF,", "data_nodes,INF,", "databases,INF,", "functions,INF,", @@ -407,6 +408,7 @@ public void testInformationSchema() throws SQLException { "pipe_plugins,INF,", "pipes,INF,", "queries,INF,", + "queries_costs_histogram,INF,", "regions,INF,", "subscriptions,INF,", "tables,INF,", @@ -634,12 +636,14 @@ public void testInformationSchema() throws SQLException { "information_schema,config_nodes,INF,USING,null,SYSTEM VIEW,", "information_schema,data_nodes,INF,USING,null,SYSTEM VIEW,", "information_schema,connections,INF,USING,null,SYSTEM VIEW,", + "information_schema,current_queries,INF,USING,null,SYSTEM VIEW,", + "information_schema,queries_costs_histogram,INF,USING,null,SYSTEM VIEW,", "test,test,INF,USING,test,BASE TABLE,", "test,view_table,100,USING,null,VIEW FROM TREE,"))); TestUtils.assertResultSetEqual( statement.executeQuery("count devices from tables where status = 'USING'"), "count(devices),", - Collections.singleton("19,")); + Collections.singleton("21,")); TestUtils.assertResultSetEqual( statement.executeQuery( "select * from columns where table_name = 'queries' or database = 'test'"), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java index 7591e9d3cabf..3f19a2101a80 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBConfig.java @@ -814,6 +814,9 @@ public class IoTDBConfig { /** time cost(ms) threshold for slow query. Unit: millisecond */ private long slowQueryThreshold = 10000; + /** time window threshold for record of history queries. Unit: minute */ + private int queryCostStatWindow = 0; + private int patternMatchingThreshold = 1000000; /** @@ -2628,6 +2631,14 @@ public void setSlowQueryThreshold(long slowQueryThreshold) { this.slowQueryThreshold = slowQueryThreshold; } + public int getQueryCostStatWindow() { + return queryCostStatWindow; + } + + public void setQueryCostStatWindow(int queryCostStatWindow) { + this.queryCostStatWindow = queryCostStatWindow; + } + public boolean isEnableIndex() { return enableIndex; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java index bd579b24c3f4..e51f445c5678 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/conf/IoTDBDescriptor.java @@ -813,6 +813,11 @@ public void loadProperties(TrimProperties properties) throws BadNodeUrlException properties.getProperty( "slow_query_threshold", String.valueOf(conf.getSlowQueryThreshold())))); + conf.setQueryCostStatWindow( + Integer.parseInt( + properties.getProperty( + "query_cost_stat_window", String.valueOf(conf.getQueryCostStatWindow())))); + conf.setDataRegionNum( Integer.parseInt( properties.getProperty("data_region_num", String.valueOf(conf.getDataRegionNum())))); diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java index 66be144edefb..b657523987dc 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/rest/v2/impl/RestApiServiceImpl.java @@ -212,7 +212,8 @@ public Response executeFastLastQueryStatement( t = e; return Response.ok().entity(ExceptionHandler.tryCatchException(e)).build(); } finally { - long costTime = System.nanoTime() - startTime; + long endTime = System.nanoTime(); + long costTime = endTime - startTime; StatementType statementType = Optional.ofNullable(statement) @@ -227,7 +228,18 @@ public Response executeFastLastQueryStatement( if (queryId != null) { COORDINATOR.cleanupQueryExecution(queryId); } else { - recordQueries(() -> costTime, new FastLastQueryContentSupplier(prefixPathList), t); + IClientSession clientSession = SESSION_MANAGER.getCurrSession(); + + Supplier contentOfQuerySupplier = new FastLastQueryContentSupplier(prefixPathList); + COORDINATOR.recordCurrentQueries( + null, + startTime / 1_000_000, + endTime / 1_000_000, + costTime, + contentOfQuerySupplier, + clientSession.getUsername(), + clientSession.getClientAddress()); + recordQueries(() -> costTime, contentOfQuerySupplier, t); } } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java index 07c2800799c8..167a1fa914fd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/thrift/impl/ClientRPCServiceImpl.java @@ -1050,13 +1050,23 @@ public TSExecuteStatementResp executeFastLastDataQueryForOnePrefixPath( resp.setMoreData(false); - long costTime = System.nanoTime() - startTime; + long endTime = System.nanoTime(); + long costTime = endTime - startTime; CommonUtils.addStatementExecutionLatency( OperationType.EXECUTE_QUERY_STATEMENT, StatementType.FAST_LAST_QUERY.name(), costTime); CommonUtils.addQueryLatency(StatementType.FAST_LAST_QUERY, costTime); - recordQueries( - () -> costTime, () -> String.format("thrift fastLastQuery %s", prefixPath), null); + + String statement = String.format("thrift fastLastQuery %s", prefixPath); + COORDINATOR.recordCurrentQueries( + null, + startTime / 1_000_000, + endTime / 1_000_000, + costTime, + () -> statement, + clientSession.getUsername(), + clientSession.getClientAddress()); + recordQueries(() -> costTime, () -> statement, null); return resp; } catch (final Exception e) { return RpcUtils.getTSExecuteStatementResp( diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java index a59ce8334bf5..44e67aa7ba60 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/common/QueryId.java @@ -19,6 +19,7 @@ package org.apache.iotdb.db.queryengine.common; +import org.apache.iotdb.db.conf.IoTDBDescriptor; import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId; import org.apache.tsfile.utils.ReadWriteIOUtils; @@ -37,6 +38,8 @@ public class QueryId { public static final QueryId MOCK_QUERY_ID = QueryId.valueOf("mock_query_id"); + private static final int DATANODE_ID = IoTDBDescriptor.getInstance().getConfig().getDataNodeId(); + private final String id; private int nextPlanNodeIndex; @@ -67,6 +70,10 @@ public String getId() { return id; } + public static int getDataNodeId() { + return DATANODE_ID; + } + @Override public String toString() { return id; diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java index daceffce6b7b..76d84c741de5 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/InformationSchemaContentSupplierFactory.java @@ -68,6 +68,7 @@ import org.apache.iotdb.db.protocol.session.IClientSession; import org.apache.iotdb.db.protocol.session.SessionManager; import org.apache.iotdb.db.queryengine.common.ConnectionInfo; +import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.plan.Coordinator; import org.apache.iotdb.db.queryengine.plan.execution.IQueryExecution; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.relational.ShowCreateViewTask; @@ -166,6 +167,10 @@ public static Iterator getSupplier( return new DataNodesSupplier(dataTypes, userEntity); case InformationSchema.CONNECTIONS: return new ConnectionsSupplier(dataTypes, userEntity); + case InformationSchema.CURRENT_QUERIES: + return new CurrentQueriesSupplier(dataTypes, userEntity); + case InformationSchema.QUERIES_COSTS_HISTOGRAM: + return new QueriesCostsHistogramSupplier(dataTypes, userEntity); default: throw new UnsupportedOperationException("Unknown table: " + tableName); } @@ -201,14 +206,11 @@ protected void constructLine() { final IQueryExecution queryExecution = queryExecutions.get(nextConsumedIndex); if (queryExecution.getSQLDialect().equals(IClientSession.SqlDialect.TABLE)) { - final String[] splits = queryExecution.getQueryId().split("_"); - final int dataNodeId = Integer.parseInt(splits[splits.length - 1]); - columnBuilders[0].writeBinary(BytesUtils.valueOf(queryExecution.getQueryId())); columnBuilders[1].writeLong( TimestampPrecisionUtils.convertToCurrPrecision( queryExecution.getStartExecutionTime(), TimeUnit.MILLISECONDS)); - columnBuilders[2].writeInt(dataNodeId); + columnBuilders[2].writeInt(QueryId.getDataNodeId()); columnBuilders[3].writeFloat( (float) (currTime - queryExecution.getStartExecutionTime()) / 1000); columnBuilders[4].writeBinary( @@ -1181,4 +1183,140 @@ public boolean hasNext() { return sessionConnectionIterator.hasNext(); } } + + private static class CurrentQueriesSupplier extends TsBlockSupplier { + private int nextConsumedIndex; + private List queriesInfo; + + private CurrentQueriesSupplier(final List dataTypes, final UserEntity userEntity) { + super(dataTypes); + queriesInfo = Coordinator.getInstance().getCurrentQueriesInfo(); + try { + accessControl.checkUserGlobalSysPrivilege(userEntity); + } catch (final AccessDeniedException e) { + queriesInfo = + queriesInfo.stream() + .filter(iQueryInfo -> userEntity.getUsername().equals(iQueryInfo.getUser())) + .collect(Collectors.toList()); + } + } + + @Override + protected void constructLine() { + final Coordinator.StatedQueriesInfo queryInfo = queriesInfo.get(nextConsumedIndex); + columnBuilders[0].writeBinary(BytesUtils.valueOf(queryInfo.getQueryId())); + columnBuilders[1].writeBinary(BytesUtils.valueOf(queryInfo.getQueryState())); + columnBuilders[2].writeLong( + TimestampPrecisionUtils.convertToCurrPrecision( + queryInfo.getStartTime(), TimeUnit.MILLISECONDS)); + if (queryInfo.getEndTime() == Coordinator.QueryInfo.DEFAULT_END_TIME) { + columnBuilders[3].appendNull(); + } else { + columnBuilders[3].writeLong( + TimestampPrecisionUtils.convertToCurrPrecision( + queryInfo.getEndTime(), TimeUnit.MILLISECONDS)); + } + columnBuilders[4].writeInt(QueryId.getDataNodeId()); + columnBuilders[5].writeFloat(queryInfo.getCostTime()); + columnBuilders[6].writeBinary(BytesUtils.valueOf(queryInfo.getStatement())); + columnBuilders[7].writeBinary(BytesUtils.valueOf(queryInfo.getUser())); + columnBuilders[8].writeBinary(BytesUtils.valueOf(queryInfo.getClientHost())); + resultBuilder.declarePosition(); + nextConsumedIndex++; + } + + @Override + public boolean hasNext() { + return nextConsumedIndex < queriesInfo.size(); + } + } + + private static class QueriesCostsHistogramSupplier extends TsBlockSupplier { + private int nextConsumedIndex; + private static final Binary[] BUCKETS = + new Binary[] { + BytesUtils.valueOf("[0,1)"), + BytesUtils.valueOf("[1,2)"), + BytesUtils.valueOf("[2,3)"), + BytesUtils.valueOf("[3,4)"), + BytesUtils.valueOf("[4,5)"), + BytesUtils.valueOf("[5,6)"), + BytesUtils.valueOf("[6,7)"), + BytesUtils.valueOf("[7,8)"), + BytesUtils.valueOf("[8,9)"), + BytesUtils.valueOf("[9,10)"), + BytesUtils.valueOf("[10,11)"), + BytesUtils.valueOf("[11,12)"), + BytesUtils.valueOf("[12,13)"), + BytesUtils.valueOf("[13,14)"), + BytesUtils.valueOf("[14,15)"), + BytesUtils.valueOf("[15,16)"), + BytesUtils.valueOf("[16,17)"), + BytesUtils.valueOf("[17,18)"), + BytesUtils.valueOf("[18,19)"), + BytesUtils.valueOf("[19,20)"), + BytesUtils.valueOf("[20,21)"), + BytesUtils.valueOf("[21,22)"), + BytesUtils.valueOf("[22,23)"), + BytesUtils.valueOf("[23,24)"), + BytesUtils.valueOf("[24,25)"), + BytesUtils.valueOf("[25,26)"), + BytesUtils.valueOf("[26,27)"), + BytesUtils.valueOf("[27,28)"), + BytesUtils.valueOf("[28,29)"), + BytesUtils.valueOf("[29,30)"), + BytesUtils.valueOf("[30,31)"), + BytesUtils.valueOf("[31,32)"), + BytesUtils.valueOf("[32,33)"), + BytesUtils.valueOf("[33,34)"), + BytesUtils.valueOf("[34,35)"), + BytesUtils.valueOf("[35,36)"), + BytesUtils.valueOf("[36,37)"), + BytesUtils.valueOf("[37,38)"), + BytesUtils.valueOf("[38,39)"), + BytesUtils.valueOf("[39,40)"), + BytesUtils.valueOf("[40,41)"), + BytesUtils.valueOf("[41,42)"), + BytesUtils.valueOf("[42,43)"), + BytesUtils.valueOf("[43,44)"), + BytesUtils.valueOf("[44,45)"), + BytesUtils.valueOf("[45,46)"), + BytesUtils.valueOf("[46,47)"), + BytesUtils.valueOf("[47,48)"), + BytesUtils.valueOf("[48,49)"), + BytesUtils.valueOf("[49,50)"), + BytesUtils.valueOf("[50,51)"), + BytesUtils.valueOf("[51,52)"), + BytesUtils.valueOf("[52,53)"), + BytesUtils.valueOf("[53,54)"), + BytesUtils.valueOf("[54,55)"), + BytesUtils.valueOf("[55,56)"), + BytesUtils.valueOf("[56,57)"), + BytesUtils.valueOf("[57,58)"), + BytesUtils.valueOf("[58,59)"), + BytesUtils.valueOf("[59,60)"), + BytesUtils.valueOf("60+") + }; + private final int[] currentQueriesCostHistogram; + + private QueriesCostsHistogramSupplier( + final List dataTypes, final UserEntity userEntity) { + super(dataTypes); + accessControl.checkUserGlobalSysPrivilege(userEntity); + currentQueriesCostHistogram = Coordinator.getInstance().getCurrentQueriesCostHistogram(); + } + + @Override + protected void constructLine() { + columnBuilders[0].writeBinary(BUCKETS[nextConsumedIndex]); + columnBuilders[1].writeInt(currentQueriesCostHistogram[nextConsumedIndex]); + resultBuilder.declarePosition(); + nextConsumedIndex++; + } + + @Override + public boolean hasNext() { + return nextConsumedIndex < 61; + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java index 7708f6c18cda..78cdee720da2 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/Coordinator.java @@ -26,6 +26,7 @@ import org.apache.iotdb.commons.client.sync.SyncDataNodeInternalServiceClient; import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory; import org.apache.iotdb.commons.concurrent.ThreadName; +import org.apache.iotdb.commons.concurrent.threadpool.ScheduledExecutorUtil; import org.apache.iotdb.commons.conf.CommonConfig; import org.apache.iotdb.commons.conf.CommonDescriptor; import org.apache.iotdb.commons.conf.IoTDBConstant; @@ -42,6 +43,7 @@ import org.apache.iotdb.db.queryengine.common.QueryId; import org.apache.iotdb.db.queryengine.common.SessionInfo; import org.apache.iotdb.db.queryengine.execution.QueryIdGenerator; +import org.apache.iotdb.db.queryengine.execution.QueryState; import org.apache.iotdb.db.queryengine.plan.analyze.IPartitionFetcher; import org.apache.iotdb.db.queryengine.plan.analyze.lock.DataNodeSchemaLockManager; import org.apache.iotdb.db.queryengine.plan.analyze.schema.ISchemaFetcher; @@ -141,23 +143,37 @@ import org.apache.iotdb.db.utils.SetThreadName; import org.apache.thrift.TBase; +import org.apache.tsfile.utils.Accountable; +import org.apache.tsfile.utils.RamUsageEstimator; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; +import java.util.concurrent.BlockingDeque; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutorService; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; import java.util.function.LongSupplier; import java.util.function.Supplier; +import java.util.stream.Collectors; import static org.apache.iotdb.commons.utils.StatusUtils.needRetry; +import static org.apache.iotdb.db.queryengine.plan.Coordinator.QueryInfo.DEFAULT_END_TIME; import static org.apache.iotdb.db.utils.CommonUtils.getContentOfRequest; +import static org.apache.tsfile.utils.RamUsageEstimator.shallowSizeOfInstance; +import static org.apache.tsfile.utils.RamUsageEstimator.sizeOfCharArray; /** * The coordinator for MPP. It manages all the queries which are executed in current Node. And it @@ -203,12 +219,32 @@ public class Coordinator { private final ConcurrentHashMap queryExecutionMap; + private final BlockingDeque currentQueriesInfo = new LinkedBlockingDeque<>(); + private final AtomicInteger[] currentQueriesCostHistogram = new AtomicInteger[61]; + private final ScheduledExecutorService retryFailTasksExecutor = + IoTDBThreadPoolFactory.newSingleThreadScheduledExecutor( + ThreadName.EXPIRED_QUERIES_INFO_CLEAR.getName()); + private final StatementRewrite statementRewrite; private final List logicalPlanOptimizers; private final List distributionPlanOptimizers; private final DataNodeLocationSupplierFactory.DataNodeLocationSupplier dataNodeLocationSupplier; private final TypeManager typeManager; + { + for (int i = 0; i < 61; i++) { + currentQueriesCostHistogram[i] = new AtomicInteger(); + } + + ScheduledExecutorUtil.safelyScheduleWithFixedDelay( + retryFailTasksExecutor, + this::clearExpiredQueriesInfoTask, + 1_000L, + 1_000L, + TimeUnit.MILLISECONDS); + LOGGER.info("Expired-Queries-Info-Clear thread is successfully started."); + } + static { coordinatorMemoryBlock = IoTDBDescriptor.getInstance() @@ -625,12 +661,22 @@ public void cleanupQueryExecution( try (SetThreadName threadName = new SetThreadName(queryExecution.getQueryId())) { LOGGER.debug("[CleanUpQuery]]"); queryExecution.stopAndCleanup(t); + boolean isUserQuery = queryExecution.isQuery() && queryExecution.isUserQuery(); + Supplier contentOfQuerySupplier = + new ContentOfQuerySupplier(nativeApiRequest, queryExecution); + if (isUserQuery) { + recordCurrentQueries( + queryExecution.getQueryId(), + queryExecution.getStartExecutionTime(), + System.currentTimeMillis(), + queryExecution.getTotalExecutionTime(), + contentOfQuerySupplier, + queryExecution.getUser(), + queryExecution.getClientHostname()); + } queryExecutionMap.remove(queryId); - if (queryExecution.isQuery() && queryExecution.isUserQuery()) { - recordQueries( - queryExecution::getTotalExecutionTime, - new ContentOfQuerySupplier(nativeApiRequest, queryExecution), - t); + if (isUserQuery) { + recordQueries(queryExecution::getTotalExecutionTime, contentOfQuerySupplier, t); } } } @@ -722,4 +768,249 @@ public DataNodeLocationSupplierFactory.DataNodeLocationSupplier getDataNodeLocat public ExecutorService getDispatchExecutor() { return dispatchExecutor; } + + /** record query info in memory data structure */ + public void recordCurrentQueries( + String queryId, + long startTime, + long endTime, + long costTimeInNs, + Supplier contentOfQuerySupplier, + String user, + String clientHost) { + if (CONFIG.getQueryCostStatWindow() <= 0) { + return; + } + + if (queryId == null) { + // fast Last query API executeFastLastDataQueryForOnePrefixPath will enter this + queryId = queryIdGenerator.createNextQueryId().getId(); + } + + // ns -> s + float costTimeInSeconds = costTimeInNs * 1e-9f; + + QueryInfo queryInfo = + new QueryInfo( + queryId, + startTime, + endTime, + costTimeInSeconds, + contentOfQuerySupplier.get(), + user, + clientHost); + + while (!coordinatorMemoryBlock.allocate(RamUsageEstimator.sizeOfObject(queryInfo))) { + // try to release memory from the head of queue + QueryInfo queryInfoToRelease = currentQueriesInfo.poll(); + if (queryInfoToRelease == null) { + // no element in the queue and the memory is still not enough, skip this record + return; + } else { + // release memory and unrecord in histogram + coordinatorMemoryBlock.release(RamUsageEstimator.sizeOfObject(queryInfoToRelease)); + unrecordInHistogram(queryInfoToRelease.costTime); + } + } + + currentQueriesInfo.addLast(queryInfo); + recordInHistogram(costTimeInSeconds); + } + + private void recordInHistogram(float costTimeInSeconds) { + int bucket = (int) costTimeInSeconds; + if (bucket < 60) { + currentQueriesCostHistogram[bucket].getAndIncrement(); + } else { + currentQueriesCostHistogram[60].getAndIncrement(); + } + } + + private void unrecordInHistogram(float costTimeInSeconds) { + int bucket = (int) costTimeInSeconds; + if (bucket < 60) { + currentQueriesCostHistogram[bucket].getAndDecrement(); + } else { + currentQueriesCostHistogram[60].getAndDecrement(); + } + } + + private void clearExpiredQueriesInfoTask() { + int queryCostStatWindow = CONFIG.getQueryCostStatWindow(); + if (queryCostStatWindow <= 0) { + return; + } + + // the QueryInfo smaller than expired time will be cleared + long expiredTime = System.currentTimeMillis() - queryCostStatWindow * 60 * 1_000L; + // peek head, the head QueryInfo is in the time window, return directly + QueryInfo queryInfo = currentQueriesInfo.peekFirst(); + if (queryInfo == null || queryInfo.endTime >= expiredTime) { + return; + } + + queryInfo = currentQueriesInfo.poll(); + while (queryInfo != null) { + if (queryInfo.endTime < expiredTime) { + // out of time window, clear queryInfo + coordinatorMemoryBlock.release(RamUsageEstimator.sizeOfObject(queryInfo)); + unrecordInHistogram(queryInfo.costTime); + queryInfo = currentQueriesInfo.poll(); + } else { + // the head of the queue is not expired, add back + currentQueriesInfo.addFirst(queryInfo); + // there is no more candidate to clear + return; + } + } + } + + public List getCurrentQueriesInfo() { + List runningQueries = getAllQueryExecutions(); + Set runningQueryIdSet = + runningQueries.stream().map(IQueryExecution::getQueryId).collect(Collectors.toSet()); + List result = new ArrayList<>(); + + // add History queries (satisfy the time window) info + Iterator historyQueriesIterator = currentQueriesInfo.iterator(); + Set repetitionQueryIdSet = new HashSet<>(); + long currentTime = System.currentTimeMillis(); + long needRecordTime = currentTime - CONFIG.getQueryCostStatWindow() * 60 * 1_000L; + while (historyQueriesIterator.hasNext()) { + QueryInfo queryInfo = historyQueriesIterator.next(); + if (queryInfo.endTime < needRecordTime) { + // out of time window, ignore it + } else { + if (runningQueryIdSet.contains(queryInfo.queryId)) { + repetitionQueryIdSet.add(queryInfo.queryId); + } + result.add(new StatedQueriesInfo(QueryState.FINISHED, queryInfo)); + } + } + + // add Running queries info after remove the repetitions which has recorded in History queries + result.addAll( + runningQueries.stream() + .filter(queryExecution -> !repetitionQueryIdSet.contains(queryExecution.getQueryId())) + .map( + queryExecution -> + new StatedQueriesInfo( + QueryState.RUNNING, + queryExecution.getQueryId(), + queryExecution.getStartExecutionTime(), + DEFAULT_END_TIME, + (currentTime - queryExecution.getStartExecutionTime()) / 1000, + queryExecution.getExecuteSQL().orElse("UNKNOWN"), + queryExecution.getUser(), + queryExecution.getClientHostname())) + .collect(Collectors.toList())); + return result; + } + + public int[] getCurrentQueriesCostHistogram() { + return Arrays.stream(currentQueriesCostHistogram).mapToInt(AtomicInteger::get).toArray(); + } + + public static class QueryInfo implements Accountable { + public static final long DEFAULT_END_TIME = -1L; + private static final long INSTANCE_SIZE = shallowSizeOfInstance(QueryInfo.class); + + private final String queryId; + + // unit: millisecond + private final long startTime; + private final long endTime; + // unit: second + private final float costTime; + + private final String statement; + private final String user; + private final String clientHost; + + public QueryInfo( + String queryId, + long startTime, + long endTime, + float costTime, + String statement, + String user, + String clientHost) { + this.queryId = queryId; + this.startTime = startTime; + this.endTime = endTime; + this.costTime = costTime; + this.statement = statement; + this.user = user; + this.clientHost = clientHost; + } + + public String getClientHost() { + return clientHost; + } + + public String getUser() { + return user; + } + + public long getStartTime() { + return startTime; + } + + public long getEndTime() { + return endTime; + } + + public float getCostTime() { + return costTime; + } + + public String getQueryId() { + return queryId; + } + + public String getStatement() { + return statement; + } + + @Override + public long ramBytesUsed() { + return INSTANCE_SIZE + + sizeOfCharArray(statement.length()) + + sizeOfCharArray(user.length()) + + sizeOfCharArray(clientHost.length()); + } + } + + public static class StatedQueriesInfo extends QueryInfo { + private final QueryState queryState; + + private StatedQueriesInfo(QueryState queryState, QueryInfo queryInfo) { + super( + queryInfo.queryId, + queryInfo.startTime, + queryInfo.endTime, + queryInfo.costTime, + queryInfo.statement, + queryInfo.user, + queryInfo.clientHost); + this.queryState = queryState; + } + + private StatedQueriesInfo( + QueryState queryState, + String queryId, + long startTime, + long endTime, + long costTime, + String statement, + String user, + String clientHost) { + super(queryId, startTime, endTime, costTime, statement, user, clientHost); + this.queryState = queryState; + } + + public String getQueryState() { + return queryState.name(); + } + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java index 98257c24293e..e98f016767f6 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/IQueryExecution.java @@ -77,4 +77,6 @@ public interface IQueryExecution { IClientSession.SqlDialect getSQLDialect(); String getUser(); + + String getClientHostname(); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java index 2c1657e839e2..4734db5850a4 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/QueryExecution.java @@ -700,6 +700,11 @@ public String getUser() { return context.getSession().getUserName(); } + @Override + public String getClientHostname() { + return context.getCliHostname(); + } + public MPPQueryContext getContext() { return context; } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java index 1880924f297a..823a620820fd 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/ConfigExecution.java @@ -353,4 +353,8 @@ public IClientSession.SqlDialect getSQLDialect() { public String getUser() { return context.getSession().getUserName(); } + + public String getClientHostname() { + return context.getCliHostname(); + } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java index f8cf497546e6..d7d755ddc1da 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/DataNodeLocationSupplierFactory.java @@ -86,6 +86,8 @@ public List getDataNodeLocations(final String tableName) { switch (tableName) { case InformationSchema.QUERIES: case InformationSchema.CONNECTIONS: + case InformationSchema.CURRENT_QUERIES: + case InformationSchema.QUERIES_COSTS_HISTOGRAM: return getReadableDataNodeLocations(); case InformationSchema.DATABASES: case InformationSchema.TABLES: diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java index 4ca38a5c8d9f..8fc7d01437cf 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/MergeTreeSortOperatorTest.java @@ -1910,5 +1910,10 @@ public QueryType getQueryType() { public boolean isUserQuery() { return false; } + + @Override + public String getClientHostname() { + return SessionConfig.DEFAULT_HOST; + } } } diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java index a4cef177d813..850755702016 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/PlanTester.java @@ -83,6 +83,8 @@ public class PlanTester { public List getDataNodeLocations(String table) { switch (table) { case "queries": + case "current_queries": + case "queries_costs_histogram": return ImmutableList.of( genDataNodeLocation(1, "192.0.1.1"), genDataNodeLocation(2, "192.0.1.2")); default: diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java new file mode 100644 index 000000000000..1e3163321ec4 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/CurrentQueriesTest.java @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.iotdb.db.queryengine.plan.relational.planner.informationschema; + +import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; +import org.apache.iotdb.db.queryengine.plan.relational.planner.PlanTester; + +import com.google.common.collect.ImmutableList; +import org.junit.Test; + +import java.util.Optional; + +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.BIN; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.CLIENT_IP; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.COST_TIME; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.DATA_NODE_ID_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.END_TIME_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.NUMS; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.QUERY_ID_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.START_TIME_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATE_TABLE_MODEL; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.infoSchemaTableScan; +import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.output; + +public class CurrentQueriesTest { + private final PlanTester planTester = new PlanTester(); + + @Test + public void testCurrentQueries() { + LogicalQueryPlan logicalQueryPlan = + planTester.createPlan("select * from information_schema.current_queries"); + assertPlan( + logicalQueryPlan, + output( + infoSchemaTableScan( + "information_schema.current_queries", + Optional.empty(), + ImmutableList.of( + QUERY_ID_TABLE_MODEL, + STATE_TABLE_MODEL, + START_TIME_TABLE_MODEL, + END_TIME_TABLE_MODEL, + DATA_NODE_ID_TABLE_MODEL, + COST_TIME, + STATEMENT_TABLE_MODEL, + USER_TABLE_MODEL, + CLIENT_IP)))); + + // - Exchange + // Output - Collect - Exchange + assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange()))); + // TableScan + assertPlan( + planTester.getFragmentPlan(1), + infoSchemaTableScan("information_schema.current_queries", Optional.of(1))); + // TableScan + assertPlan( + planTester.getFragmentPlan(2), + infoSchemaTableScan("information_schema.current_queries", Optional.of(2))); + } + + @Test + public void testQueriesCostsHistogram() { + LogicalQueryPlan logicalQueryPlan = + planTester.createPlan("select * from information_schema.queries_costs_histogram"); + assertPlan( + logicalQueryPlan, + output( + infoSchemaTableScan( + "information_schema.queries_costs_histogram", + Optional.empty(), + ImmutableList.of(BIN, NUMS)))); + + // - Exchange + // Output - Collect - Exchange + assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange()))); + // TableScan + assertPlan( + planTester.getFragmentPlan(1), + infoSchemaTableScan("information_schema.queries_costs_histogram", Optional.of(1))); + // TableScan + assertPlan( + planTester.getFragmentPlan(2), + infoSchemaTableScan("information_schema.queries_costs_histogram", Optional.of(2))); + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ShowQueriesTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/ShowQueriesTest.java similarity index 94% rename from iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ShowQueriesTest.java rename to iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/ShowQueriesTest.java index 63b7b783d747..7161c68f4e9e 100644 --- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ShowQueriesTest.java +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/informationschema/ShowQueriesTest.java @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.iotdb.db.queryengine.plan.relational.analyzer; +package org.apache.iotdb.db.queryengine.plan.relational.planner.informationschema; import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan; import org.apache.iotdb.db.queryengine.plan.relational.planner.PlanTester; @@ -32,7 +32,9 @@ import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.QUERY_ID_TABLE_MODEL; import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.START_TIME_TABLE_MODEL; import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.STATEMENT_TABLE_MODEL; import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER; +import static org.apache.iotdb.commons.schema.column.ColumnHeaderConstant.USER_TABLE_MODEL; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect; import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange; @@ -61,8 +63,8 @@ public void testNormal() { START_TIME_TABLE_MODEL, DATA_NODE_ID_TABLE_MODEL, ELAPSED_TIME_TABLE_MODEL, - STATEMENT.toLowerCase(Locale.ENGLISH), - USER.toLowerCase(Locale.ENGLISH))))); + STATEMENT_TABLE_MODEL, + USER_TABLE_MODEL)))); // - Exchange // Output - Collect - Exchange diff --git a/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template b/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template index 5a32d6a12314..0a9a47ee8907 100644 --- a/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template +++ b/iotdb-core/node-commons/src/assembly/resources/conf/iotdb-system.properties.template @@ -1080,6 +1080,12 @@ max_tsblock_line_number=1000 # Datatype: long slow_query_threshold=10000 +# Time window threshold(min) for record of history queries. +# effectiveMode: hot_reload +# Datatype: int +# Privilege: SYSTEM +query_cost_stat_window=0 + # The max executing time of query. unit: ms # effectiveMode: restart # Datatype: int diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java index 390e9f80e9b3..6f9f95ca8fe8 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/concurrent/ThreadName.java @@ -35,6 +35,7 @@ public enum ThreadName { FRAGMENT_INSTANCE_NOTIFICATION("Fragment-Instance-Notification"), FRAGMENT_INSTANCE_DISPATCH("Fragment-Instance-Dispatch"), DRIVER_TASK_SCHEDULER_NOTIFICATION("Driver-Task-Scheduler-Notification"), + EXPIRED_QUERIES_INFO_CLEAR("Expired-Queries-Info-Clear"), // -------------------------- MPP -------------------------- MPP_COORDINATOR_SCHEDULED_EXECUTOR("MPP-Coordinator-Scheduled-Executor"), MPP_DATA_EXCHANGE_TASK_EXECUTOR("MPP-Data-Exchange-Task-Executors"), diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java index 67850d991cb6..0459d4d2c86e 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/column/ColumnHeaderConstant.java @@ -214,11 +214,19 @@ private ColumnHeaderConstant() { public static final String CLIENT_IP = "client_ip"; public static final String QUERY_ID_TABLE_MODEL = "query_id"; - public static final String QUERY_ID_START_TIME_TABLE_MODEL = "start_time"; public static final String DATA_NODE_ID_TABLE_MODEL = "datanode_id"; public static final String START_TIME_TABLE_MODEL = "start_time"; public static final String ELAPSED_TIME_TABLE_MODEL = "elapsed_time"; + // column names for current_queries and queries_costs_histogram + public static final String STATE_TABLE_MODEL = "state"; + public static final String END_TIME_TABLE_MODEL = "end_time"; + public static final String COST_TIME = "cost_time"; + public static final String STATEMENT_TABLE_MODEL = "statement"; + public static final String USER_TABLE_MODEL = "user"; + public static final String BIN = "bin"; + public static final String NUMS = "nums"; + public static final String TABLE_NAME_TABLE_MODEL = "table_name"; public static final String TABLE_TYPE_TABLE_MODEL = "table_type"; public static final String COLUMN_NAME_TABLE_MODEL = "column_name"; diff --git a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java index 243bc41c40ce..b8e03423d61c 100644 --- a/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java +++ b/iotdb-core/node-commons/src/main/java/org/apache/iotdb/commons/schema/table/InformationSchema.java @@ -50,6 +50,8 @@ public class InformationSchema { public static final String CONFIG_NODES = "config_nodes"; public static final String DATA_NODES = "data_nodes"; public static final String CONNECTIONS = "connections"; + public static final String CURRENT_QUERIES = "current_queries"; + public static final String QUERIES_COSTS_HISTOGRAM = "queries_costs_histogram"; static { final TsTable queriesTable = new TsTable(QUERIES); @@ -57,17 +59,15 @@ public class InformationSchema { new TagColumnSchema(ColumnHeaderConstant.QUERY_ID_TABLE_MODEL, TSDataType.STRING)); queriesTable.addColumnSchema( new AttributeColumnSchema( - ColumnHeaderConstant.QUERY_ID_START_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); + ColumnHeaderConstant.START_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); queriesTable.addColumnSchema( new AttributeColumnSchema(ColumnHeaderConstant.DATA_NODE_ID_TABLE_MODEL, TSDataType.INT32)); queriesTable.addColumnSchema( new AttributeColumnSchema(ColumnHeaderConstant.ELAPSED_TIME_TABLE_MODEL, TSDataType.FLOAT)); queriesTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.STATEMENT.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); + new AttributeColumnSchema(ColumnHeaderConstant.STATEMENT_TABLE_MODEL, TSDataType.STRING)); queriesTable.addColumnSchema( - new AttributeColumnSchema( - ColumnHeaderConstant.USER.toLowerCase(Locale.ENGLISH), TSDataType.STRING)); + new AttributeColumnSchema(ColumnHeaderConstant.USER_TABLE_MODEL, TSDataType.STRING)); queriesTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(QUERIES, queriesTable); @@ -361,6 +361,37 @@ public class InformationSchema { new AttributeColumnSchema(ColumnHeaderConstant.CLIENT_IP, TSDataType.STRING)); connectionsTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); schemaTables.put(CONNECTIONS, connectionsTable); + + final TsTable currentQueriesTable = new TsTable(CURRENT_QUERIES); + currentQueriesTable.addColumnSchema( + new TagColumnSchema(ColumnHeaderConstant.QUERY_ID_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.STATE_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema( + ColumnHeaderConstant.START_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.END_TIME_TABLE_MODEL, TSDataType.TIMESTAMP)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.DATA_NODE_ID_TABLE_MODEL, TSDataType.INT32)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.COST_TIME, TSDataType.FLOAT)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.STATEMENT_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.USER_TABLE_MODEL, TSDataType.STRING)); + currentQueriesTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.CLIENT_IP, TSDataType.STRING)); + currentQueriesTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); + schemaTables.put(CURRENT_QUERIES, currentQueriesTable); + + final TsTable queriesCostsHistogramTable = new TsTable(QUERIES_COSTS_HISTOGRAM); + queriesCostsHistogramTable.addColumnSchema( + new TagColumnSchema(ColumnHeaderConstant.BIN, TSDataType.STRING)); + queriesCostsHistogramTable.addColumnSchema( + new AttributeColumnSchema(ColumnHeaderConstant.NUMS, TSDataType.INT32)); + queriesCostsHistogramTable.removeColumnSchema(TsTable.TIME_COLUMN_NAME); + schemaTables.put(QUERIES_COSTS_HISTOGRAM, queriesCostsHistogramTable); } public static Map getSchemaTables() { From 22a8fa1621acbbe10c655be02378a55722d6bd43 Mon Sep 17 00:00:00 2001 From: JackieTien97 Date: Wed, 10 Dec 2025 19:42:17 +0800 Subject: [PATCH 12/13] Bump tsfile version to 2.2.0-251210-SNAPSHOT --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index d9df3de1bf7b..66c7116a9fb5 100644 --- a/pom.xml +++ b/pom.xml @@ -173,7 +173,7 @@ 0.14.1 1.9 1.5.6-3 - 2.2.0-251209-SNAPSHOT + 2.2.0-251210-SNAPSHOT