This is an automated email from the ASF dual-hosted git repository. JackieTien97 pushed a commit to branch speedup-ainode-ci in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit e09264582a1a7050227155f35443f8a3c7c392fa Author: JackieTien97 <[email protected]> AuthorDate: Sat May 16 12:51:12 2026 +0800 Speed up AINode CI by consolidating tests and caching PyInstaller output - Merge 5 AINode IT test classes (DeviceManage, ModelManage, CallInference, Forecast, InstanceManagement) into AINodeSharedClusterIT that shares a single 1C1D1A cluster, reducing cluster startups from 8 to 3 (~20min saved) - Convert AINodeClusterConfigIT from @Before/@After to @BeforeClass/@AfterClass, merging both dialect tests into one method to eliminate a redundant cluster restart - Add hash-based dist caching to build_binary.py: computes SHA256 of AINode source files and skips PyInstaller rebuild when source hasn't changed (~11min saved) - Cache stored at ~/.cache/iotdb-ainode-build/dist-cache/, survives mvn clean --- .../iotdb/ainode/it/AINodeCallInferenceIT.java | 135 ----- .../iotdb/ainode/it/AINodeClusterConfigIT.java | 70 ++- .../iotdb/ainode/it/AINodeDeviceManageIT.java | 96 ---- .../apache/iotdb/ainode/it/AINodeForecastIT.java | 195 -------- .../ainode/it/AINodeInstanceManagementIT.java | 179 ------- .../iotdb/ainode/it/AINodeModelManageIT.java | 215 -------- .../iotdb/ainode/it/AINodeSharedClusterIT.java | 555 +++++++++++++++++++++ iotdb-core/ainode/build_binary.py | 106 +++- 8 files changed, 704 insertions(+), 847 deletions(-) diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java deleted file mode 100644 index 852827aa8ab..00000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeCallInferenceIT.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.ainode.utils.AINodeTestUtils; -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeCallInferenceIT { - - private static final String CALL_INFERENCE_SQL_TEMPLATE = - "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; - private static final String CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE = - "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")"; - private static final int DEFAULT_INPUT_LENGTH = 256; - private static final int DEFAULT_OUTPUT_LENGTH = 48; - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataInTree(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - @Test - public void callInferenceTest() throws SQLException { - for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - callInferenceTest(statement, modelInfo); - callInferenceByDefaultTest(statement, modelInfo); - callInferenceErrorTest(statement, modelInfo); - } - } - } - - public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) - throws SQLException { - // Invoke call inference for specified models, there should exist result. - for (int i = 0; i < 4; i++) { - String callInferenceSQL = - String.format( - CALL_INFERENCE_SQL_TEMPLATE, - modelInfo.getModelId(), - i, - DEFAULT_INPUT_LENGTH, - DEFAULT_OUTPUT_LENGTH); - try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "Time,output"); - int count = 0; - while (resultSet.next()) { - count++; - } - // Ensure the call inference return results - Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count); - } - } - } - - public static void callInferenceByDefaultTest( - Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { - // Invoke call inference for specified models, there should exist result. - for (int i = 0; i < 4; i++) { - String callInferenceSQL = - String.format(CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE, modelInfo.getModelId(), i); - try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "output"); - int count = 0; - while (resultSet.next()) { - count++; - } - // Ensure the call inference return results - Assert.assertTrue(count > 0); - } - } - } - - public static void callInferenceErrorTest( - Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) { - String multiVariateSQL = - String.format( - "CALL INFERENCE(%s, \"SELECT s0,s1 FROM root.AI LIMIT 128\", generateTime=true, outputLength=10)", - modelInfo.getModelId()); - errorTest( - statement, - multiVariateSQL, - "701: Call inference function should not contain more than one input column, found [2] input columns."); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java index e148be6b20a..4963c0828c8 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java @@ -24,9 +24,9 @@ import org.apache.iotdb.it.framework.IoTDBTestRunner; import org.apache.iotdb.itbase.category.AIClusterIT; import org.apache.iotdb.itbase.env.BaseEnv; -import org.junit.After; +import org.junit.AfterClass; import org.junit.Assert; -import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; import org.junit.experimental.categories.Category; import org.junit.runner.RunWith; @@ -44,39 +44,56 @@ import static org.junit.Assert.assertEquals; @Category({AIClusterIT.class}) public class AINodeClusterConfigIT { - @Before - public void setUp() throws Exception { - // Init 1C1D1A cluster environment + @BeforeClass + public static void setUp() throws Exception { EnvFactory.getEnv().initClusterEnvironment(1, 1); } - @After - public void tearDown() throws Exception { + @AfterClass + public static void tearDown() throws Exception { EnvFactory.getEnv().cleanClusterEnvironment(); } @Test - public void aiNodeRegisterAndRemoveTestInTree() throws SQLException { + public void aiNodeRegisterAndRemoveTest() throws SQLException { + String show_sql = "SHOW AINODES"; + String title = "NodeID,Status,InternalAddress,InternalPort"; + + // Verify AINode exists via both dialects before removal try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - aiNodeRegisterAndRemoveTest(statement); + verifyAINodeExists(statement, show_sql, title); + } + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + verifyAINodeExists(statement, show_sql, title); } - } - @Test - public void aiNodeRegisterAndRemoveTestInTable() throws SQLException { + // Remove AINode + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + statement.execute("REMOVE AINODE"); + waitForAINodeRemoval(statement, show_sql, title); + } + + // Verify removal is visible via table dialect as well try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - aiNodeRegisterAndRemoveTest(statement); + try (ResultSet resultSet = statement.executeQuery(show_sql)) { + checkHeader(resultSet.getMetaData(), title); + int count = 0; + while (resultSet.next()) { + count++; + } + assertEquals(0, count); + } } } - private void aiNodeRegisterAndRemoveTest(Statement statement) throws SQLException { - String show_sql = "SHOW AINODES"; - String title = "NodeID,Status,InternalAddress,InternalPort"; - try (ResultSet resultSet = statement.executeQuery(show_sql)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, title); + private static void verifyAINodeExists(Statement statement, String showSql, String title) + throws SQLException { + try (ResultSet resultSet = statement.executeQuery(showSql)) { + checkHeader(resultSet.getMetaData(), title); int count = 0; while (resultSet.next()) { assertEquals("2", resultSet.getString(1)); @@ -85,22 +102,23 @@ public class AINodeClusterConfigIT { } assertEquals(1, count); } - String remove_sql = "REMOVE AINODE"; - statement.execute(remove_sql); + } + + private static void waitForAINodeRemoval(Statement statement, String showSql, String title) + throws SQLException { for (int retry = 0; retry < 500; retry++) { - try (ResultSet resultSet = statement.executeQuery(show_sql)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, title); + try (ResultSet resultSet = statement.executeQuery(showSql)) { + checkHeader(resultSet.getMetaData(), title); int count = 0; while (resultSet.next()) { count++; } if (count == 0) { - return; // Successfully removed the AI node + return; } } try { - Thread.sleep(1000); // Wait before retrying + Thread.sleep(1000); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeDeviceManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeDeviceManageIT.java deleted file mode 100644 index bbffd3cffb0..00000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeDeviceManageIT.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Arrays; -import java.util.LinkedList; -import java.util.List; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeDeviceManageIT { - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataInTree(); - prepareDataInTable(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - @Test - public void showAIDeviceTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - showAIDevicesTest(statement); - } - } - - @Test - public void showAIDeviceTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - showAIDevicesTest(statement); - } - } - - private void showAIDevicesTest(Statement statement) throws SQLException { - final String showSql = "SHOW AI_DEVICES"; - final List<String> expectedDeviceIdList = new LinkedList<>(Arrays.asList("0", "1", "cpu")); - final List<String> expectedDeviceTypeList = - new LinkedList<>(Arrays.asList("cuda", "cuda", "cpu")); - try (ResultSet resultSet = statement.executeQuery(showSql)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceId,DeviceType"); - while (resultSet.next()) { - String deviceId = resultSet.getString(1); - String deviceType = resultSet.getString(2); - Assert.assertEquals(expectedDeviceIdList.remove(0), deviceId); - Assert.assertEquals(expectedDeviceTypeList.remove(0), deviceType); - } - } - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java deleted file mode 100644 index eb4a981389e..00000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeForecastIT.java +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.ainode.utils.AINodeTestUtils; -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.SQLException; -import java.sql.Statement; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeForecastIT { - - private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = - "SELECT * FROM FORECAST(" - + "model_id=>'%s', " - + "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time DESC LIMIT %d) ORDER BY time, " - + "output_start_time=>%d, " - + "output_length=>%d, " - + "output_interval=>%d, " - + "timecol=>'%s'" - + ")"; - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataInTable(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - @Test - public void forecastTableFunctionTest() throws SQLException { - for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - forecastTableFunctionTest(statement, modelInfo); - } - } - } - - public static void forecastTableFunctionTest( - Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { - // Invoke forecast table function for specified models, there should exist result. - for (int i = 0; i < 4; i++) { - String forecastTableFunctionSQL = - String.format( - FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, - modelInfo.getModelId(), - i, - 5760, - 2880, - 5760, - 96, - 1, - "time"); - try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { - int count = 0; - while (resultSet.next()) { - count++; - } - // Ensure the forecast sentence return results - Assert.assertTrue(count > 0); - } - } - } - - @Test - public void forecastTableFunctionErrorTest() throws SQLException { - for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - forecastTableFunctionErrorTest(statement, modelInfo); - } - } - } - - public static void forecastTableFunctionErrorTest( - Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { - // OUTPUT_START_TIME error - String invalidOutputStartTimeSQL = - String.format( - FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, - modelInfo.getModelId(), - 0, - 5760, - 2880, - 5759, - 96, - 1, - "time"); - errorTest( - statement, - invalidOutputStartTimeSQL, - "701: The OUTPUT_START_TIME should be greater than the maximum timestamp of target time series. Expected greater than [5759] but found [5759]."); - - // OUTPUT_LENGTH error - String invalidOutputLengthSQLWithZero = - String.format( - FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, - modelInfo.getModelId(), - 0, - 5760, - 2880, - 5760, - 0, - 1, - "time"); - errorTest( - statement, invalidOutputLengthSQLWithZero, "701: OUTPUT_LENGTH should be greater than 0"); - - String invalidOutputLengthSQLWithOutOfRange = - String.format( - FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, - modelInfo.getModelId(), - 0, - 5760, - 2880, - 5760, - 2881, - 1, - "time"); - errorTest( - statement, - invalidOutputLengthSQLWithOutOfRange, - "1599: Error occurred while executing forecast:[Attribute output_length expect value between 1 and 2880, got 2881 instead.]"); - - // OUTPUT_INTERVAL error - String invalidOutputIntervalSQL = - String.format( - FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, - modelInfo.getModelId(), - 0, - 5760, - 2880, - 5760, - 96, - -1, - "time"); - errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL should be greater than 0"); - - // TIMECOL error - String invalidTimecolSQL2 = - String.format( - FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, - modelInfo.getModelId(), - 0, - 5760, - 2880, - 5760, - 96, - 1, - "s0"); - errorTest( - statement, invalidTimecolSQL2, "701: The type of the column [s0] is not as expected."); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java deleted file mode 100644 index 8356a055311..00000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeInstanceManagementIT.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.Assert; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.Arrays; -import java.util.HashSet; -import java.util.Set; - -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeInstanceManagementIT { - - private static final String TARGET_DEVICES_STR = "0,1"; - private static final Set<String> TARGET_DEVICES = - new HashSet<>(Arrays.asList(TARGET_DEVICES_STR.split(","))); - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - @Test - public void basicManagementTestInTreeModel() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - basicManagementTest(statement); - } - } - - @Test - public void basicManagementTestInTableModel() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - basicManagementTest(statement); - } - } - - private void basicManagementTest(Statement statement) throws SQLException, InterruptedException { - // Ensure resources - try (ResultSet resultSet = statement.executeQuery("SHOW AI_DEVICES")) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "DeviceId,DeviceType"); - final Set<String> resultDevices = new HashSet<>(); - while (resultSet.next()) { - resultDevices.add(resultSet.getString("DeviceId")); - } - Set<String> expected = new HashSet<>(TARGET_DEVICES); - expected.add("cpu"); - Assert.assertEquals(expected, resultDevices); - } - - // Load sundial to each device - statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); - checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); - // Unload sundial from each device - statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); - - // Load timer_xl to each device - statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES_STR)); - checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR); - // Unload timer_xl from each device - statement.execute(String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES_STR)); - checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR); - } - - private static final int LOOP_CNT = 10; - - // @Test - public void repeatLoadAndUnloadTest() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < LOOP_CNT; i++) { - statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); - checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); - statement.execute( - String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); - } - } - } - - // @Test - public void concurrentLoadAndUnloadTest() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - for (int i = 0; i < LOOP_CNT; i++) { - statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); - statement.execute( - String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); - } - checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); - } - } - - @Test - public void failTestInTreeModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - failTest(statement); - } - } - - @Test - public void failTestInTableModel() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - failTest(statement); - } - } - - private void failTest(Statement statement) { - errorTest( - statement, - "LOAD MODEL unknown TO DEVICES 'cpu,0,1'", - "1504: Model [unknown] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models."); - errorTest( - statement, - "LOAD MODEL sundial TO DEVICES '999'", - "1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); - errorTest( - statement, - "UNLOAD MODEL sundial FROM DEVICES '999'", - "1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); - errorTest( - statement, - "LOAD MODEL sundial TO DEVICES '0,0'", - "1509: Device ID list contains duplicate entries."); - errorTest( - statement, - "UNLOAD MODEL sundial FROM DEVICES '0,0'", - "1510: Device ID list contains duplicate entries."); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java deleted file mode 100644 index 42dc8db520e..00000000000 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.ainode.it; - -import org.apache.iotdb.ainode.utils.AINodeTestUtils; -import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo; -import org.apache.iotdb.it.env.EnvFactory; -import org.apache.iotdb.it.framework.IoTDBTestRunner; -import org.apache.iotdb.itbase.category.AIClusterIT; -import org.apache.iotdb.itbase.env.BaseEnv; - -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.junit.experimental.categories.Category; -import org.junit.runner.RunWith; - -import java.sql.Connection; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.Statement; -import java.util.concurrent.TimeUnit; - -import static org.apache.iotdb.ainode.it.AINodeCallInferenceIT.callInferenceTest; -import static org.apache.iotdb.ainode.it.AINodeForecastIT.forecastTableFunctionTest; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; -import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; - -@RunWith(IoTDBTestRunner.class) -@Category({AIClusterIT.class}) -public class AINodeModelManageIT { - - @BeforeClass - public static void setUp() throws Exception { - // Init 1C1D1A cluster environment - EnvFactory.getEnv().initClusterEnvironment(1, 1); - prepareDataInTree(); - prepareDataInTable(); - } - - @AfterClass - public static void tearDown() throws Exception { - EnvFactory.getEnv().cleanClusterEnvironment(); - } - - @Test - public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - // Test transformers model (chronos2) in tree. - AINodeTestUtils.FakeModelInfo modelInfo = - new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"); - registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2"); - callInferenceTest(statement, modelInfo); - dropUserDefinedModel(statement, modelInfo.getModelId()); - - // Test PytorchModelHubMixin model (mantis) in tree. - modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active"); - registerUserDefinedModel(statement, modelInfo, "file:///data/mantis"); - dropUserDefinedModel(statement, modelInfo.getModelId()); - } - } - - @Test - public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - // Test transformers model (chronos2) in table. - AINodeTestUtils.FakeModelInfo modelInfo = - new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"); - registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2"); - forecastTableFunctionTest(statement, modelInfo); - dropUserDefinedModel(statement, modelInfo.getModelId()); - - // Test PytorchModelHubMixin model (mantis) in table. - modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active"); - registerUserDefinedModel(statement, modelInfo, "file:///data/mantis"); - dropUserDefinedModel(statement, modelInfo.getModelId()); - } - } - - public static void registerUserDefinedModel( - Statement statement, AINodeTestUtils.FakeModelInfo modelInfo, String uri) - throws SQLException, InterruptedException { - String modelId = modelInfo.getModelId(); - String modelType = modelInfo.getModelType(); - String category = modelInfo.getCategory(); - final String CREATE_MODEL_TEMPLATE = "create model %s using uri \"%s\""; - final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; - final String registerSql = String.format(CREATE_MODEL_TEMPLATE, modelId, uri); - final String showSql = String.format("SHOW MODELS %s", modelId); - statement.execute(alterConfigSQL); - statement.execute(registerSql); - boolean loading = true; - for (int retryCnt = 0; retryCnt < 100; retryCnt++) { - try (ResultSet resultSet = statement.executeQuery(showSql)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); - while (resultSet.next()) { - String resultModelId = resultSet.getString(1); - String resultModelType = resultSet.getString(2); - String resultCategory = resultSet.getString(3); - String state = resultSet.getString(4); - assertEquals(modelId, resultModelId); - assertEquals(modelType, resultModelType); - assertEquals(category, resultCategory); - if (state.equals("active")) { - loading = false; - } else if (state.equals("loading")) { - break; - } else { - fail("Unexpected status of model: " + state); - } - } - } - if (!loading) { - break; // Model is loaded successfully - } - TimeUnit.SECONDS.sleep(1); - } - assertFalse(loading); - } - - public static void dropUserDefinedModel(Statement statement, String modelId) throws SQLException { - final String showSql = String.format("SHOW MODELS %s", modelId); - final String dropSql = String.format("DROP MODEL %s", modelId); - statement.execute(dropSql); - try (ResultSet resultSet = statement.executeQuery(showSql)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); - int count = 0; - while (resultSet.next()) { - count++; - } - assertEquals(0, count); - } - } - - @Test - public void dropBuiltInModelErrorTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1506: Cannot delete built-in model: sundial"); - } - } - - @Test - public void dropBuiltInModelErrorTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1506: Cannot delete built-in model: sundial"); - } - } - - @Test - public void showBuiltInModelTestInTree() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); - Statement statement = connection.createStatement()) { - showBuiltInModelTest(statement); - } - } - - @Test - public void showBuiltInModelTestInTable() throws SQLException { - try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); - Statement statement = connection.createStatement(); ) { - showBuiltInModelTest(statement); - } - } - - private void showBuiltInModelTest(Statement statement) throws SQLException { - int built_in_model_count = 0; - final String showSql = "SHOW MODELS"; - try (ResultSet resultSet = statement.executeQuery(showSql)) { - ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); - checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); - while (resultSet.next()) { - built_in_model_count++; - FakeModelInfo modelInfo = - new FakeModelInfo( - resultSet.getString(1), - resultSet.getString(2), - resultSet.getString(3), - resultSet.getString(4)); - assertTrue(AINodeTestUtils.BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId())); - assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); - } - } - assertEquals(AINodeTestUtils.BUILTIN_MODEL_MAP.size(), built_in_model_count); - } -} diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java new file mode 100644 index 00000000000..35b8c6e3aec --- /dev/null +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeSharedClusterIT.java @@ -0,0 +1,555 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.ainode.it; + +import org.apache.iotdb.ainode.utils.AINodeTestUtils; +import org.apache.iotdb.ainode.utils.AINodeTestUtils.FakeModelInfo; +import org.apache.iotdb.it.env.EnvFactory; +import org.apache.iotdb.it.framework.IoTDBTestRunner; +import org.apache.iotdb.itbase.category.AIClusterIT; +import org.apache.iotdb.itbase.env.BaseEnv; + +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Arrays; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.BUILTIN_MODEL_MAP; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkHeader; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelNotOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.checkModelOnSpecifiedDevice; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.errorTest; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTable; +import static org.apache.iotdb.ainode.utils.AINodeTestUtils.prepareDataInTree; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** + * Consolidates AINodeDeviceManageIT, AINodeModelManageIT, AINodeCallInferenceIT, + * AINodeForecastIT, and AINodeInstanceManagementIT into a single class that shares one 1C1D1A + * cluster, avoiding 5 redundant cluster startups (~20 min saved). + */ +@RunWith(IoTDBTestRunner.class) +@Category({AIClusterIT.class}) +public class AINodeSharedClusterIT { + + private static final String TARGET_DEVICES_STR = "0,1"; + private static final Set<String> TARGET_DEVICES = + new HashSet<>(Arrays.asList(TARGET_DEVICES_STR.split(","))); + + private static final String CALL_INFERENCE_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT %d\", generateTime=true, outputLength=%d)"; + private static final String CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE = + "CALL INFERENCE(%s, \"SELECT s%d FROM root.AI LIMIT 256\")"; + private static final int DEFAULT_INPUT_LENGTH = 256; + private static final int DEFAULT_OUTPUT_LENGTH = 48; + + private static final String FORECAST_TABLE_FUNCTION_SQL_TEMPLATE = + "SELECT * FROM FORECAST(" + + "model_id=>'%s', " + + "targets=>(SELECT time, s%d FROM db.AI WHERE time<%d ORDER BY time DESC LIMIT %d) ORDER BY time, " + + "output_start_time=>%d, " + + "output_length=>%d, " + + "output_interval=>%d, " + + "timecol=>'%s'" + + ")"; + + @BeforeClass + public static void setUp() throws Exception { + EnvFactory.getEnv().initClusterEnvironment(1, 1); + prepareDataInTree(); + prepareDataInTable(); + } + + @AfterClass + public static void tearDown() throws Exception { + EnvFactory.getEnv().cleanClusterEnvironment(); + } + + // ========== DeviceManage tests ========== + + @Test + public void showAIDeviceTestInTree() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + showAIDevicesTest(statement); + } + } + + @Test + public void showAIDeviceTestInTable() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + showAIDevicesTest(statement); + } + } + + private void showAIDevicesTest(Statement statement) throws SQLException { + final String showSql = "SHOW AI_DEVICES"; + final List<String> expectedDeviceIdList = new LinkedList<>(Arrays.asList("0", "1", "cpu")); + final List<String> expectedDeviceTypeList = + new LinkedList<>(Arrays.asList("cuda", "cuda", "cpu")); + try (ResultSet resultSet = statement.executeQuery(showSql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "DeviceId,DeviceType"); + while (resultSet.next()) { + String deviceId = resultSet.getString(1); + String deviceType = resultSet.getString(2); + Assert.assertEquals(expectedDeviceIdList.remove(0), deviceId); + Assert.assertEquals(expectedDeviceTypeList.remove(0), deviceType); + } + } + } + + // ========== ModelManage tests ========== + + @Test + public void userDefinedModelManagementTestInTree() throws SQLException, InterruptedException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + FakeModelInfo modelInfo = + new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2"); + callInferenceTest(statement, modelInfo); + dropUserDefinedModel(statement, modelInfo.getModelId()); + + modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/mantis"); + dropUserDefinedModel(statement, modelInfo.getModelId()); + } + } + + @Test + public void userDefinedModelManagementTestInTable() throws SQLException, InterruptedException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + FakeModelInfo modelInfo = + new FakeModelInfo("user_chronos", "custom_t5", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/chronos2"); + forecastTableFunctionTest(statement, modelInfo); + dropUserDefinedModel(statement, modelInfo.getModelId()); + + modelInfo = new FakeModelInfo("user_mantis", "custom_mantis", "user_defined", "active"); + registerUserDefinedModel(statement, modelInfo, "file:///data/mantis"); + dropUserDefinedModel(statement, modelInfo.getModelId()); + } + } + + @Test + public void dropBuiltInModelErrorTestInTree() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + errorTest(statement, "drop model sundial", "1506: Cannot delete built-in model: sundial"); + } + } + + @Test + public void dropBuiltInModelErrorTestInTable() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + errorTest(statement, "drop model sundial", "1506: Cannot delete built-in model: sundial"); + } + } + + @Test + public void showBuiltInModelTestInTree() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + showBuiltInModelTest(statement); + } + } + + @Test + public void showBuiltInModelTestInTable() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + showBuiltInModelTest(statement); + } + } + + private void showBuiltInModelTest(Statement statement) throws SQLException { + int built_in_model_count = 0; + final String showSql = "SHOW MODELS"; + try (ResultSet resultSet = statement.executeQuery(showSql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); + while (resultSet.next()) { + built_in_model_count++; + FakeModelInfo modelInfo = + new FakeModelInfo( + resultSet.getString(1), + resultSet.getString(2), + resultSet.getString(3), + resultSet.getString(4)); + assertTrue(BUILTIN_MODEL_MAP.containsKey(modelInfo.getModelId())); + assertEquals(BUILTIN_MODEL_MAP.get(modelInfo.getModelId()), modelInfo); + } + } + assertEquals(BUILTIN_MODEL_MAP.size(), built_in_model_count); + } + + // ========== CallInference tests ========== + + @Test + public void callInferenceTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + callInferenceTest(statement, modelInfo); + callInferenceByDefaultTest(statement, modelInfo); + callInferenceErrorTest(statement, modelInfo); + } + } + } + + public static void callInferenceTest(Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) + throws SQLException { + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format( + CALL_INFERENCE_SQL_TEMPLATE, + modelInfo.getModelId(), + i, + DEFAULT_INPUT_LENGTH, + DEFAULT_OUTPUT_LENGTH); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "Time,output"); + int count = 0; + while (resultSet.next()) { + count++; + } + Assert.assertEquals(DEFAULT_OUTPUT_LENGTH, count); + } + } + } + + public static void callInferenceByDefaultTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + for (int i = 0; i < 4; i++) { + String callInferenceSQL = + String.format(CALL_INFERENCE_BY_DEFAULT_SQL_TEMPLATE, modelInfo.getModelId(), i); + try (ResultSet resultSet = statement.executeQuery(callInferenceSQL)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "output"); + int count = 0; + while (resultSet.next()) { + count++; + } + Assert.assertTrue(count > 0); + } + } + } + + public static void callInferenceErrorTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) { + String multiVariateSQL = + String.format( + "CALL INFERENCE(%s, \"SELECT s0,s1 FROM root.AI LIMIT 128\", generateTime=true, outputLength=10)", + modelInfo.getModelId()); + errorTest( + statement, + multiVariateSQL, + "701: Call inference function should not contain more than one input column, found [2] input columns."); + } + + // ========== Forecast tests ========== + + @Test + public void forecastTableFunctionTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + forecastTableFunctionTest(statement, modelInfo); + } + } + } + + public static void forecastTableFunctionTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + for (int i = 0; i < 4; i++) { + String forecastTableFunctionSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelInfo.getModelId(), + i, + 5760, + 2880, + 5760, + 96, + 1, + "time"); + try (ResultSet resultSet = statement.executeQuery(forecastTableFunctionSQL)) { + int count = 0; + while (resultSet.next()) { + count++; + } + Assert.assertTrue(count > 0); + } + } + } + + @Test + public void forecastTableFunctionErrorTest() throws SQLException { + for (AINodeTestUtils.FakeModelInfo modelInfo : BUILTIN_MODEL_MAP.values()) { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + forecastTableFunctionErrorTest(statement, modelInfo); + } + } + } + + public static void forecastTableFunctionErrorTest( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo) throws SQLException { + String invalidOutputStartTimeSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelInfo.getModelId(), + 0, + 5760, + 2880, + 5759, + 96, + 1, + "time"); + errorTest( + statement, + invalidOutputStartTimeSQL, + "701: The OUTPUT_START_TIME should be greater than the maximum timestamp of target time series. Expected greater than [5759] but found [5759]."); + + String invalidOutputLengthSQLWithZero = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelInfo.getModelId(), + 0, + 5760, + 2880, + 5760, + 0, + 1, + "time"); + errorTest( + statement, invalidOutputLengthSQLWithZero, "701: OUTPUT_LENGTH should be greater than 0"); + + String invalidOutputLengthSQLWithOutOfRange = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelInfo.getModelId(), + 0, + 5760, + 2880, + 5760, + 2881, + 1, + "time"); + errorTest( + statement, + invalidOutputLengthSQLWithOutOfRange, + "1599: Error occurred while executing forecast:[Attribute output_length expect value between 1 and 2880, got 2881 instead.]"); + + String invalidOutputIntervalSQL = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelInfo.getModelId(), + 0, + 5760, + 2880, + 5760, + 96, + -1, + "time"); + errorTest(statement, invalidOutputIntervalSQL, "701: OUTPUT_INTERVAL should be greater than 0"); + + String invalidTimecolSQL2 = + String.format( + FORECAST_TABLE_FUNCTION_SQL_TEMPLATE, + modelInfo.getModelId(), + 0, + 5760, + 2880, + 5760, + 96, + 1, + "s0"); + errorTest( + statement, invalidTimecolSQL2, "701: The type of the column [s0] is not as expected."); + } + + // ========== InstanceManagement tests ========== + + @Test + public void instanceBasicManagementTestInTreeModel() + throws SQLException, InterruptedException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + instanceBasicManagementTest(statement); + } + } + + @Test + public void instanceBasicManagementTestInTableModel() + throws SQLException, InterruptedException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + instanceBasicManagementTest(statement); + } + } + + private void instanceBasicManagementTest(Statement statement) + throws SQLException, InterruptedException { + try (ResultSet resultSet = statement.executeQuery("SHOW AI_DEVICES")) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "DeviceId,DeviceType"); + final Set<String> resultDevices = new HashSet<>(); + while (resultSet.next()) { + resultDevices.add(resultSet.getString("DeviceId")); + } + Set<String> expected = new HashSet<>(TARGET_DEVICES); + expected.add("cpu"); + Assert.assertEquals(expected, resultDevices); + } + + statement.execute(String.format("LOAD MODEL sundial TO DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); + statement.execute(String.format("UNLOAD MODEL sundial FROM DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelNotOnSpecifiedDevice(statement, "sundial", TARGET_DEVICES_STR); + + statement.execute(String.format("LOAD MODEL timer_xl TO DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR); + statement.execute( + String.format("UNLOAD MODEL timer_xl FROM DEVICES '%s'", TARGET_DEVICES_STR)); + checkModelNotOnSpecifiedDevice(statement, "timer_xl", TARGET_DEVICES_STR); + } + + @Test + public void instanceFailTestInTreeModel() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + instanceFailTest(statement); + } + } + + @Test + public void instanceFailTestInTableModel() throws SQLException { + try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); + Statement statement = connection.createStatement()) { + instanceFailTest(statement); + } + } + + private void instanceFailTest(Statement statement) { + errorTest( + statement, + "LOAD MODEL unknown TO DEVICES 'cpu,0,1'", + "1504: Model [unknown] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models."); + errorTest( + statement, + "LOAD MODEL sundial TO DEVICES '999'", + "1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); + errorTest( + statement, + "UNLOAD MODEL sundial FROM DEVICES '999'", + "1508: AIDevice ID [999] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices."); + errorTest( + statement, + "LOAD MODEL sundial TO DEVICES '0,0'", + "1509: Device ID list contains duplicate entries."); + errorTest( + statement, + "UNLOAD MODEL sundial FROM DEVICES '0,0'", + "1510: Device ID list contains duplicate entries."); + } + + // ========== Helper methods (from ModelManageIT) ========== + + private static void registerUserDefinedModel( + Statement statement, AINodeTestUtils.FakeModelInfo modelInfo, String uri) + throws SQLException, InterruptedException { + String modelId = modelInfo.getModelId(); + String modelType = modelInfo.getModelType(); + String category = modelInfo.getCategory(); + final String CREATE_MODEL_TEMPLATE = "create model %s using uri \"%s\""; + final String alterConfigSQL = "set configuration \"trusted_uri_pattern\"='.*'"; + final String registerSql = String.format(CREATE_MODEL_TEMPLATE, modelId, uri); + final String showSql = String.format("SHOW MODELS %s", modelId); + statement.execute(alterConfigSQL); + statement.execute(registerSql); + boolean loading = true; + for (int retryCnt = 0; retryCnt < 100; retryCnt++) { + try (ResultSet resultSet = statement.executeQuery(showSql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); + while (resultSet.next()) { + String resultModelId = resultSet.getString(1); + String resultModelType = resultSet.getString(2); + String resultCategory = resultSet.getString(3); + String state = resultSet.getString(4); + assertEquals(modelId, resultModelId); + assertEquals(modelType, resultModelType); + assertEquals(category, resultCategory); + if (state.equals("active")) { + loading = false; + } else if (state.equals("loading")) { + break; + } else { + fail("Unexpected status of model: " + state); + } + } + } + if (!loading) { + break; + } + TimeUnit.SECONDS.sleep(1); + } + assertFalse(loading); + } + + private static void dropUserDefinedModel(Statement statement, String modelId) + throws SQLException { + final String showSql = String.format("SHOW MODELS %s", modelId); + final String dropSql = String.format("DROP MODEL %s", modelId); + statement.execute(dropSql); + try (ResultSet resultSet = statement.executeQuery(showSql)) { + ResultSetMetaData resultSetMetaData = resultSet.getMetaData(); + checkHeader(resultSetMetaData, "ModelId,ModelType,Category,State"); + int count = 0; + while (resultSet.next()) { + count++; + } + assertEquals(0, count); + } + } +} diff --git a/iotdb-core/ainode/build_binary.py b/iotdb-core/ainode/build_binary.py index 56f19f43cf5..48969a66f4b 100644 --- a/iotdb-core/ainode/build_binary.py +++ b/iotdb-core/ainode/build_binary.py @@ -21,7 +21,9 @@ PyInstaller build script (Python version) """ +import hashlib import os +import shutil import subprocess import sys from pathlib import Path @@ -317,6 +319,97 @@ def poetry_install_with_accel(poetry_exe, script_dir, venv_env, accelerator): ) +def compute_source_hash(script_dir): + """ + Compute a SHA256 hash over all files that affect the PyInstaller output. + + Includes Python source files, the spec file, pyproject.toml, and poetry.lock. + """ + hasher = hashlib.sha256() + + hash_targets = [] + + for pattern in ("**/*.py", "**/*.spec"): + hash_targets.extend(script_dir.glob(pattern)) + + for name in ("pyproject.toml", "poetry.lock"): + f = script_dir / name + if f.exists(): + hash_targets.append(f) + + # Also include the thrift/client-py sources that get copied in + client_py_dir = script_dir.parent.parent / "iotdb-client" / "client-py" / "iotdb" + if client_py_dir.is_dir(): + hash_targets.extend(client_py_dir.rglob("*.py")) + + hash_targets.sort(key=lambda p: str(p)) + + for f in hash_targets: + try: + rel = f.relative_to(script_dir) + except ValueError: + rel = f + hasher.update(str(rel).encode()) + hasher.update(f.read_bytes()) + + return hasher.hexdigest() + + +def get_dist_cache_dir(): + """Get the directory used to cache PyInstaller dist output.""" + return get_venv_base_dir() / "dist-cache" + + +def try_restore_dist_cache(script_dir): + """ + Try to restore the dist/ directory from cache. + + Returns True if cache hit, False otherwise. + """ + source_hash = compute_source_hash(script_dir) + cache_dir = get_dist_cache_dir() + hash_file = cache_dir / "source_hash" + cached_dist = cache_dir / "ainode" + dist_dir = script_dir / "dist" / "ainode" + + print(f"Source hash: {source_hash}") + + if hash_file.exists() and cached_dist.is_dir(): + cached_hash = hash_file.read_text().strip() + if cached_hash == source_hash: + print("Cache hit — restoring dist/ from cache, skipping PyInstaller build") + dist_dir.parent.mkdir(parents=True, exist_ok=True) + if dist_dir.exists(): + shutil.rmtree(dist_dir) + shutil.copytree(cached_dist, dist_dir, symlinks=True) + return True + else: + print("Cache miss — source hash changed, will rebuild") + else: + print("No dist cache found, will build from scratch") + + return False + + +def save_dist_cache(script_dir): + """Save the dist/ directory to cache after a successful build.""" + source_hash = compute_source_hash(script_dir) + cache_dir = get_dist_cache_dir() + cached_dist = cache_dir / "ainode" + dist_dir = script_dir / "dist" / "ainode" + + if not dist_dir.is_dir(): + print("Warning: dist/ainode not found, skipping cache save") + return + + cache_dir.mkdir(parents=True, exist_ok=True) + if cached_dist.exists(): + shutil.rmtree(cached_dist) + shutil.copytree(dist_dir, cached_dist, symlinks=True) + (cache_dir / "source_hash").write_text(source_hash) + print(f"Saved dist cache (hash: {source_hash})") + + def build(): """ Execute the complete build process. @@ -325,7 +418,9 @@ def build(): 1. Setup virtual environment (outside project directory) 2. Update pip and install 2.2.1 poetry 3. Install project dependencies (including PyInstaller from pyproject.toml) - 4. Build executable using PyInstaller + 4. Check dist cache — skip PyInstaller if source hasn't changed + 5. Build executable using PyInstaller (if cache miss) + 6. Save dist to cache """ script_dir = Path(__file__).parent @@ -344,6 +439,13 @@ def build(): print("=" * 50) print() + if try_restore_dist_cache(script_dir): + print() + print("=" * 50) + print("Build completed (from cache)!") + print("=" * 50) + return + print("Starting build...") print() @@ -383,6 +485,8 @@ def build(): print(f"\nError: Build failed: {e}") sys.exit(1) + save_dist_cache(script_dir) + print() print("=" * 50) print("Build completed!")
