diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java index 2e62f618091c..e148be6b20a4 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeClusterConfigIT.java @@ -107,7 +107,4 @@ private void aiNodeRegisterAndRemoveTest(Statement statement) throws SQLExceptio } Assert.fail("The target AINode is not removed successfully after all retries."); } - - // TODO: We might need to add remove unknown test in the future, but current infrastructure is too - // hard to implement it. } diff --git a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java index b92b80aecf32..3315617e7fda 100644 --- a/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java +++ b/integration-test/src/test/java/org/apache/iotdb/ainode/it/AINodeModelManageIT.java @@ -131,7 +131,7 @@ private void userDefinedModelManagementTest(Statement statement) public void dropBuiltInModelErrorTestInTree() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TREE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); + errorTest(statement, "drop model sundial", "1506: Cannot delete built-in model: sundial"); } } @@ -139,7 +139,7 @@ public void dropBuiltInModelErrorTestInTree() throws SQLException { public void dropBuiltInModelErrorTestInTable() throws SQLException { try (Connection connection = EnvFactory.getEnv().getConnection(BaseEnv.TABLE_SQL_DIALECT); Statement statement = connection.createStatement()) { - errorTest(statement, "drop model sundial", "1510: Cannot delete built-in model: sundial"); + errorTest(statement, "drop model sundial", "1506: Cannot delete built-in model: sundial"); } } diff --git a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java index 7f94dd696ac5..6bd0da063b03 100644 --- a/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java +++ b/iotdb-client/service-rpc/src/main/java/org/apache/iotdb/rpc/TSStatusCode.java @@ -244,16 +244,17 @@ public enum TSStatusCode { CQ_UPDATE_LAST_EXEC_TIME_ERROR(1403), // AI - CREATE_MODEL_ERROR(1500), - DROP_MODEL_ERROR(1501), - MODEL_EXIST_ERROR(1502), - GET_MODEL_INFO_ERROR(1503), - NO_REGISTERED_AI_NODE_ERROR(1504), - MODEL_NOT_FOUND_ERROR(1505), - REGISTER_AI_NODE_ERROR(1506), - UNAVAILABLE_AI_DEVICE_ERROR(1507), - AI_NODE_INTERNAL_ERROR(1510), - REMOVE_AI_NODE_ERROR(1511), + NO_REGISTERED_AI_NODE_ERROR(1500), + REGISTER_AI_NODE_ERROR(1501), + REMOVE_AI_NODE_ERROR(1502), + MODEL_EXISTED_ERROR(1503), + MODEL_NOT_EXIST_ERROR(1504), + CREATE_MODEL_ERROR(1505), + DROP_BUILTIN_MODEL_ERROR(1506), + DROP_MODEL_ERROR(1507), + UNAVAILABLE_AI_DEVICE_ERROR(1508), + + AINODE_INTERNAL_ERROR(1599), // In case somebody too lazy to add a new error code // Pipe Plugin CREATE_PIPE_PLUGIN_ERROR(1600), diff --git a/iotdb-core/ainode/iotdb/ainode/core/config.py b/iotdb-core/ainode/iotdb/ainode/core/config.py index e465df7e36d2..b14efa3bedff 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/config.py +++ b/iotdb-core/ainode/iotdb/ainode/core/config.py @@ -46,7 +46,7 @@ AINODE_THRIFT_COMPRESSION_ENABLED, AINODE_VERSION_INFO, ) -from iotdb.ainode.core.exception import BadNodeUrlError +from iotdb.ainode.core.exception import BadNodeUrlException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.util.decorator import singleton from iotdb.thrift.common.ttypes import TEndPoint @@ -437,7 +437,7 @@ def _load_config_from_file(self) -> None: file_configs["ain_cluster_ingress_time_zone"] ) - except BadNodeUrlError: + except BadNodeUrlException: logger.warning("Cannot load AINode conf file, use default configuration.") except Exception as e: @@ -489,7 +489,7 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint: """ split = endpoint_url.split(":") if len(split) != 2: - raise BadNodeUrlError(endpoint_url) + raise BadNodeUrlException(endpoint_url) ip = split[0] try: @@ -497,4 +497,4 @@ def parse_endpoint_url(endpoint_url: str) -> TEndPoint: result = TEndPoint(ip, port) return result except ValueError: - raise BadNodeUrlError(endpoint_url) + raise BadNodeUrlException(endpoint_url) diff --git a/iotdb-core/ainode/iotdb/ainode/core/constant.py b/iotdb-core/ainode/iotdb/ainode/core/constant.py index d8f730c829c8..44e76840f73c 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/constant.py +++ b/iotdb-core/ainode/iotdb/ainode/core/constant.py @@ -81,33 +81,18 @@ class TSStatusCode(Enum): SUCCESS_STATUS = 200 REDIRECTION_RECOMMEND = 400 - MODEL_EXIST_ERROR = 1502 - MODEL_NOT_FOUND_ERROR = 1505 - UNAVAILABLE_AI_DEVICE_ERROR = 1507 - AINODE_INTERNAL_ERROR = 1510 + MODEL_EXISTED_ERROR = 1503 + MODEL_NOT_EXIST_ERROR = 1504 + CREATE_MODEL_ERROR = 1505 + DROP_BUILTIN_MODEL_ERROR = 1506 + DROP_MODEL_ERROR = 1507 + UNAVAILABLE_AI_DEVICE_ERROR = 1508 + INVALID_URI_ERROR = 1511 INVALID_INFERENCE_CONFIG = 1512 INFERENCE_INTERNAL_ERROR = 1520 - def get_status_code(self) -> int: - return self.value - + AINODE_INTERNAL_ERROR = 1599 # In case somebody too lazy to add a new error code -class HyperparameterName(Enum): - # Training hyperparameter - LEARNING_RATE = "learning_rate" - EPOCHS = "epochs" - BATCH_SIZE = "batch_size" - USE_GPU = "use_gpu" - NUM_WORKERS = "num_workers" - - # Structure hyperparameter - KERNEL_SIZE = "kernel_size" - INPUT_VARS = "input_vars" - BLOCK_TYPE = "block_type" - D_MODEL = "d_model" - INNER_LAYERS = "inner_layer" - OUTER_LAYERS = "outer_layer" - - def name(self): + def get_status_code(self) -> int: return self.value diff --git a/iotdb-core/ainode/iotdb/ainode/core/exception.py b/iotdb-core/ainode/iotdb/ainode/core/exception.py index 30b9d54dcc7d..b007ee58c484 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/exception.py +++ b/iotdb-core/ainode/iotdb/ainode/core/exception.py @@ -23,7 +23,7 @@ ) -class _BaseError(Exception): +class _BaseException(Exception): """Base class for exceptions in this module.""" def __init__(self): @@ -33,86 +33,74 @@ def __str__(self) -> str: return self.message -class BadNodeUrlError(_BaseError): +class BadNodeUrlException(_BaseException): def __init__(self, node_url: str): + super().__init__() self.message = "Bad node url: {}".format(node_url) -class ModelNotExistError(_BaseError): - def __init__(self, msg: str): - self.message = "Model is not exists: {} ".format(msg) - - -class BadConfigValueError(_BaseError): - def __init__(self, config_name: str, config_value, hint: str = ""): - self.message = "Bad value [{0}] for config {1}. {2}".format( - config_value, config_name, hint - ) - +# ==================== Model Management ==================== -class MissingConfigError(_BaseError): - def __init__(self, config_name: str): - self.message = "Missing config: {}".format(config_name) - -class MissingOptionError(_BaseError): - def __init__(self, config_name: str): - self.message = "Missing task option: {}".format(config_name) +class ModelExistedException(_BaseException): + def __init__(self, model_id: str): + super().__init__() + self.message = "Model {} already exists".format(model_id) -class RedundantOptionError(_BaseError): - def __init__(self, option_name: str): - self.message = "Redundant task option: {}".format(option_name) +class ModelNotExistException(_BaseException): + def __init__(self, model_id: str): + super().__init__() + self.message = "Model {} does not exist".format(model_id) -class WrongTypeConfigError(_BaseError): - def __init__(self, config_name: str, expected_type: str): - self.message = "Wrong type for config: {0}, expected: {1}".format( - config_name, expected_type +class InvalidModelUriException(_BaseException): + def __init__(self, msg: str): + super().__init__() + self.message = ( + "Model registration failed because the specified uri is invalid: {}".format( + msg + ) ) -class UnsupportedError(_BaseError): - def __init__(self, msg: str): - self.message = "{0} is not supported in current version".format(msg) +class BuiltInModelDeletionException(_BaseException): + def __init__(self, model_id: str): + super().__init__() + self.message = "Cannot delete built-in model: {}".format(model_id) -class InvalidUriError(_BaseError): - def __init__(self, uri: str): - self.message = "Invalid uri: {}, there are no {} or {} under this uri.".format( - uri, MODEL_WEIGHTS_FILE_IN_PT, MODEL_CONFIG_FILE_IN_YAML +class BadConfigValueException(_BaseException): + def __init__(self, config_name: str, config_value, hint: str = ""): + super().__init__() + self.message = "Bad value [{0}] for config {1}. {2}".format( + config_value, config_name, hint ) -class InvalidWindowArgumentError(_BaseError): - def __init__(self, window_interval, window_step, dataset_length): - self.message = f"Invalid inference input: window_interval {window_interval}, window_step {window_step}, dataset_length {dataset_length}" - - -class InferenceModelInternalError(_BaseError): +class InferenceModelInternalException(_BaseException): def __init__(self, msg: str): + super().__init__() self.message = "Inference model internal error: {0}".format(msg) -class BuiltInModelNotSupportError(_BaseError): +class BuiltInModelNotSupportException(_BaseException): def __init__(self, msg: str): + super().__init__() self.message = "Built-in model not support: {0}".format(msg) -class BuiltInModelDeletionError(_BaseError): - def __init__(self, model_id: str): - self.message = "Cannot delete built-in model: {0}".format(model_id) - - -class WrongAttributeTypeError(_BaseError): +class WrongAttributeTypeException(_BaseException): def __init__(self, attribute_name: str, expected_type: str): + super().__init__() self.message = "Wrong type for attribute: {0}, expected: {1}".format( attribute_name, expected_type ) -class NumericalRangeException(_BaseError): +class NumericalRangeException(_BaseException): def __init__(self, attribute_name: str, value, min_value, max_value): + super().__init__() self.message = ( "Attribute {0} expect value between {1} and {2}, got {3} instead.".format( attribute_name, min_value, max_value, value @@ -120,35 +108,19 @@ def __init__(self, attribute_name: str, value, min_value, max_value): ) -class StringRangeException(_BaseError): +class StringRangeException(_BaseException): def __init__(self, attribute_name: str, value: str, expect_value): + super().__init__() self.message = "Attribute {0} expect value in {1}, got {2} instead.".format( attribute_name, expect_value, value ) -class ListRangeException(_BaseError): +class ListRangeException(_BaseException): def __init__(self, attribute_name: str, value: list, expected_type: str): + super().__init__() self.message = ( "Attribute {0} expect value type list[{1}], got {2} instead.".format( attribute_name, expected_type, value ) ) - - -class AttributeNotSupportError(_BaseError): - def __init__(self, model_name: str, attribute_name: str): - self.message = "Attribute {0} is not supported in model {1}".format( - attribute_name, model_name - ) - - -# This is used to extract the key message in RuntimeError instead of the traceback message -def runtime_error_extractor(error_message): - pattern = re.compile(r"RuntimeError: (.+)") - match = pattern.search(error_message) - - if match: - return match.group(1) - else: - return "" diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py b/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py index d06ba55546fe..d9a2d8663e59 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/dispatcher/basic_dispatcher.py @@ -16,7 +16,7 @@ # under the License. # -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.dispatcher.abstract_dispatcher import ( AbstractDispatcher, ) @@ -41,7 +41,7 @@ def _select_pool_by_hash(self, req, pool_ids) -> int: """ model_id = req.model_id if not pool_ids: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"No available pools for model {model_id}" ) start_idx = hash(req.req_id) % len(pool_ids) @@ -51,7 +51,7 @@ def _select_pool_by_hash(self, req, pool_ids) -> int: state = self.pool_states[pool_id] if state == PoolState.RUNNING: return pool_id - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"No RUNNING pools available for model {model_id}" ) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py index 8ffa89ffd675..c580a89916d5 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_controller.py @@ -24,7 +24,7 @@ import torch.multiprocessing as mp -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, InferenceRequestProxy, @@ -374,7 +374,7 @@ def add_request(self, req: InferenceRequest, infer_proxy: InferenceRequestProxy) if not self.has_request_pools(model_id): logger.error(f"[Inference] No pools found for model {model_id}.") infer_proxy.set_result(None) - raise InferenceModelInternalError( + raise InferenceModelInternalException( "Dispatch request failed, because no inference pools are init." ) # TODO: Implement adaptive scaling based on requests.(e.g. lazy initialization) diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py index a700dcee4733..b85f64d42cc4 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_group.py @@ -19,7 +19,7 @@ import torch.multiprocessing as mp -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.dispatcher.basic_dispatcher import BasicDispatcher from iotdb.ainode.core.inference.inference_request import ( InferenceRequest, @@ -90,14 +90,14 @@ def dispatch_request( def get_request_pool(self, pool_id) -> InferenceRequestPool: if pool_id not in self.pool_group: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference][Pool-{pool_id}] Pool not found for model {self.model_id}" ) return self.pool_group[pool_id][0] def get_request_queue(self, pool_id) -> mp.Queue: if pool_id not in self.pool_group: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference][Pool-{pool_id}] Pool not found for model {self.model_id}" ) return self.pool_group[pool_id][1] diff --git a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py index d2e7292ecd8f..21140cafb1fe 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/inference/pool_scheduler/basic_pool_scheduler.py @@ -20,7 +20,7 @@ import torch -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.pool_group import PoolGroup from iotdb.ainode.core.inference.pool_scheduler.abstract_pool_scheduler import ( AbstractPoolScheduler, @@ -113,7 +113,7 @@ def schedule(self, model_id: str) -> List[ScaleAction]: if model_id not in self._request_pool_map: pool_num = estimate_pool_size(self.DEFAULT_DEVICE, model_id) if pool_num <= 0: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"Not enough memory to run model {model_id}." ) return [ScaleAction(ScaleActionType.SCALE_UP, pool_num, model_id)] diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py index 1ce2e84e0592..34c315274f59 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/inference_manager.py @@ -27,7 +27,7 @@ from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode from iotdb.ainode.core.exception import ( - InferenceModelInternalError, + InferenceModelInternalException, NumericalRangeException, ) from iotdb.ainode.core.inference.inference_request import ( @@ -161,7 +161,7 @@ def _process_request(self, req): return outputs except Exception as e: logger.error(e) - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) finally: with self._result_wrapper_lock: del self._result_wrapper_map[req_id] diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py index 8ffb33d91e2d..ef0846c3d786 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/model_manager.py @@ -16,12 +16,16 @@ # under the License. # -from typing import Any, List, Optional +from typing import Optional from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BuiltInModelDeletionError +from iotdb.ainode.core.exception import ( + BuiltInModelDeletionException, + InvalidModelUriException, + ModelExistedException, + ModelNotExistException, +) from iotdb.ainode.core.log import Logger -from iotdb.ainode.core.model.model_loader import load_model from iotdb.ainode.core.model.model_storage import ModelCategory, ModelInfo, ModelStorage from iotdb.ainode.core.rpc.status import get_status from iotdb.ainode.core.util.decorator import singleton @@ -47,16 +51,15 @@ def register_model( req: TRegisterModelReq, ) -> TRegisterModelResp: try: - if self._model_storage.register_model(model_id=req.modelId, uri=req.uri): - return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) - return TRegisterModelResp(get_status(TSStatusCode.AINODE_INTERNAL_ERROR)) - except ValueError as e: + self._model_storage.register_model(model_id=req.modelId, uri=req.uri) + return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS)) + except ModelExistedException as e: return TRegisterModelResp( - get_status(TSStatusCode.INVALID_URI_ERROR, str(e)) + get_status(TSStatusCode.MODEL_EXISTED_ERROR, str(e)) ) - except Exception as e: + except InvalidModelUriException as e: return TRegisterModelResp( - get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + get_status(TSStatusCode.CREATE_MODEL_ERROR, str(e)) ) def show_models(self, req: TShowModelsReq) -> TShowModelsResp: @@ -67,12 +70,12 @@ def delete_model(self, req: TDeleteModelReq) -> TSStatus: try: self._model_storage.delete_model(req.modelId) return get_status(TSStatusCode.SUCCESS_STATUS) - except BuiltInModelDeletionError as e: - logger.warning(e) - return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + except ModelNotExistException as e: + return get_status(TSStatusCode.MODEL_NOT_EXIST_ERROR, str(e)) + except BuiltInModelDeletionException as e: + return get_status(TSStatusCode.DROP_BUILTIN_MODEL_ERROR, str(e)) except Exception as e: - logger.warning(e) - return get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) + return get_status(TSStatusCode.DROP_MODEL_ERROR, str(e)) def get_model_info( self, @@ -81,19 +84,9 @@ def get_model_info( ) -> Optional[ModelInfo]: return self._model_storage.get_model_info(model_id, category) - def get_model_infos( - self, - category: Optional[ModelCategory] = None, - model_type: Optional[str] = None, - ) -> List[ModelInfo]: - return self._model_storage.get_model_infos(category, model_type) - def _refresh(self): """Refresh the model list (re-scan the file system)""" self._model_storage.discover_all_models() - def get_registered_models(self) -> List[str]: - return self._model_storage.get_registered_models() - def is_model_registered(self, model_id: str) -> bool: return self._model_storage.is_model_registered(model_id) diff --git a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py index 23a98f26bbff..175168762015 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/manager/utils.py @@ -22,7 +22,7 @@ import torch from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.exception import ModelNotExistException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.manager.model_manager import ModelManager from iotdb.ainode.core.model.model_loader import load_model @@ -86,7 +86,7 @@ def estimate_pool_size(device: torch.device, model_id: str) -> int: logger.error( f"[Inference] Cannot estimate inference pool size on device: {device}, because model: {model_id} is not supported." ) - raise ModelNotExistError(model_id) + raise ModelNotExistException(model_id) system_res = evaluate_system_resources(device) free_mem = system_res["free_mem"] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py index c42ec98551b8..9f1801b5073a 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_constants.py @@ -24,13 +24,6 @@ MODEL_CONFIG_FILE_IN_YAML = "config.yaml" -# Model file constants -MODEL_WEIGHTS_FILE_IN_SAFETENSORS = "model.safetensors" -MODEL_CONFIG_FILE_IN_JSON = "config.json" -MODEL_WEIGHTS_FILE_IN_PT = "model.pt" -MODEL_CONFIG_FILE_IN_YAML = "config.yaml" - - class ModelCategory(Enum): BUILTIN = "builtin" USER_DEFINED = "user_defined" diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py index a6e3b1f7b5e3..29a7c14c9721 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_loader.py @@ -32,7 +32,7 @@ ) from iotdb.ainode.core.config import AINodeDescriptor -from iotdb.ainode.core.exception import ModelNotExistError +from iotdb.ainode.core.exception import ModelNotExistException from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ModelCategory from iotdb.ainode.core.model.model_info import ModelInfo @@ -131,7 +131,7 @@ def load_model_from_pt(model_info: ModelInfo, **kwargs): model_file = os.path.join(model_path, "model.pt") if not os.path.exists(model_file): logger.error(f"Model file not found at {model_file}.") - raise ModelNotExistError(model_file) + raise ModelNotExistException(model_file) model = torch.jit.load(model_file) if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: return model diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py index 5194ed4df1bd..a79371d2e791 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_storage.py @@ -23,12 +23,16 @@ from pathlib import Path from typing import Dict, List, Optional -from huggingface_hub import hf_hub_download, snapshot_download +from huggingface_hub import hf_hub_download from transformers import AutoConfig, AutoModelForCausalLM from iotdb.ainode.core.config import AINodeDescriptor from iotdb.ainode.core.constant import TSStatusCode -from iotdb.ainode.core.exception import BuiltInModelDeletionError +from iotdb.ainode.core.exception import ( + BuiltInModelDeletionException, + ModelExistedException, + ModelNotExistException, +) from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_JSON, @@ -43,6 +47,8 @@ ModelInfo, ) from iotdb.ainode.core.model.utils import ( + _fetch_model_from_hf_repo, + _fetch_model_from_local, ensure_init_file, get_parsed_uri, import_class_from_path, @@ -233,12 +239,23 @@ def _process_user_defined_model_directory(self, model_dir: str, model_id: str): # ==================== Registration Methods ==================== - def register_model(self, model_id: str, uri: str) -> bool: + def register_model(self, model_id: str, uri: str): """ - Supported URI formats: - - repo:// (Maybe in the future) - - file:// + Register a user-defined model from a given URI. + Args: + model_id (str): Unique identifier for the model. + uri (str): URI to fetch the model from. + Supported URI formats: + - file:// + - repo:// (Maybe in the future) + Raises: + ModelExistedException: If the model_id already exists. + InvalidModelUriException: If the URI format is invalid. """ + + if self.is_model_registered(model_id): + raise ModelExistedException(model_id) + uri_type = parse_uri_type(uri) parsed_uri = get_parsed_uri(uri) @@ -249,9 +266,9 @@ def register_model(self, model_id: str, uri: str) -> bool: ensure_init_file(model_dir) if uri_type == UriType.REPO: - self._fetch_model_from_hf_repo(parsed_uri, model_dir) + _fetch_model_from_hf_repo(parsed_uri, model_dir) else: - self._fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) + _fetch_model_from_local(os.path.expanduser(parsed_uri), model_dir) config_path, _ = validate_model_files(model_dir) config = load_model_config_in_json(config_path) @@ -272,7 +289,7 @@ def register_model(self, model_id: str, uri: str) -> bool: self._models[ModelCategory.USER_DEFINED.value][model_id] = model_info if auto_map: - # Transformers model: immediately register to Transformers auto-loading mechanism + # Transformers model: immediately register to Transformers autoloading mechanism success = self._register_transformers_model(model_info) if success: with self._lock_pool.get_lock(model_id).write_lock(): @@ -281,46 +298,15 @@ def register_model(self, model_id: str, uri: str) -> bool: with self._lock_pool.get_lock(model_id).write_lock(): model_info.state = ModelStates.INACTIVE logger.error(f"Failed to register Transformers model {model_id}") - return False else: # Other type models: only log self._register_other_model(model_info) logger.info(f"Successfully registered model {model_id} from URI: {uri}") - return True - def _fetch_model_from_hf_repo(self, repo_id: str, storage_path: str): - logger.info( - f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" - ) - # Use snapshot_download to download entire repository (including config.json and model.safetensors) - try: - snapshot_download( - repo_id=repo_id, - local_dir=storage_path, - local_dir_use_symlinks=False, - ) - except Exception as e: - logger.error(f"Failed to download model from HuggingFace: {e}") - raise - - def _fetch_model_from_local(self, source_path: str, storage_path: str): - logger.info(f"Copying model from local path: {source_path} -> {storage_path}") - source_dir = Path(source_path) - if not source_dir.is_dir(): - raise ValueError( - f"Source path does not exist or is not a directory: {source_path}" - ) - - storage_dir = Path(storage_path) - for file in source_dir.iterdir(): - if file.is_file(): - shutil.copy2(file, storage_dir / file.name) - return - - def _register_transformers_model(self, model_info: ModelInfo) -> bool: + def _register_transformers_model(self, model_info: ModelInfo): """ - Register Transformers model to auto-loading mechanism (internal method) + Register Transformers model to autoloading mechanism (internal method) """ auto_map = model_info.auto_map if not auto_map: @@ -350,7 +336,6 @@ def _register_transformers_model(self, model_info: ModelInfo) -> bool: logger.info( f"Registered AutoModelForCausalLM: {config_class.__name__} -> {auto_model_path}" ) - return True except Exception as e: logger.warning( @@ -471,8 +456,16 @@ def show_models(self, req: TShowModelsReq) -> TShowModelsResp: stateMap=state_map, ) - def delete_model(self, model_id: str) -> None: - # Use write lock to protect entire deletion process + def delete_model(self, model_id: str): + """ + Delete a user-defined model by model_id. + Args: + model_id (str): Unique identifier for the model to be deleted. + Raises: + ModelNotExistException: If the model_id does not exist. + BuiltInModelDeletionException: If attempting to delete a built-in model. + Others: Any exceptions raised during file deletion. + """ with self._lock_pool.get_lock(model_id).write_lock(): model_info = None category_value = None @@ -481,30 +474,25 @@ def delete_model(self, model_id: str) -> None: model_info = category_dict[model_id] category_value = cat_value break - if not model_info: logger.warning(f"Model {model_id} does not exist, cannot delete") - return - + raise ModelNotExistException(model_id) if model_info.category == ModelCategory.BUILTIN: - raise BuiltInModelDeletionError(model_id) + logger.warning(f"Model {model_id} is builtin, cannot delete") + raise BuiltInModelDeletionException(model_id) model_info.state = ModelStates.DROPPING model_path = os.path.join( self._models_dir, model_info.category.value, model_id ) - if model_path.exists(): + if os.path.exists(model_path): try: shutil.rmtree(model_path) - logger.info(f"Deleted model directory: {model_path}") + logger.info(f"Model directory is deleted: {model_path}") except Exception as e: logger.error(f"Failed to delete model directory {model_path}: {e}") - raise - - if category_value and model_id in self._models[category_value]: - del self._models[category_value][model_id] - logger.info(f"Model {model_id} has been removed from storage") - - return + raise e + del self._models[category_value][model_id] + logger.info(f"Model {model_id} has been removed from model storage") # ==================== Query Methods ==================== @@ -512,10 +500,14 @@ def get_model_info( self, model_id: str, category: Optional[ModelCategory] = None ) -> Optional[ModelInfo]: """ - Get single model information - - If category is specified, use model_id's lock - If category is not specified, need to traverse all dictionaries, use global lock + Get specified model information. + Args: + model_id (str): Unique identifier for the model. + category (Optional[ModelCategory]): Category of the model (if known). + Returns: + ModelInfo: Information of the specified model. + Raises: + ModelNotExistException: If the model_id does not exist. """ if category: # Category specified, only need to access specific dictionary, use model_id's lock @@ -527,39 +519,7 @@ def get_model_info( for category_dict in self._models.values(): if model_id in category_dict: return category_dict[model_id] - return None - - def get_model_infos( - self, category: Optional[ModelCategory] = None, model_type: Optional[str] = None - ) -> List[ModelInfo]: - """ - Get model information list - - Note: Since we need to traverse all models, use a global lock to protect the entire dictionary structure - For single model access, using model_id-based lock would be more efficient - """ - matching_models = [] - - # For traversal operations, we need to protect the entire dictionary structure - # Use a special lock (using empty string as key) to protect the entire dictionary - with self._lock_pool.get_lock("").read_lock(): - if category and model_type: - for model_info in self._models[category.value].values(): - if model_info.model_type == model_type: - matching_models.append(model_info) - return matching_models - elif category: - return list(self._models[category.value].values()) - elif model_type: - for category_dict in self._models.values(): - for model_info in category_dict.values(): - if model_info.model_type == model_type: - matching_models.append(model_info) - return matching_models - else: - for category_dict in self._models.values(): - matching_models.extend(category_dict.values()) - return matching_models + raise ModelNotExistException(model_id) def is_model_registered(self, model_id: str) -> bool: """Check if model is registered (search in _models)""" @@ -572,11 +532,3 @@ def is_model_registered(self, model_id: str) -> bool: if model_id in category_dict: return True return False - - def get_registered_models(self) -> List[str]: - """Get list of all registered model IDs""" - with self._lock_pool.get_lock("").read_lock(): - model_ids = [] - for category_dict in self._models.values(): - model_ids.extend(category_dict.keys()) - return model_ids diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py index 261de3c9abe7..d9d20545af6f 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/configuration_sktime.py @@ -20,11 +20,11 @@ from typing import Any, Dict, List, Union from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, + BuiltInModelNotSupportException, ListRangeException, NumericalRangeException, StringRangeException, - WrongAttributeTypeError, + WrongAttributeTypeException, ) from iotdb.ainode.core.log import Logger @@ -49,7 +49,7 @@ def validate_value(self, value): if value is None: return True # Allow None for optional int parameters if not isinstance(value, int): - raise WrongAttributeTypeError(self.name, "int") + raise WrongAttributeTypeException(self.name, "int") if self.low is not None and self.high is not None: if not (self.low <= value <= self.high): raise NumericalRangeException(self.name, value, self.low, self.high) @@ -57,7 +57,7 @@ def validate_value(self, value): if value is None: return True # Allow None for optional float parameters if not isinstance(value, (int, float)): - raise WrongAttributeTypeError(self.name, "float") + raise WrongAttributeTypeException(self.name, "float") value = float(value) if self.low is not None and self.high is not None: if not (self.low <= value <= self.high): @@ -66,26 +66,26 @@ def validate_value(self, value): if value is None: return True # Allow None for optional str parameters if not isinstance(value, str): - raise WrongAttributeTypeError(self.name, "str") + raise WrongAttributeTypeException(self.name, "str") if self.choices and value not in self.choices: raise StringRangeException(self.name, value, self.choices) elif self.type == "bool": if value is None: return True # Allow None for optional bool parameters if not isinstance(value, bool): - raise WrongAttributeTypeError(self.name, "bool") + raise WrongAttributeTypeException(self.name, "bool") elif self.type == "list": if not isinstance(value, list): - raise WrongAttributeTypeError(self.name, "list") + raise WrongAttributeTypeException(self.name, "list") for item in value: if not isinstance(item, self.value_type): - raise WrongAttributeTypeError(self.name, self.value_type) + raise WrongAttributeTypeException(self.name, self.value_type) elif self.type == "tuple": if not isinstance(value, tuple): - raise WrongAttributeTypeError(self.name, "tuple") + raise WrongAttributeTypeException(self.name, "tuple") for item in value: if not isinstance(item, self.value_type): - raise WrongAttributeTypeError(self.name, self.value_type) + raise WrongAttributeTypeException(self.name, self.value_type) return True def parse(self, string_value: str): @@ -96,14 +96,14 @@ def parse(self, string_value: str): try: return int(string_value) except: - raise WrongAttributeTypeError(self.name, "int") + raise WrongAttributeTypeException(self.name, "int") elif self.type == "float": if string_value.lower() == "none" or string_value.strip() == "": return None try: return float(string_value) except: - raise WrongAttributeTypeError(self.name, "float") + raise WrongAttributeTypeException(self.name, "float") elif self.type == "str": if string_value.lower() == "none" or string_value.strip() == "": return None @@ -116,14 +116,14 @@ def parse(self, string_value: str): elif string_value.lower() == "none" or string_value.strip() == "": return None else: - raise WrongAttributeTypeError(self.name, "bool") + raise WrongAttributeTypeException(self.name, "bool") elif self.type == "list": try: list_value = eval(string_value) except: - raise WrongAttributeTypeError(self.name, "list") + raise WrongAttributeTypeException(self.name, "list") if not isinstance(list_value, list): - raise WrongAttributeTypeError(self.name, "list") + raise WrongAttributeTypeException(self.name, "list") for i in range(len(list_value)): try: list_value[i] = self.value_type(list_value[i]) @@ -136,9 +136,9 @@ def parse(self, string_value: str): try: tuple_value = eval(string_value) except: - raise WrongAttributeTypeError(self.name, "tuple") + raise WrongAttributeTypeException(self.name, "tuple") if not isinstance(tuple_value, tuple): - raise WrongAttributeTypeError(self.name, "tuple") + raise WrongAttributeTypeException(self.name, "tuple") list_value = list(tuple_value) for i in range(len(list_value)): try: @@ -390,7 +390,7 @@ def get_attributes(model_id: str) -> Dict[str, AttributeConfig]: """Get attribute configuration for Sktime model""" model_id = "EXPONENTIAL_SMOOTHING" if model_id == "HOLTWINTERS" else model_id if model_id not in MODEL_CONFIGS: - raise BuiltInModelNotSupportError(model_id) + raise BuiltInModelNotSupportException(model_id) return MODEL_CONFIGS[model_id] diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py index eca812d35ec9..9ddbcab286fa 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sktime/modeling_sktime.py @@ -30,8 +30,8 @@ from sktime.forecasting.trend import STLForecaster from iotdb.ainode.core.exception import ( - BuiltInModelNotSupportError, - InferenceModelInternalError, + BuiltInModelNotSupportException, + InferenceModelInternalException, ) from iotdb.ainode.core.log import Logger @@ -66,7 +66,7 @@ def generate(self, data, **kwargs): output = self._model.predict(fh=range(predict_length)) return np.array(output, dtype=np.float64) except Exception as e: - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) class DetectionModel(SktimeModel): @@ -82,7 +82,7 @@ def generate(self, data, **kwargs): else: return np.array(output, dtype=np.int32) except Exception as e: - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) class ArimaModel(ForecastingModel): @@ -155,7 +155,7 @@ def generate(self, data, **kwargs): scaled_data = pd.Series(scaled_data.flatten()) return super().generate(scaled_data, **kwargs) except Exception as e: - raise InferenceModelInternalError(str(e)) + raise InferenceModelInternalException(str(e)) # Model factory mapping @@ -176,5 +176,5 @@ def create_sktime_model(model_id: str, **kwargs) -> SktimeModel: attributes = update_attribute({**kwargs}, get_attributes(model_id.upper())) model_class = _MODEL_FACTORY.get(model_id.upper()) if model_class is None: - raise BuiltInModelNotSupportError(model_id) + raise BuiltInModelNotSupportException(model_id) return model_class(attributes) diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py index 85b6f7db2ffe..ee128802d240 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/sundial/pipeline_sundial.py @@ -18,7 +18,7 @@ import torch -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline @@ -28,7 +28,7 @@ def __init__(self, model_info, **model_kwargs): def _preprocess(self, inputs): if len(inputs.shape) != 2: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py index c0f00b1f5caf..65c6cdd74cd3 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/timer_xl/pipeline_timer.py @@ -18,7 +18,7 @@ import torch -from iotdb.ainode.core.exception import InferenceModelInternalError +from iotdb.ainode.core.exception import InferenceModelInternalException from iotdb.ainode.core.inference.pipeline.basic_pipeline import ForecastPipeline @@ -28,7 +28,7 @@ def __init__(self, model_info, **model_kwargs): def _preprocess(self, inputs): if len(inputs.shape) != 2: - raise InferenceModelInternalError( + raise InferenceModelInternalException( f"[Inference] Input shape must be: [batch_size, seq_len], but receives {inputs.shape}" ) return inputs diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py index 1cd0ee44912d..815232c52b0d 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/model/utils.py +++ b/iotdb-core/ainode/iotdb/ainode/core/model/utils.py @@ -19,25 +19,31 @@ import importlib import json import os.path +import shutil import sys from contextlib import contextmanager +from pathlib import Path from typing import Dict, Tuple +from huggingface_hub import snapshot_download + +from iotdb.ainode.core.exception import InvalidModelUriException +from iotdb.ainode.core.log import Logger from iotdb.ainode.core.model.model_constants import ( MODEL_CONFIG_FILE_IN_JSON, MODEL_WEIGHTS_FILE_IN_SAFETENSORS, UriType, ) +logger = Logger() + def parse_uri_type(uri: str) -> UriType: - if uri.startswith("repo://"): - return UriType.REPO - elif uri.startswith("file://"): + if uri.startswith("file://"): return UriType.FILE else: - raise ValueError( - f"Unsupported URI type: {uri}. Supported formats: repo:// or file://" + raise InvalidModelUriException( + f"Unknown uri type {uri}, currently supporting formats: file://" ) @@ -70,9 +76,13 @@ def validate_model_files(model_dir: str) -> Tuple[str, str]: weights_path = os.path.join(model_dir, MODEL_WEIGHTS_FILE_IN_SAFETENSORS) if not os.path.exists(config_path): - raise ValueError(f"Model config file does not exist: {config_path}") + raise InvalidModelUriException( + f"Model config file does not exist: {config_path}" + ) if not os.path.exists(weights_path): - raise ValueError(f"Model weights file does not exist: {weights_path}") + raise InvalidModelUriException( + f"Model weights file does not exist: {weights_path}" + ) # Create __init__.py file to ensure model directory can be imported as a module init_file = os.path.join(model_dir, "__init__.py") @@ -96,3 +106,32 @@ def ensure_init_file(dir_path: str): if not os.path.exists(init_file): with open(init_file, "w"): pass + + +def _fetch_model_from_local(source_path: str, storage_path: str): + logger.info(f"Copying model from local path: {source_path} -> {storage_path}") + source_dir = Path(source_path) + if not source_dir.exists(): + raise InvalidModelUriException(f"Source path does not exist: {source_path}") + if not source_dir.is_dir(): + raise InvalidModelUriException(f"Source path is not a directory: {source_path}") + storage_dir = Path(storage_path) + for file in source_dir.iterdir(): + if file.is_file(): + shutil.copy2(file, storage_dir / file.name) + + +def _fetch_model_from_hf_repo(repo_id: str, storage_path: str): + logger.info( + f"Downloading model from HuggingFace repository: {repo_id} -> {storage_path}" + ) + # Use snapshot_download to download entire repository (including config.json and model.safetensors) + try: + snapshot_download( + repo_id=repo_id, + local_dir=storage_path, + local_dir_use_symlinks=False, + ) + except Exception as e: + logger.error(f"Failed to download model from HuggingFace: {e}") + raise diff --git a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py index 6c4eedeb99f7..492802fc0600 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py +++ b/iotdb-core/ainode/iotdb/ainode/core/rpc/handler.py @@ -40,7 +40,7 @@ TShowLoadedModelsResp, TShowModelsReq, TShowModelsResp, - TTrainingReq, + TTuningReq, TUnloadModelReq, ) from iotdb.thrift.common.ttypes import TSStatus @@ -56,8 +56,8 @@ def _ensure_device_id_is_available(device_id_list: list[str]) -> TSStatus: for device_id in device_id_list: if device_id not in available_devices: return TSStatus( - code=TSStatusCode.INVALID_URI_ERROR.value, - message=f"Device ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", + code=TSStatusCode.UNAVAILABLE_AI_DEVICE_ERROR.value, + message=f"AIDevice ID [{device_id}] is not available. You can use 'SHOW AI_DEVICES' to retrieve the available devices.", ) return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) @@ -68,7 +68,9 @@ def __init__(self, ainode): self._model_manager = ModelManager() self._inference_manager = InferenceManager() - def stop(self) -> None: + # ==================== Cluster Management ==================== + + def stop(self): logger.info("Stopping the RPC service handler of IoTDB-AINode...") self._inference_manager.stop() @@ -76,6 +78,17 @@ def stopAINode(self) -> TSStatus: self._ainode.stop() return get_status(TSStatusCode.SUCCESS_STATUS, "AINode stopped successfully.") + def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: + return ClusterManager.get_heart_beat(req) + + def showAIDevices(self) -> TShowAIDevicesResp: + return TShowAIDevicesResp( + status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value), + deviceIdList=get_available_devices(), + ) + + # ==================== Model Management ==================== + def registerModel(self, req: TRegisterModelReq) -> TRegisterModelResp: return self._model_manager.register_model(req) @@ -109,11 +122,15 @@ def showLoadedModels(self, req: TShowLoadedModelsReq) -> TShowLoadedModelsResp: return TShowLoadedModelsResp(status=status, deviceLoadedModelsMap={}) return self._inference_manager.show_loaded_models(req) - def showAIDevices(self) -> TShowAIDevicesResp: - return TShowAIDevicesResp( - status=TSStatus(code=TSStatusCode.SUCCESS_STATUS.value), - deviceIdList=get_available_devices(), - ) + def _ensure_model_is_registered(self, model_id: str) -> TSStatus: + if not self._model_manager.is_model_registered(model_id): + return TSStatus( + code=TSStatusCode.MODEL_NOT_EXIST_ERROR.value, + message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", + ) + return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) + + # ==================== Inference ==================== def inference(self, req: TInferenceReq) -> TInferenceResp: status = self._ensure_model_is_registered(req.modelId) @@ -127,16 +144,7 @@ def forecast(self, req: TForecastReq) -> TForecastResp: return TForecastResp(status, []) return self._inference_manager.forecast(req) - def getAIHeartbeat(self, req: TAIHeartbeatReq) -> TAIHeartbeatResp: - return ClusterManager.get_heart_beat(req) + # ==================== Tuning ==================== - def createTrainingTask(self, req: TTrainingReq) -> TSStatus: + def createTuningTask(self, req: TTuningReq) -> TSStatus: pass - - def _ensure_model_is_registered(self, model_id: str) -> TSStatus: - if not self._model_manager.is_model_registered(model_id): - return TSStatus( - code=TSStatusCode.MODEL_NOT_FOUND_ERROR.value, - message=f"Model [{model_id}] is not registered yet. You can use 'SHOW MODELS' to retrieve the available models.", - ) - return TSStatus(code=TSStatusCode.SUCCESS_STATUS.value) diff --git a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py index f8188209a377..9c6020019fc2 100644 --- a/iotdb-core/ainode/iotdb/ainode/core/util/serde.py +++ b/iotdb-core/ainode/iotdb/ainode/core/util/serde.py @@ -21,7 +21,7 @@ import numpy as np import pandas as pd -from iotdb.ainode.core.exception import BadConfigValueError +from iotdb.ainode.core.exception import BadConfigValueException class TSDataType(Enum): @@ -122,7 +122,7 @@ def _get_type_in_byte(data_type: pd.Series): elif data_type == "text": return b"\x05" else: - raise BadConfigValueError( + raise BadConfigValueException( "data_type", data_type, "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']", @@ -138,7 +138,7 @@ def get_data_type_byte_from_str(value): byte: corresponding data type in [b'\x00', b'\x01', b'\x02', b'\x03', b'\x04', b'\x05'] """ if value not in ["bool", "int32", "int64", "float32", "float64", "text"]: - raise BadConfigValueError( + raise BadConfigValueException( "data_type", value, "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']", diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java index 03402d30c643..b8d98c9c1d72 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/GetModelInfoException.java @@ -23,6 +23,6 @@ public class GetModelInfoException extends ModelException { public GetModelInfoException(String message) { - super(message, TSStatusCode.GET_MODEL_INFO_ERROR); + super(message, TSStatusCode.AINODE_INTERNAL_ERROR); } } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java deleted file mode 100644 index 38a5105cded1..000000000000 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/exception/ainode/ModelNotFoundException.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iotdb.db.exception.ainode; - -import org.apache.iotdb.rpc.TSStatusCode; - -public class ModelNotFoundException extends ModelException { - public ModelNotFoundException(String message) { - super(message, TSStatusCode.MODEL_NOT_FOUND_ERROR); - } -} diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java index 5eaffc40af9c..ffad889a2233 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/protocol/client/an/AINodeClient.java @@ -35,7 +35,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TTuningReq; import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.TAINodeLocation; import org.apache.iotdb.common.rpc.thrift.TEndPoint; @@ -129,8 +129,8 @@ public TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) { } @Override - public TSStatus createTrainingTask(TTrainingReq req) throws TException { - return executeRemoteCallWithRetry(() -> client.createTrainingTask(req)); + public TSStatus createTuningTask(TTuningReq req) throws TException { + return executeRemoteCallWithRetry(() -> client.createTuningTask(req)); } @Override diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java index 847b5d628808..bc52a3a7e8cb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TableConfigTaskVisitor.java @@ -62,7 +62,7 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowRegionTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask; -import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTuningTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.LoadModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowAIDevicesTask; @@ -1520,7 +1520,7 @@ protected IConfigTask visitRemoveRegion(RemoveRegion removeRegion, MPPQueryConte protected IConfigTask visitCreateTraining(CreateTraining node, MPPQueryContext context) { context.setQueryType(QueryType.WRITE); accessControl.checkUserGlobalSysPrivilege(context); - return new CreateTrainingTask( + return new CreateTuningTask( node.getModelId(), node.getParameters(), node.getExistingModelId(), node.getTargetSql()); } diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java index cb4d05f1b99a..ed5e2e434f48 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/TreeConfigTaskVisitor.java @@ -68,7 +68,7 @@ import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ShowVariablesTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.UnSetTTLTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateModelTask; -import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTrainingTask; +import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.CreateTuningTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.DropModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.LoadModelTask; import org.apache.iotdb.db.queryengine.plan.execution.config.metadata.ai.ShowAIDevicesTask; @@ -926,7 +926,7 @@ public IConfigTask visitCreateTraining( for (PartialPath partialPath : partialPathList) { targetPathPatterns.add(partialPath.getFullPath()); } - return new CreateTrainingTask( + return new CreateTuningTask( createTrainingStatement.getModelId(), createTrainingStatement.getParameters(), createTrainingStatement.getTargetTimeRanges(), diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java index 01f6757f02eb..4d99a3ed8928 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/ClusterConfigTaskExecutor.java @@ -28,7 +28,7 @@ import org.apache.iotdb.ainode.rpc.thrift.TShowLoadedModelsResp; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsReq; import org.apache.iotdb.ainode.rpc.thrift.TShowModelsResp; -import org.apache.iotdb.ainode.rpc.thrift.TTrainingReq; +import org.apache.iotdb.ainode.rpc.thrift.TTuningReq; import org.apache.iotdb.ainode.rpc.thrift.TUnloadModelReq; import org.apache.iotdb.common.rpc.thrift.FunctionType; import org.apache.iotdb.common.rpc.thrift.Model; @@ -3727,7 +3727,7 @@ public SettableFuture unloadModel( } @Override - public SettableFuture createTraining( + public SettableFuture createTuningTask( String modelId, boolean isTableModel, Map parameters, @@ -3736,9 +3736,9 @@ public SettableFuture createTraining( @Nullable String targetSql, @Nullable List pathList) { final SettableFuture future = SettableFuture.create(); - try (final AINodeClient ai = + try (final AINodeClient aiNodeClient = AINodeClientManager.getInstance().borrowClient(AINodeClientManager.AINODE_ID_PLACEHOLDER)) { - final TTrainingReq req = new TTrainingReq(); + final TTuningReq req = new TTuningReq(); req.setModelId(modelId); req.setParameters(parameters); if (existingModelId != null) { @@ -3747,7 +3747,7 @@ public SettableFuture createTraining( if (existingModelId != null) { req.setExistingModelId(existingModelId); } - final TSStatus status = ai.createTrainingTask(req); + final TSStatus status = aiNodeClient.createTuningTask(req); if (TSStatusCode.SUCCESS_STATUS.getStatusCode() != status.getCode()) { future.setException(new IoTDBException(status)); } else { diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java index 85455756b052..2d7c0f6d1f3a 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/executor/IConfigTaskExecutor.java @@ -443,7 +443,7 @@ SettableFuture createTableView( SettableFuture unloadModel(String existingModelId, List deviceIdList); - SettableFuture createTraining( + SettableFuture createTuningTask( String modelId, boolean isTableModel, Map parameters, diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java similarity index 93% rename from iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java rename to iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java index 9c93c5b75779..a66b7e700b75 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTrainingTask.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/execution/config/metadata/ai/CreateTuningTask.java @@ -28,7 +28,7 @@ import java.util.List; import java.util.Map; -public class CreateTrainingTask implements IConfigTask { +public class CreateTuningTask implements IConfigTask { private final String modelId; private final boolean isTableModel; @@ -43,7 +43,7 @@ public class CreateTrainingTask implements IConfigTask { private List> timeRanges; // For table model - public CreateTrainingTask( + public CreateTuningTask( String modelId, Map parameters, String existingModelId, String targetSql) { this.modelId = modelId; this.parameters = parameters; @@ -53,7 +53,7 @@ public CreateTrainingTask( } // For tree model - public CreateTrainingTask( + public CreateTuningTask( String modelId, Map parameters, List> timeRanges, @@ -71,7 +71,7 @@ public CreateTrainingTask( @Override public ListenableFuture execute(IConfigTaskExecutor configTaskExecutor) throws InterruptedException { - return configTaskExecutor.createTraining( + return configTaskExecutor.createTuningTask( modelId, isTableModel, parameters, timeRanges, existingModelId, targetSql, targetPaths); } } diff --git a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift index 1dc2f025f5c3..cda356a948e1 100644 --- a/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift +++ b/iotdb-protocol/thrift-ainode/src/main/thrift/ainode.thrift @@ -73,7 +73,7 @@ struct IDataSchema { 2: optional list timeRange } -struct TTrainingReq { +struct TTuningReq { 1: required string dbType 2: required string modelId 3: required string existingModelId @@ -131,30 +131,27 @@ struct TUnloadModelReq { service IAINodeRPCService { - // -------------- For Config Node -------------- common.TSStatus stopAINode() + TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) + + TShowAIDevicesResp showAIDevices() + TShowModelsResp showModels(TShowModelsReq req) TShowLoadedModelsResp showLoadedModels(TShowLoadedModelsReq req) - TShowAIDevicesResp showAIDevices() - common.TSStatus deleteModel(TDeleteModelReq req) TRegisterModelResp registerModel(TRegisterModelReq req) - TAIHeartbeatResp getAIHeartbeat(TAIHeartbeatReq req) - - common.TSStatus createTrainingTask(TTrainingReq req) - common.TSStatus loadModel(TLoadModelReq req) common.TSStatus unloadModel(TUnloadModelReq req) - // -------------- For Data Node -------------- - TInferenceResp inference(TInferenceReq req) TForecastResp forecast(TForecastReq req) + + common.TSStatus createTuningTask(TTuningReq req) } \ No newline at end of file