This is an automated email from the ASF dual-hosted git repository. yongzao pushed a commit to branch hf-timerxl in repository https://gitbox.apache.org/repos/asf/iotdb.git
commit a20dc8935683c198534212cc8aa0ceec743c6dba Author: Yongzao <[email protected]> AuthorDate: Tue May 27 10:46:06 2025 +0800 preview --- iotdb-core/ainode/ainode/core/constant.py | 2 +- .../ainode/core/manager/inference_manager.py | 3 +- .../ainode/core/model/built_in_model_factory.py | 13 ++++++-- .../core/model/sundial/configuration_sundial.py | 2 ++ .../ainode/core/model/sundial/modeling_sundial.py | 39 +++++++++++++++++++--- 5 files changed, 49 insertions(+), 10 deletions(-) diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index b893a9cffa6..79a991664fd 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -261,7 +261,7 @@ class AttributeName(Enum): ATTENTION_DROPOUT = "attention_dropout" INITIALIZER_RANGE = "initializer_range" MAX_POSITION_EMBEDDINGS = "max_position_embeddings" - TIMERXL_CKPT_PATH = "ckpt_path" + CKPT_PATH = "ckpt_path" # sundial DROPOUT_RATE = "dropout_rate" diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index 2a9c93c2b1e..9df60fe1a40 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -58,7 +58,8 @@ class SundialStrategy(InferenceStrategy): if data.dtype.byteorder not in ('=', '|'): data = data.byteswap().newbyteorder() seqs = torch.tensor(data).unsqueeze(0).float() - output = self.model.generate(seqs, max_new_tokens=96, num_samples=10, revin=True) + # TODO: unify model inference input + output = self.model.generate(seqs, max_new_tokens=predict_length, num_samples=10, revin=True) df = pd.DataFrame(output[0].mean(dim=0)) return convert_to_binary(df) diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index 402f53522a7..f5f847f0eec 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -37,6 +37,7 @@ from ainode.core.exception import InferenceModelInternalError from ainode.core.exception import WrongAttributeTypeError, NumericalRangeException, StringRangeException, \ ListRangeException, BuiltInModelNotSupportError from ainode.core.log import Logger +from ainode.core.model.sundial.configuration_sundial import SundialConfig logger = Logger() @@ -103,7 +104,7 @@ def fetch_built_in_model(model_id, inference_attributes): elif model_id == BuiltInModelType.TIMER_XL.value: model = timer_xl.Model(TimerxlConfig.from_dict(attributes)) elif model_id == BuiltInModelType.SUNDIAL.value: - model = modeling_sundial.SundialForPrediction.from_pretrained("/Users/yongzaodan/Downloads/models--thuml--sundial-base-128m/snapshots/03a967dd675f0dea6229fdee1bc89e4eb310df19") + model = modeling_sundial.SundialForPrediction(SundialConfig.from_dict(attributes)) else: raise BuiltInModelNotSupportError(model_id) @@ -413,6 +414,12 @@ sundial_attribute_map = { default_low=1, default_high=5000 ), + AttributeName.CKPT_PATH.value: StringAttribute( + name=AttributeName.CKPT_PATH.value, + default_value=os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir(), 'weights', + 'sundial'), + value_choices=[''] + ) } timerxl_attribute_map = { @@ -484,8 +491,8 @@ timerxl_attribute_map = { default_low=1, default_high=50000 ), - AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute( - name=AttributeName.TIMERXL_CKPT_PATH.value, + AttributeName.CKPT_PATH.value: StringAttribute( + name=AttributeName.CKPT_PATH.value, default_value=os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir(), 'weights', 'timerxl', 'model.safetensors'), value_choices=[''] diff --git a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py index 2bdb6a0a4b6..41c54ff4a72 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py @@ -41,6 +41,7 @@ class SundialConfig(PretrainedConfig): flow_loss_depth: int = 3, num_sampling_steps: int = 50, diffusion_batch_mul: int = 4, + ckpt_path: str = None, # weight path **kwargs, ): self.input_token_len = input_token_len @@ -58,6 +59,7 @@ class SundialConfig(PretrainedConfig): self.flow_loss_depth = flow_loss_depth self.num_sampling_steps = num_sampling_steps self.diffusion_batch_mul = diffusion_batch_mul + self.ckpt_path = ckpt_path super().__init__( **kwargs, diff --git a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py index 346d0f73520..724ae3f1bc1 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py @@ -16,6 +16,7 @@ # under the License. # +import os from typing import Optional, Tuple, List, Union import torch from torch import nn @@ -28,6 +29,11 @@ from .configuration_sundial import SundialConfig from .ts_generation_mixin import TSGenerationMixin from .flow_loss import FlowLoss +from safetensors.torch import load_file as load_safetensors +from huggingface_hub import hf_hub_download + +from ainode.core.log import Logger +logger = Logger() def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] @@ -416,11 +422,34 @@ class SundialModel(SundialPreTrainedModel): class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin): def __init__(self, config: SundialConfig): super().__init__(config) - self.config = config - self.model = SundialModel(self.config) - self.flow_loss = FlowLoss(self.config.output_token_lens[-1], self.config.hidden_size, - self.config.flow_loss_depth, self.config.hidden_size, self.config.num_sampling_steps) - self.post_init() + # self.config = config + # self.model = SundialModel(self.config) + # self.flow_loss = FlowLoss(self.config.output_token_lens[-1], self.config.hidden_size, + # self.config.flow_loss_depth, self.config.hidden_size, self.config.num_sampling_steps) + # TODO: Unify data loader + if not os.path.exists(config.ckpt_path): + os.mkdir(config.ckpt_path) + if not os.path.exists(os.path.join(config.ckpt_path, "config.json")): + logger.info(f"Config not found at {config.ckpt_path}, downloading from HuggingFace...") + repo_id = "thuml/sundial-base-128m" + try: + hf_hub_download(repo_id=repo_id, filename="config.json", local_dir=config.ckpt_path) + logger.info(f"Got config to {config.ckpt_path}") + except Exception as e: + logger.error(f"Failed to download config to {config.ckpt_path} due to {e}") + raise e + if not os.path.exists(os.path.join(config.ckpt_path, "model.safetensors")): + logger.info(f"Weight not found at {config.ckpt_path}, downloading from HuggingFace...") + repo_id = "thuml/sundial-base-128m" + try: + hf_hub_download(repo_id=repo_id, filename="model.safetensors", local_dir=config.ckpt_path) + logger.info(f"Got weight to {config.ckpt_path}") + except Exception as e: + logger.error(f"Failed to download weight to {config.ckpt_path} due to {e}") + raise e + print(config.ckpt_path, type(config.ckpt_path)) + self.from_pretrained(config.ckpt_path) + # self.post_init() def set_decoder(self, decoder): self.model = decoder
