This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch hf-timerxl
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit a20dc8935683c198534212cc8aa0ceec743c6dba
Author: Yongzao <[email protected]>
AuthorDate: Tue May 27 10:46:06 2025 +0800

    preview
---
 iotdb-core/ainode/ainode/core/constant.py          |  2 +-
 .../ainode/core/manager/inference_manager.py       |  3 +-
 .../ainode/core/model/built_in_model_factory.py    | 13 ++++++--
 .../core/model/sundial/configuration_sundial.py    |  2 ++
 .../ainode/core/model/sundial/modeling_sundial.py  | 39 +++++++++++++++++++---
 5 files changed, 49 insertions(+), 10 deletions(-)

diff --git a/iotdb-core/ainode/ainode/core/constant.py 
b/iotdb-core/ainode/ainode/core/constant.py
index b893a9cffa6..79a991664fd 100644
--- a/iotdb-core/ainode/ainode/core/constant.py
+++ b/iotdb-core/ainode/ainode/core/constant.py
@@ -261,7 +261,7 @@ class AttributeName(Enum):
     ATTENTION_DROPOUT = "attention_dropout"
     INITIALIZER_RANGE = "initializer_range"
     MAX_POSITION_EMBEDDINGS = "max_position_embeddings"
-    TIMERXL_CKPT_PATH = "ckpt_path"
+    CKPT_PATH = "ckpt_path"
 
     # sundial
     DROPOUT_RATE = "dropout_rate"
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 2a9c93c2b1e..9df60fe1a40 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -58,7 +58,8 @@ class SundialStrategy(InferenceStrategy):
         if data.dtype.byteorder not in ('=', '|'):
             data = data.byteswap().newbyteorder()
         seqs = torch.tensor(data).unsqueeze(0).float()
-        output = self.model.generate(seqs, max_new_tokens=96, num_samples=10, 
revin=True)
+        # TODO: unify model inference input
+        output = self.model.generate(seqs, max_new_tokens=predict_length, 
num_samples=10, revin=True)
         df = pd.DataFrame(output[0].mean(dim=0))
         return convert_to_binary(df)
 
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 402f53522a7..f5f847f0eec 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -37,6 +37,7 @@ from ainode.core.exception import InferenceModelInternalError
 from ainode.core.exception import WrongAttributeTypeError, 
NumericalRangeException, StringRangeException, \
     ListRangeException, BuiltInModelNotSupportError
 from ainode.core.log import Logger
+from ainode.core.model.sundial.configuration_sundial import SundialConfig
 
 logger = Logger()
 
@@ -103,7 +104,7 @@ def fetch_built_in_model(model_id, inference_attributes):
     elif model_id == BuiltInModelType.TIMER_XL.value:
         model = timer_xl.Model(TimerxlConfig.from_dict(attributes))
     elif model_id == BuiltInModelType.SUNDIAL.value:
-        model = 
modeling_sundial.SundialForPrediction.from_pretrained("/Users/yongzaodan/Downloads/models--thuml--sundial-base-128m/snapshots/03a967dd675f0dea6229fdee1bc89e4eb310df19")
+        model = 
modeling_sundial.SundialForPrediction(SundialConfig.from_dict(attributes))
     else:
         raise BuiltInModelNotSupportError(model_id)
 
@@ -413,6 +414,12 @@ sundial_attribute_map = {
         default_low=1,
         default_high=5000
     ),
+    AttributeName.CKPT_PATH.value: StringAttribute(
+        name=AttributeName.CKPT_PATH.value,
+        default_value=os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
+                                   'sundial'),
+        value_choices=['']
+    )
 }
 
 timerxl_attribute_map = {
@@ -484,8 +491,8 @@ timerxl_attribute_map = {
         default_low=1,
         default_high=50000
     ),
-    AttributeName.TIMERXL_CKPT_PATH.value: StringAttribute(
-        name=AttributeName.TIMERXL_CKPT_PATH.value,
+    AttributeName.CKPT_PATH.value: StringAttribute(
+        name=AttributeName.CKPT_PATH.value,
         default_value=os.path.join(os.getcwd(), 
AINodeDescriptor().get_config().get_ain_models_dir(), 'weights',
                                    'timerxl', 'model.safetensors'),
         value_choices=['']
diff --git 
a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py 
b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
index 2bdb6a0a4b6..41c54ff4a72 100644
--- a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
+++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py
@@ -41,6 +41,7 @@ class SundialConfig(PretrainedConfig):
         flow_loss_depth: int = 3,
         num_sampling_steps: int = 50,
         diffusion_batch_mul: int = 4,
+        ckpt_path: str = None,              # weight path
         **kwargs,
     ):
         self.input_token_len = input_token_len
@@ -58,6 +59,7 @@ class SundialConfig(PretrainedConfig):
         self.flow_loss_depth = flow_loss_depth
         self.num_sampling_steps = num_sampling_steps
         self.diffusion_batch_mul = diffusion_batch_mul
+        self.ckpt_path = ckpt_path
 
         super().__init__(
             **kwargs,
diff --git a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py 
b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
index 346d0f73520..724ae3f1bc1 100644
--- a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
+++ b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py
@@ -16,6 +16,7 @@
 # under the License.
 #
 
+import os
 from typing import Optional, Tuple, List, Union
 import torch
 from torch import nn
@@ -28,6 +29,11 @@ from .configuration_sundial import SundialConfig
 from .ts_generation_mixin import TSGenerationMixin
 from .flow_loss import FlowLoss
 
+from safetensors.torch import load_file as load_safetensors
+from huggingface_hub import hf_hub_download
+
+from ainode.core.log import Logger
+logger = Logger()
 
 def rotate_half(x):
     x1 = x[..., : x.shape[-1] // 2]
@@ -416,11 +422,34 @@ class SundialModel(SundialPreTrainedModel):
 class SundialForPrediction(SundialPreTrainedModel, TSGenerationMixin):
     def __init__(self, config: SundialConfig):
         super().__init__(config)
-        self.config = config
-        self.model = SundialModel(self.config)
-        self.flow_loss = FlowLoss(self.config.output_token_lens[-1], 
self.config.hidden_size,
-                                  self.config.flow_loss_depth, 
self.config.hidden_size, self.config.num_sampling_steps)
-        self.post_init()
+        # self.config = config
+        # self.model = SundialModel(self.config)
+        # self.flow_loss = FlowLoss(self.config.output_token_lens[-1], 
self.config.hidden_size,
+        #                           self.config.flow_loss_depth, 
self.config.hidden_size, self.config.num_sampling_steps)
+        # TODO: Unify data loader
+        if not os.path.exists(config.ckpt_path):
+            os.mkdir(config.ckpt_path)
+        if not os.path.exists(os.path.join(config.ckpt_path, "config.json")):
+            logger.info(f"Config not found at {config.ckpt_path}, downloading 
from HuggingFace...")
+            repo_id = "thuml/sundial-base-128m"
+            try:
+                hf_hub_download(repo_id=repo_id, filename="config.json", 
local_dir=config.ckpt_path)
+                logger.info(f"Got config to {config.ckpt_path}")
+            except Exception as e:
+                logger.error(f"Failed to download config to {config.ckpt_path} 
due to {e}")
+                raise e
+        if not os.path.exists(os.path.join(config.ckpt_path, 
"model.safetensors")):
+            logger.info(f"Weight not found at {config.ckpt_path}, downloading 
from HuggingFace...")
+            repo_id = "thuml/sundial-base-128m"
+            try:
+                hf_hub_download(repo_id=repo_id, filename="model.safetensors", 
local_dir=config.ckpt_path)
+                logger.info(f"Got weight to {config.ckpt_path}")
+            except Exception as e:
+                logger.error(f"Failed to download weight to {config.ckpt_path} 
due to {e}")
+                raise e
+        print(config.ckpt_path, type(config.ckpt_path))
+        self.from_pretrained(config.ckpt_path)
+        # self.post_init()
 
     def set_decoder(self, decoder):
         self.model = decoder

Reply via email to