This is an automated email from the ASF dual-hosted git repository.

CRZbulabula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new c9c16a9b75f [AINode] Integrate MOMENT as a builtin forecasting model  
(#17386)
c9c16a9b75f is described below

commit c9c16a9b75fe42ea527a3dbf275ae84da9353226
Author: Elia LIU <[email protected]>
AuthorDate: Mon Apr 27 20:55:14 2026 +1000

    [AINode] Integrate MOMENT as a builtin forecasting model  (#17386)
---
 LICENSE                                            |  10 +
 .../apache/iotdb/ainode/utils/AINodeTestUtils.java |   4 +-
 .../ainode/iotdb/ainode/core/model/model_info.py   |  13 +
 .../iotdb/ainode/core/model/moment/__init__.py     |  17 +
 .../core/model/moment/configuration_moment.py      |  71 ++++
 .../ainode/core/model/moment/modeling_moment.py    | 415 +++++++++++++++++++++
 .../ainode/core/model/moment/pipeline_moment.py    | 176 +++++++++
 scripts/conf/confignode-env.sh                     |   2 +-
 scripts/conf/datanode-env.sh                       |   2 +-
 9 files changed, 707 insertions(+), 3 deletions(-)

diff --git a/LICENSE b/LICENSE
index 07a2ad8da4d..81000948cc4 100644
--- a/LICENSE
+++ b/LICENSE
@@ -360,3 +360,13 @@ Project page: 
https://github.com/SalesforceAIResearch/uni2ts
 License: https://github.com/SalesforceAIResearch/uni2ts/blob/main/LICENSE.txt
 
 
--------------------------------------------------------------------------------
+
+The following files include code modified from MOMENT project.
+
+./iotdb-core/ainode/iotdb/ainode/core/model/moment/*
+
+The MOMENT is open source software licensed under the MIT License
+Project page: https://github.com/moment-timeseries-foundation-model/moment
+License: 
https://github.com/moment-timeseries-foundation-model/moment/blob/main/LICENSE
+
+--------------------------------------------------------------------------------
diff --git 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index bf758a083d4..d69a301c066 100644
--- 
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++ 
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -60,7 +60,9 @@ public class AINodeTestUtils {
               new AbstractMap.SimpleEntry<>(
                   "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin", 
"active")),
               new AbstractMap.SimpleEntry<>(
-                  "toto", new FakeModelInfo("toto", "toto", "builtin", 
"active")))
+                  "toto", new FakeModelInfo("toto", "toto", "builtin", 
"active")),
+              new AbstractMap.SimpleEntry<>(
+                  "moment", new FakeModelInfo("moment", "moment", "builtin", 
"active")))
           .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
 
   public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index 642986c42d2..0c63b356e7a 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -173,4 +173,17 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
         },
         transformers_registered=True,
     ),
+    "moment": ModelInfo(
+        model_id="moment",
+        category=ModelCategory.BUILTIN,
+        state=ModelStates.INACTIVE,
+        model_type="moment",
+        pipeline_cls="pipeline_moment.MomentPipeline",
+        repo_id="AutonLab/MOMENT-1-large",
+        auto_map={
+            "AutoConfig": "configuration_moment.MomentConfig",
+            "AutoModelForCausalLM": "modeling_moment.MomentForPrediction",
+        },
+        transformers_registered=True,
+    ),
 }
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/moment/__init__.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/moment/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/moment/configuration_moment.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/moment/configuration_moment.py
new file mode 100644
index 00000000000..e249d620ff8
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/configuration_moment.py
@@ -0,0 +1,71 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# This file contains code adapted from the MOMENT project
+# (https://github.com/moment-timeseries-foundation-model/moment),
+# originally licensed under the MIT License.
+
+from typing import Optional
+
+from transformers import PretrainedConfig
+
+
+class MomentConfig(PretrainedConfig):
+    """
+    Configuration class for the MOMENT time series foundation model.
+
+    MOMENT (A Family of Open Time-series Foundation Models) is developed by
+    Auton Lab, Carnegie Mellon University. It uses a T5 encoder-only backbone
+    with patch-based input embedding and RevIN normalization for multi-task
+    time series analysis including forecasting, classification, anomaly
+    detection and imputation.
+
+    Reference: https://arxiv.org/abs/2402.03885
+    """
+
+    model_type = "moment"
+
+    def __init__(
+        self,
+        seq_len: int = 512,
+        patch_len: int = 8,
+        patch_stride_len: int = 8,
+        d_model: Optional[int] = None,
+        transformer_backbone: str = "google/flan-t5-large",
+        forecast_horizon: int = 96,
+        revin_affine: bool = False,
+        t5_config: Optional[dict] = None,
+        **kwargs,
+    ):
+        self.seq_len = seq_len
+        self.patch_len = patch_len
+        self.patch_stride_len = patch_stride_len
+        self.transformer_backbone = transformer_backbone
+        self.forecast_horizon = forecast_horizon
+        self.revin_affine = revin_affine
+        self.t5_config = t5_config
+
+        # Infer d_model: prefer explicit value, then t5_config, then default
+        if d_model is not None:
+            self.d_model = d_model
+        elif t5_config is not None and "d_model" in t5_config:
+            self.d_model = t5_config["d_model"]
+        else:
+            self.d_model = 1024  # Default for MOMENT-1-large
+
+        super().__init__(**kwargs)
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/moment/modeling_moment.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/moment/modeling_moment.py
new file mode 100644
index 00000000000..a3b358a21e3
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/modeling_moment.py
@@ -0,0 +1,415 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# This file contains code adapted from the MOMENT project
+# (https://github.com/moment-timeseries-foundation-model/moment),
+# originally licensed under the MIT License.
+
+import json
+import os
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import PreTrainedModel, T5Config, T5EncoderModel
+
+from iotdb.ainode.core.log import Logger
+
+from .configuration_moment import MomentConfig
+
+logger = Logger()
+
+
+@dataclass
+class MomentOutput:
+    forecast: Optional[torch.Tensor] = None
+    reconstruction: Optional[torch.Tensor] = None
+    embeddings: Optional[torch.Tensor] = None
+    input_mask: Optional[torch.Tensor] = None
+
+
+class RevIN(nn.Module):
+    """Reversible Instance Normalization for time series."""
+
+    def __init__(self, n_features: int, affine: bool = False, eps: float = 
1e-5):
+        super().__init__()
+        self.n_features = n_features
+        self.affine = affine
+        self.eps = eps
+        if self.affine:
+            self.affine_weight = nn.Parameter(torch.ones(self.n_features))
+            self.affine_bias = nn.Parameter(torch.zeros(self.n_features))
+
+    def forward(
+        self, x: torch.Tensor, mode: str, mask: Optional[torch.Tensor] = None
+    ) -> torch.Tensor:
+        # x: [batch, n_channels, seq_len]
+        # mask: [batch, seq_len] - 1 for observed, 0 for padding (optional)
+        if mode == "norm":
+            self._get_statistics(x, mask)
+            x = self._normalize(x)
+        elif mode == "denorm":
+            x = self._denormalize(x)
+        return x
+
+    def _get_statistics(self, x: torch.Tensor, mask: Optional[torch.Tensor] = 
None):
+        if mask is not None:
+            # Expand mask to match x: [batch, 1, seq_len] -> broadcast over 
n_channels
+            m = mask.unsqueeze(1).float()  # [batch, 1, seq_len]
+            count = m.sum(dim=-1, keepdim=True).clamp(min=1)
+            self.mean = (x * m).sum(dim=-1, keepdim=True) / count
+            self.stdev = torch.sqrt(
+                ((x - self.mean) * m).pow(2).sum(dim=-1, keepdim=True) / count
+                + self.eps
+            )
+            self.mean = self.mean.detach()
+            self.stdev = self.stdev.detach()
+        else:
+            self.mean = torch.mean(x, dim=-1, keepdim=True).detach()
+            self.stdev = torch.sqrt(
+                torch.var(x, dim=-1, keepdim=True, unbiased=False) + self.eps
+            ).detach()
+
+    def _normalize(self, x: torch.Tensor) -> torch.Tensor:
+        x = (x - self.mean) / self.stdev
+        if self.affine:
+            x = x * self.affine_weight.unsqueeze(0).unsqueeze(-1)
+            x = x + self.affine_bias.unsqueeze(0).unsqueeze(-1)
+        return x
+
+    def _denormalize(self, x: torch.Tensor) -> torch.Tensor:
+        if self.affine:
+            x = x - self.affine_bias.unsqueeze(0).unsqueeze(-1)
+            x = x / (self.affine_weight.unsqueeze(0).unsqueeze(-1) + self.eps)
+        x = x * self.stdev
+        x = x + self.mean
+        return x
+
+
+class Patching(nn.Module):
+    """Unfold a 1-D time series into fixed-size patches."""
+
+    def __init__(self, patch_len: int = 8, stride: int = 8):
+        super().__init__()
+        self.patch_len = patch_len
+        self.stride = stride
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # x: [batch, n_channels, seq_len]
+        # out: [batch, n_channels, n_patches, patch_len]
+        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
+        return x
+
+
+class PatchEmbedding(nn.Module):
+    """Linear projection of patches to model dimension."""
+
+    def __init__(self, d_model: int, patch_len: int):
+        super().__init__()
+        self.proj = nn.Linear(patch_len, d_model)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # x: [batch, n_channels, n_patches, patch_len]
+        # out: [batch, n_channels, n_patches, d_model]
+        return self.proj(x)
+
+
+class ForecastingHead(nn.Module):
+    """Linear head that projects flattened patch embeddings to forecast 
horizon."""
+
+    def __init__(self, d_model: int, n_patches: int, forecast_horizon: int):
+        super().__init__()
+        self.flatten = nn.Flatten(start_dim=-2)
+        self.proj = nn.Linear(d_model * n_patches, forecast_horizon)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        # x: [batch, n_channels, n_patches, d_model]
+        x = self.flatten(x)  # [batch, n_channels, n_patches * d_model]
+        return self.proj(x)  # [batch, n_channels, forecast_horizon]
+
+
+class MomentBackbone(nn.Module):
+    """
+    Core MOMENT architecture.
+
+    Architecture:
+        Input [batch, n_channels, seq_len]
+        -> RevIN normalization
+        -> Patching (unfold into fixed-size patches)
+        -> Patch embedding (linear projection to d_model)
+        -> T5 Encoder (self-attention layers)
+        -> Task-specific head (forecasting: linear projection)
+        -> RevIN denormalization
+        -> Output [batch, n_channels, forecast_horizon]
+    """
+
+    def __init__(self, config: MomentConfig):
+        super().__init__()
+        self.config = config
+        self.seq_len = config.seq_len
+        self.patch_len = config.patch_len
+        self.patch_stride_len = config.patch_stride_len
+        self.d_model = config.d_model
+        self.forecast_horizon = config.forecast_horizon
+
+        self.n_patches = (self.seq_len - self.patch_len) // 
self.patch_stride_len + 1
+
+        # RevIN normalization
+        self.revin = RevIN(n_features=1, affine=config.revin_affine)
+
+        # Patching and embedding
+        self.patching = Patching(patch_len=self.patch_len, 
stride=self.patch_stride_len)
+        self.patch_embedding = PatchEmbedding(
+            d_model=self.d_model, patch_len=self.patch_len
+        )
+
+        # Positional embedding for patches
+        self.position_embedding = nn.Embedding(self.n_patches, self.d_model)
+
+        # Mask embedding (for masked reconstruction during pre-training)
+        self.mask_embedding = nn.Parameter(torch.zeros(self.d_model))
+
+        # T5 encoder backbone
+        t5_config = getattr(config, "t5_config", None)
+        if t5_config is not None:
+            encoder_config = T5Config(**t5_config)
+        else:
+            encoder_config = 
T5Config.from_pretrained(config.transformer_backbone)
+            encoder_config.d_model = self.d_model
+        self.encoder = T5EncoderModel(encoder_config)
+
+        # Layer norm before head
+        self.layer_norm = nn.LayerNorm(self.d_model)
+
+        # Forecasting head
+        self.head = ForecastingHead(
+            d_model=self.d_model,
+            n_patches=self.n_patches,
+            forecast_horizon=self.forecast_horizon,
+        )
+
+    def forward(
+        self,
+        x_enc: torch.Tensor,
+        input_mask: Optional[torch.Tensor] = None,
+    ) -> MomentOutput:
+        """
+        Forward pass for forecasting.
+
+        Args:
+            x_enc: [batch_size, n_channels, seq_len]
+            input_mask: [batch_size, seq_len] - 1 for observed, 0 for padding
+
+        Returns:
+            MomentOutput with forecast field
+        """
+        batch_size, n_channels, seq_len = x_enc.shape
+
+        # Handle input_mask
+        if input_mask is None:
+            input_mask = torch.ones(batch_size, seq_len, device=x_enc.device)
+
+        # RevIN normalization (channel-independent)
+        # Reshape to process each channel independently
+        x = x_enc.reshape(batch_size * n_channels, 1, seq_len)
+        # Expand input_mask to match reshaped x: repeat for each channel
+        revin_mask = input_mask.unsqueeze(1).expand(-1, n_channels, -1)
+        revin_mask = revin_mask.reshape(batch_size * n_channels, seq_len)
+        x = self.revin(x, mode="norm", mask=revin_mask)
+        x = x.reshape(batch_size, n_channels, seq_len)
+
+        # Patching: [batch, n_channels, n_patches, patch_len]
+        x = self.patching(x)
+
+        # Patch embedding: [batch, n_channels, n_patches, d_model]
+        x = self.patch_embedding(x)
+
+        # Apply input mask at patch level
+        patch_mask = self._create_patch_mask(input_mask)
+        mask_embed = self.mask_embedding.unsqueeze(0).unsqueeze(0).unsqueeze(0)
+        x = x * patch_mask.unsqueeze(-1) + mask_embed * (1.0 - 
patch_mask.unsqueeze(-1))
+
+        # Position embedding
+        positions = torch.arange(self.n_patches, device=x.device)
+        x = x + self.position_embedding(positions).unsqueeze(0).unsqueeze(0)
+
+        # Flatten batch and channel dims for T5 encoder
+        # [batch * n_channels, n_patches, d_model]
+        x = x.reshape(batch_size * n_channels, self.n_patches, self.d_model)
+
+        # T5 encoder forward
+        enc_output = self.encoder(inputs_embeds=x).last_hidden_state
+
+        # Layer norm
+        enc_output = self.layer_norm(enc_output)
+
+        # Restore channel dim: [batch, n_channels, n_patches, d_model]
+        enc_output = enc_output.reshape(
+            batch_size, n_channels, self.n_patches, self.d_model
+        )
+
+        # Forecasting head: [batch, n_channels, forecast_horizon]
+        forecast = self.head(enc_output)
+
+        # RevIN denormalization
+        forecast = forecast.reshape(batch_size * n_channels, 1, 
self.forecast_horizon)
+        forecast = self.revin(forecast, mode="denorm")
+        forecast = forecast.reshape(batch_size, n_channels, 
self.forecast_horizon)
+
+        return MomentOutput(forecast=forecast, embeddings=enc_output)
+
+    def _create_patch_mask(self, input_mask: torch.Tensor) -> torch.Tensor:
+        """Convert per-timestep mask to per-patch mask via average pooling."""
+        # input_mask: [batch, seq_len]
+        # output: [batch, 1, n_patches] with values in [0, 1]
+        mask = input_mask.unsqueeze(1)  # [batch, 1, seq_len]
+        mask = mask.unfold(
+            dimension=-1, size=self.patch_len, step=self.patch_stride_len
+        )
+        # [batch, 1, n_patches, patch_len]
+        mask = mask.mean(dim=-1)  # [batch, 1, n_patches]
+        mask = (mask > 0.5).float()
+        return mask
+
+
+class MomentPreTrainedModel(PreTrainedModel):
+    """Abstract base class for all MOMENT model variants."""
+
+    config_class = MomentConfig
+    base_model_prefix = "moment"
+    supports_gradient_checkpointing = False
+
+    def _init_weights(self, module):
+        pass
+
+
+class MomentForPrediction(MomentPreTrainedModel):
+    """
+    MOMENT model for time series forecasting, wrapped as a HuggingFace 
PreTrainedModel.
+
+    Loads the pre-trained MOMENT backbone from safetensors and configures
+    the forecasting head for a given horizon.
+
+    Reference: https://huggingface.co/AutonLab/MOMENT-1-large
+    """
+
+    def __init__(self, config: MomentConfig):
+        super().__init__(config)
+        self.moment = MomentBackbone(config)
+        self.post_init()
+
+    def forward(
+        self,
+        x_enc: torch.Tensor,
+        input_mask: Optional[torch.Tensor] = None,
+    ) -> MomentOutput:
+        return self.moment(x_enc=x_enc, input_mask=input_mask)
+
+    @classmethod
+    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+        """
+        Load MomentForPrediction from a local directory containing
+        ``config.json`` and ``model.safetensors``.
+
+        The upstream MOMENT checkpoint uses a flat state-dict structure
+        (keys like ``revin.affine_weight``, ``encoder.encoder.block.0...``).
+        This method handles mapping those keys into our nested ``moment.*``
+        structure.
+        """
+        # Pop kwargs injected by load_transformers_model that are not
+        # relevant to our custom loading logic.
+        kwargs.pop("config", None)
+        kwargs.pop("trust_remote_code", None)
+
+        if not os.path.isdir(pretrained_model_name_or_path):
+            raise ValueError(
+                f"pretrained_model_name_or_path must be a local directory, "
+                f"got: {pretrained_model_name_or_path}"
+            )
+
+        config_file = os.path.join(pretrained_model_name_or_path, 
"config.json")
+        safetensors_file = os.path.join(
+            pretrained_model_name_or_path, "model.safetensors"
+        )
+
+        # Load config
+        config_dict = {}
+        if os.path.exists(config_file):
+            with open(config_file, "r") as f:
+                config_dict = json.load(f)
+
+        # Extract t5_config if present in the upstream config
+        t5_config = config_dict.pop("t5_config", None)
+
+        forecast_horizon = kwargs.pop("forecast_horizon", 96)
+
+        # Build MomentConfig with t5_config stored for backbone construction
+        config = MomentConfig(
+            seq_len=config_dict.get("seq_len", 512),
+            patch_len=config_dict.get("patch_len", 8),
+            patch_stride_len=config_dict.get("patch_stride_len", 8),
+            d_model=config_dict.get("d_model"),
+            transformer_backbone=config_dict.get(
+                "transformer_backbone", "google/flan-t5-large"
+            ),
+            forecast_horizon=forecast_horizon,
+            revin_affine=config_dict.get("revin_affine", False),
+            t5_config=t5_config,
+        )
+
+        # Instantiate model
+        instance = cls.__new__(cls)
+        MomentPreTrainedModel.__init__(instance, config)
+        instance.moment = MomentBackbone(config)
+        instance.post_init()
+
+        # Load weights
+        if not os.path.exists(safetensors_file):
+            raise FileNotFoundError(
+                f"Model checkpoint not found at: {safetensors_file}"
+            )
+
+        import safetensors.torch as safetorch
+
+        state_dict = safetorch.load_file(safetensors_file, device="cpu")
+
+        # Map upstream flat keys to our nested moment.* structure
+        mapped_state_dict = {
+            f"moment.{key}": value for key, value in state_dict.items()
+        }
+
+        # Load with strict=False to skip mismatched head weights
+        model_state = instance.state_dict()
+        filtered = {
+            k: v
+            for k, v in mapped_state_dict.items()
+            if k in model_state and v.shape == model_state[k].shape
+        }
+        instance.load_state_dict(filtered, strict=False)
+        instance.eval()
+
+        logger.info(
+            f"Loaded MOMENT model from {pretrained_model_name_or_path} "
+            f"({len(filtered)}/{len(model_state)} keys matched)"
+        )
+        return instance
+
+    @property
+    def device(self):
+        return next(self.parameters()).device
diff --git 
a/iotdb-core/ainode/iotdb/ainode/core/model/moment/pipeline_moment.py 
b/iotdb-core/ainode/iotdb/ainode/core/model/moment/pipeline_moment.py
new file mode 100644
index 00000000000..67108d736de
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/moment/pipeline_moment.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+# This file contains code adapted from the MOMENT project
+# (https://github.com/moment-timeseries-foundation-model/moment),
+# originally licensed under the MIT License.
+
+import torch
+
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import 
ForecastPipeline
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.model_info import ModelInfo
+
+logger = Logger()
+
+# MOMENT requires a fixed input length of 512 timesteps
+MOMENT_SEQ_LEN = 512
+
+
+class MomentPipeline(ForecastPipeline):
+    """
+    Inference pipeline for the MOMENT time series foundation model.
+
+    MOMENT processes fixed-length (512) univariate patches through a T5 encoder
+    and produces forecasts via a single-shot linear head. Each channel/variate
+    is processed independently (channel-independent design).
+
+    The pipeline handles:
+    - Padding/truncating inputs to the required 512 length
+    - Constructing input masks for padded positions
+    - Iterative forecasting for horizons beyond the model's native capacity
+    """
+
+    def __init__(self, model_info: ModelInfo, **model_kwargs):
+        super().__init__(model_info, **model_kwargs)
+
+    def _preprocess(self, inputs, **infer_kwargs) -> dict:
+        """
+        Preprocess input data for MOMENT.
+
+        Converts the list of input dicts into a single tensor padded/truncated
+        to MOMENT's required sequence length of 512. Also constructs an
+        input_mask indicating which timesteps are observed vs padded.
+
+        Parameters
+        ----------
+        inputs : list of dict
+            Each dict has key ``"targets"`` with a tensor of shape
+            ``(target_count, input_length)``.
+
+        Returns
+        -------
+        dict
+            ``"x_enc"``: tensor of shape ``[batch, n_channels, 512]``
+            ``"input_mask"``: tensor of shape ``[batch, 512]``
+        """
+        if inputs[0].get("past_covariates") or 
inputs[0].get("future_covariates"):
+            logger.warning(
+                "MomentPipeline does not support covariates; they will be 
ignored."
+            )
+
+        batch_tensors = []
+        batch_masks = []
+
+        for item in inputs:
+            targets = item["targets"]  # [target_count, input_length]
+            if targets.ndim == 1:
+                targets = targets.unsqueeze(0)
+
+            n_channels, input_length = targets.shape
+
+            if input_length >= MOMENT_SEQ_LEN:
+                # Truncate: take the last MOMENT_SEQ_LEN timesteps
+                x = targets[:, -MOMENT_SEQ_LEN:]
+                mask = torch.ones(MOMENT_SEQ_LEN, device=targets.device)
+            else:
+                # Left-pad with zeros
+                pad_len = MOMENT_SEQ_LEN - input_length
+                x = torch.nn.functional.pad(targets, (pad_len, 0), value=0.0)
+                mask = torch.cat(
+                    [
+                        torch.zeros(pad_len, device=targets.device),
+                        torch.ones(input_length, device=targets.device),
+                    ]
+                )
+
+            batch_tensors.append(x)
+            batch_masks.append(mask)
+
+        x_enc = torch.stack(batch_tensors, dim=0)  # [batch, n_channels, 512]
+        input_mask = torch.stack(batch_masks, dim=0)  # [batch, 512]
+
+        return {"x_enc": x_enc, "input_mask": input_mask}
+
+    def forecast(self, inputs: dict, **infer_kwargs) -> list[torch.Tensor]:
+        """
+        Run MOMENT forecasting inference.
+
+        For output_length <= model forecast_horizon, a single forward pass
+        suffices. For longer horizons, iterative (autoregressive) forecasting
+        is used: each step's predictions are appended to the context window
+        and fed back as input.
+
+        Parameters
+        ----------
+        inputs : dict
+            Contains ``"x_enc"`` and ``"input_mask"`` from _preprocess.
+        infer_kwargs : dict
+            ``output_length`` (int): desired forecast length, default 96.
+
+        Returns
+        -------
+        list of torch.Tensor
+            Each tensor has shape ``[n_channels, output_length]``.
+        """
+        output_length = infer_kwargs.get("output_length", 96)
+        x_enc = inputs["x_enc"].to(self.model.device)
+        input_mask = inputs["input_mask"].to(self.model.device)
+
+        model_horizon = self.model.config.forecast_horizon
+        batch_size, n_channels, seq_len = x_enc.shape
+
+        if output_length <= model_horizon:
+            # Single-shot inference
+            with torch.no_grad():
+                output = self.model(x_enc=x_enc, input_mask=input_mask)
+            forecasts = output.forecast[:, :, :output_length]
+        else:
+            # Iterative forecasting for long horizons
+            forecasts_list = []
+            remaining = output_length
+            current_x = x_enc
+            current_mask = input_mask
+
+            while remaining > 0:
+                with torch.no_grad():
+                    output = self.model(x_enc=current_x, 
input_mask=current_mask)
+                step_forecast = output.forecast[:, :, : min(model_horizon, 
remaining)]
+                forecasts_list.append(step_forecast)
+                remaining -= step_forecast.shape[-1]
+
+                if remaining > 0:
+                    # Slide context window: append forecast, drop oldest
+                    step_len = step_forecast.shape[-1]
+                    current_x = torch.cat(
+                        [current_x[:, :, step_len:], step_forecast], dim=-1
+                    )
+                    current_mask = torch.ones(batch_size, seq_len, 
device=x_enc.device)
+
+            forecasts = torch.cat(forecasts_list, dim=-1)
+
+        # Split batch into list of per-sample tensors
+        return [forecasts[i] for i in range(batch_size)]
+
+    def _postprocess(
+        self, outputs: list[torch.Tensor], **infer_kwargs
+    ) -> list[torch.Tensor]:
+        """
+        Postprocess outputs. Each tensor is already [n_channels, 
output_length].
+        """
+        return outputs
diff --git a/scripts/conf/confignode-env.sh b/scripts/conf/confignode-env.sh
index edf587b8c2b..3b2117209fc 100644
--- a/scripts/conf/confignode-env.sh
+++ b/scripts/conf/confignode-env.sh
@@ -192,7 +192,7 @@ else
     JAVA=java
 fi
 
-if [ -z $JAVA ] ; then
+if [ -z "$JAVA" ] ; then
     echo Unable to find java executable. Check JAVA_HOME and PATH environment 
variables.  > /dev/stderr
     exit 1;
 fi
diff --git a/scripts/conf/datanode-env.sh b/scripts/conf/datanode-env.sh
index 31206e550d3..f0f7da3e79b 100755
--- a/scripts/conf/datanode-env.sh
+++ b/scripts/conf/datanode-env.sh
@@ -198,7 +198,7 @@ else
     JAVA=java
 fi
 
-if [ -z $JAVA ] ; then
+if [ -z "$JAVA" ] ; then
     echo Unable to find java executable. Check JAVA_HOME and PATH environment 
variables.  > /dev/stderr
     exit 1;
 fi

Reply via email to