Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/cluster-it-1c1d1a.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ jobs:

steps:
- uses: actions/checkout@v4
- name: Build AINode
shell: bash
run: mvn clean package -DskipTests -P with-ainode
- name: IT Test
shell: bash
run: |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,10 @@ public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportio
setProperty("datanode_memory_proportion", dataNodeMemoryProportion);
return this;
}

@Override
public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) {
setProperty("query_cost_stat_window", String.valueOf(queryCostStatWindow));
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ public class AINodeWrapper extends AbstractNodeWrapper {
private static final String PROPERTIES_FILE = "iotdb-ainode.properties";
public static final String CONFIG_PATH = "conf";
public static final String SCRIPT_PATH = "sbin";
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/weights";
public static final String BUILT_IN_MODEL_PATH = "data/ainode/models/builtin";
public static final String CACHE_BUILT_IN_MODEL_PATH = "/data/ainode/models/weights";

private void replaceAttribute(String[] keys, String[] values, String filePath) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,9 @@ public DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs)
public DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion) {
return this;
}

@Override
public DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow) {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,6 @@ DataNodeConfig setLoadTsFileAnalyzeSchemaMemorySizeInBytes(
DataNodeConfig setDeleteWalFilesPeriodInMs(long deleteWalFilesPeriodInMs);

DataNodeConfig setDataNodeMemoryProportion(String dataNodeMemoryProportion);

DataNodeConfig setQueryCostStatWindow(int queryCostStatWindow);
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ private void updateConfig(Statement statement, int timeout) throws SQLException
statement.setQueryTimeout(timeout);
}

/**
* Executes a SQL query on all read statements in parallel.
*
* <p>Note: For PreparedStatement EXECUTE queries, use the write connection directly instead,
* because PreparedStatements are session-scoped and this method may route queries to different
* nodes where the PreparedStatement doesn't exist.
*/
@Override
public ResultSet executeQuery(String sql) throws SQLException {
return new ClusterTestResultSet(readStatements, readEndpoints, sql, queryTimeout);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.ainode.it;

import org.apache.iotdb.ainode.utils.AINodeTestUtils;
import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
import org.apache.iotdb.itbase.category.AIClusterIT;
import org.apache.iotdb.itbase.env.BaseEnv;

import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;

import java.sql.Connection;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.sql.Statement;

import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader;
import static org.apache.iotdb.db.it.utils.TestUtils.prepareData;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeCallInferenceIT {

private static final String[] WRITE_SQL_IN_TREE =
new String[] {
"CREATE DATABASE root.AI",
"CREATE TIMESERIES root.AI.s0 WITH DATATYPE=FLOAT, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s1 WITH DATATYPE=DOUBLE, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s2 WITH DATATYPE=INT32, ENCODING=RLE",
"CREATE TIMESERIES root.AI.s3 WITH DATATYPE=INT64, ENCODING=RLE",
};

private static final String CALL_INFERENCE_SQL_TEMPLATE =
"CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)";
private static final int DEFAULT_INPUT_LENGTH = 256;
private static final int DEFAULT_OUTPUT_LENGTH = 48;

@BeforeClass
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
prepareData(WRITE_SQL_IN_TREE);
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
for (int i = 0; i < 2880; i++) {
statement.execute(
String.format(
"INSERT INTO root.AI(timestamp,s0,s1,s2,s3) VALUES(%d,%f,%f,%d,%d)",
i, (float) i, (double) i, i, i));
}
}
}

@AfterClass
public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}

@Test
public void callInferenceTest() throws SQLException {
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
callInferenceTest(statement, modelInfo);
}
}
}

public void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo)
throws SQLException {
// Invoke call inference for specified models, there should exist result.
for (int i = 0; i < 4; i++) {
String callInferenceSQL =
String.format(
CALL_INFERENCE_SQL_TEMPLATE,
modelInfo.getModelId(),
i,
DEFAULT_INPUT_LENGTH,
DEFAULT_OUTPUT_LENGTH);
try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) {
ResultSetMetaData resultSetMetaData = resultSet.getMetaData();
checkHeader(resultSetMetaData, "Time,output");
int count = 0;
while (resultSet.next()) {
count++;
}
// Ensure the call inference return results
Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.iotdb.ainode.it;

import org.apache.iotdb.ainode.utils.AINodeTestUtils;
import org.apache.iotdb.it.env.EnvFactory;
import org.apache.iotdb.it.framework.IoTDBTestRunner;
import org.apache.iotdb.itbase.category.AIClusterIT;
import org.apache.iotdb.itbase.env.BaseEnv;

import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;

import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_LTSM_MAP;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice;
import static org.apache.iotdb.ainode.utils.AINodeTestUtils.concurrentInference;

@RunWith(IoTDBTestRunner.class)
@Category({AIClusterIT.class})
public class AINodeConcurrentForecastIT {

private static final Logger LOGGER = LoggerFactory.getLogger(AINodeConcurrentForecastIT.class);

private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE =
"SELECT * FROM FORECAST(model_id=>'%s', input=>(SELECT time,s FROM root.AI) ORDER BY time, forecast_length=>%d)";

@BeforeClass
public static void setUp() throws Exception {
// Init 1C1D1A cluster environment
EnvFactory.getEnv().initClusterEnvironment(1, 1);
prepareDataForTableModel();
}

@AfterClass
public static void tearDown() throws Exception {
EnvFactory.getEnv().cleanClusterEnvironment();
}

private static void prepareDataForTableModel() throws SQLException {
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
statement.execute("CREATE DATABASE root");
statement.execute("CREATE TABLE root.AI (s DOUBLE FIELD)");
for (int i = 0; i < 2880; i++) {
statement.execute(
String.format(
"INSERT INTO root.AI(time, s) VALUES(%d, %f)", i, Math.sin(i * Math.PI / 1440)));
}
}
}

@Test
public void concurrentGPUForecastTest() throws SQLException, InterruptedException {
for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_LTSM_MAP.values()) {
concurrentGPUForecastTest(modelInfo);
}
}

public void concurrentGPUForecastTest(AINodeTestUtils.FakeModelInfo modelInfo)
throws SQLException, InterruptedException {
final int forecastLength = 512;
try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT);
Statement statement = connection.createStatement()) {
// Single forecast request can be processed successfully
final String forecastSQL =
String.format(
FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, modelInfo.getModelId(), forecastLength);
final int threadCnt = 10;
final int loop = 100;
final String devices = "0,1";
statement.execute(
String.format("LOAD MODEL %s TO DEVICES '%s'", modelInfo.getModelId(), devices));
checkModelOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
long startTime = System.currentTimeMillis();
concurrentInference(statement, forecastSQL, threadCnt, loop, forecastLength);
long endTime = System.currentTimeMillis();
LOGGER.info(
String.format(
"Model %s concurrent inference %d reqs (%d threads, %d loops) in GPU takes time: %dms",
modelInfo.getModelId(), threadCnt * loop, threadCnt, loop, endTime - startTime));
statement.execute(
String.format("UNLOAD MODEL %s FROM DEVICES '%s'", modelInfo.getModelId(), devices));
checkModelNotOnSpecifiedDevice(statement, modelInfo.getModelId(), devices);
}
}
}
Loading
Loading