This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 5095862f75f [AINode]: Integrate toto as a builtin forecasting model
(#17322)
5095862f75f is described below
commit 5095862f75f33a1f5c800666ceacd0c6a3b042e3
Author: Grace Li <[email protected]>
AuthorDate: Fri Mar 27 03:07:19 2026 -0400
[AINode]: Integrate toto as a builtin forecasting model (#17322)
---
NOTICE | 9 +
.../apache/iotdb/ainode/utils/AINodeTestUtils.java | 4 +-
iotdb-core/ainode/build_binary.py | 8 +-
.../ainode/iotdb/ainode/core/model/model_info.py | 13 +
.../iotdb/ainode/core/model/toto/__init__.py | 17 +
.../ainode/core/model/toto/configuration_toto.py | 78 ++++
.../iotdb/ainode/core/model/toto/data/__init__.py | 20 +
.../ainode/core/model/toto/data/util/__init__.py | 20 +
.../ainode/core/model/toto/data/util/dataset.py | 127 ++++++
.../ainode/core/model/toto/inference/__init__.py | 20 +
.../ainode/core/model/toto/inference/forecaster.py | 452 +++++++++++++++++++++
.../iotdb/ainode/core/model/toto/model/__init__.py | 20 +
.../ainode/core/model/toto/model/attention.py | 276 +++++++++++++
.../iotdb/ainode/core/model/toto/model/backbone.py | 258 ++++++++++++
.../ainode/core/model/toto/model/distribution.py | 112 +++++
.../ainode/core/model/toto/model/embedding.py | 83 ++++
.../ainode/core/model/toto/model/feed_forward.py | 35 ++
.../iotdb/ainode/core/model/toto/model/fusion.py | 58 +++
.../iotdb/ainode/core/model/toto/model/rope.py | 94 +++++
.../iotdb/ainode/core/model/toto/model/scaler.py | 328 +++++++++++++++
.../iotdb/ainode/core/model/toto/model/toto.py | 157 +++++++
.../ainode/core/model/toto/model/transformer.py | 318 +++++++++++++++
.../iotdb/ainode/core/model/toto/model/util.py | 251 ++++++++++++
.../iotdb/ainode/core/model/toto/modeling_toto.py | 167 ++++++++
.../iotdb/ainode/core/model/toto/pipeline_toto.py | 144 +++++++
iotdb-core/ainode/pyproject.toml | 1 +
26 files changed, 3066 insertions(+), 4 deletions(-)
diff --git a/NOTICE b/NOTICE
index fa52a36987f..429495c377b 100644
--- a/NOTICE
+++ b/NOTICE
@@ -17,6 +17,15 @@ grant the users the right to the use of patent under the
requirement of Apache 2
============================================================================
+This product includes source code derived from the DataDog/toto project:
+
+ Toto – Timeseries-Optimized Transformer for Observability
+ Copyright 2025 Datadog, Inc.
+ Licensed under the Apache License, Version 2.0
+ https://github.com/DataDog/toto
+
+============================================================================
+
Apache Commons Collections
Copyright 2001-2019 The Apache Software Foundation
diff --git
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
index e41d3d4e0f9..bf758a083d4 100644
---
a/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
+++
b/integration-test/src/test/java/org/apache/iotdb/ainode/utils/AINodeTestUtils.java
@@ -58,7 +58,9 @@ public class AINodeTestUtils {
new AbstractMap.SimpleEntry<>(
"chronos2", new FakeModelInfo("chronos2", "t5", "builtin",
"active")),
new AbstractMap.SimpleEntry<>(
- "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin",
"active")))
+ "moirai2", new FakeModelInfo("moirai2", "moirai", "builtin",
"active")),
+ new AbstractMap.SimpleEntry<>(
+ "toto", new FakeModelInfo("toto", "toto", "builtin",
"active")))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
public static final Map<String, FakeModelInfo> BUILTIN_MODEL_MAP;
diff --git a/iotdb-core/ainode/build_binary.py
b/iotdb-core/ainode/build_binary.py
index c943de41581..f3b7fa1cedf 100644
--- a/iotdb-core/ainode/build_binary.py
+++ b/iotdb-core/ainode/build_binary.py
@@ -423,7 +423,7 @@ def install_dependencies(venv_python, venv_dir, script_dir):
[str(poetry_exe), "lock"],
cwd=str(script_dir),
env=venv_env,
- check=True,
+ check=False,
capture_output=True,
text=True,
)
@@ -431,6 +431,9 @@ def install_dependencies(venv_python, venv_dir, script_dir):
print(result.stdout)
if result.stderr:
print(result.stderr)
+ if result.returncode != 0:
+ print(f"ERROR: poetry lock failed with exit code {result.returncode}")
+ sys.exit(1)
verify_poetry_env() # Verify after lock
accelerator = detect_accelerator()
@@ -438,11 +441,10 @@ def install_dependencies(venv_python, venv_dir,
script_dir):
print("Running poetry install...")
subprocess.run(
- [str(poetry_exe), "lock"],
+ [str(poetry_exe), "install", "--no-root"],
cwd=str(script_dir),
env=venv_env,
check=True,
- capture_output=True,
text=True,
)
verify_poetry_env() # Verify before install
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
index da752cbd784..642986c42d2 100644
--- a/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/model_info.py
@@ -160,4 +160,17 @@ BUILTIN_HF_TRANSFORMERS_MODEL_MAP = {
},
transformers_registered=True,
),
+ "toto": ModelInfo(
+ model_id="toto",
+ category=ModelCategory.BUILTIN,
+ state=ModelStates.INACTIVE,
+ model_type="toto",
+ pipeline_cls="pipeline_toto.TotoPipeline",
+ repo_id="Datadog/Toto-Open-Base-1.0",
+ auto_map={
+ "AutoConfig": "configuration_toto.TotoConfig",
+ "AutoModelForCausalLM": "modeling_toto.TotoForPrediction",
+ },
+ transformers_registered=True,
+ ),
}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py
new file mode 100644
index 00000000000..2a1e720805f
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/__init__.py
@@ -0,0 +1,17 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py
new file mode 100644
index 00000000000..2a00fcc3be4
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/configuration_toto.py
@@ -0,0 +1,78 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import List, Optional
+
+from transformers import PretrainedConfig
+
+
+class TotoConfig(PretrainedConfig):
+ """
+ Configuration class for the Toto time series forecasting model.
+
+ Toto (Time Series Optimized Transformer for Observability) is a foundation
model
+ for multivariate time series forecasting developed by Datadog. It uses a
decoder-only
+ architecture with per-variate patch-based causal scaling, proportional
time-variate
+ factorized attention, and a Student-T mixture prediction head.
+
+ Reference: https://github.com/DataDog/toto
+ """
+
+ model_type = "toto"
+
+ def __init__(
+ self,
+ patch_size: int = 32,
+ stride: int = 32,
+ embed_dim: int = 1024,
+ num_layers: int = 18,
+ num_heads: int = 16,
+ mlp_hidden_dim: int = 2816,
+ dropout: float = 0.0,
+ spacewise_every_n_layers: int = 3,
+ scaler_cls: str = "per_variate_causal",
+ output_distribution_classes: Optional[List[str]] = None,
+ output_distribution_kwargs: Optional[dict] = None,
+ spacewise_first: bool = True,
+ use_memory_efficient_attention: bool = True,
+ stabilize_with_global: bool = True,
+ scale_factor_exponent: float = 10.0,
+ **kwargs,
+ ):
+ self.patch_size = patch_size
+ self.stride = stride
+ self.embed_dim = embed_dim
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.mlp_hidden_dim = mlp_hidden_dim
+ self.dropout = dropout
+ self.spacewise_every_n_layers = spacewise_every_n_layers
+ self.scaler_cls = scaler_cls
+ self.output_distribution_classes = output_distribution_classes or [
+ "student_t_mixture"
+ ]
+ # k_components=5 is the default used by Datadog/Toto-Open-Base-1.0
+ self.output_distribution_kwargs = output_distribution_kwargs or {
+ "k_components": 5
+ }
+ self.spacewise_first = spacewise_first
+ self.use_memory_efficient_attention = use_memory_efficient_attention
+ self.stabilize_with_global = stabilize_with_global
+ self.scale_factor_exponent = scale_factor_exponent
+
+ super().__init__(**kwargs)
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py
new file mode 100644
index 00000000000..6bccf35988c
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/data/util/dataset.py
@@ -0,0 +1,127 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from functools import reduce
+from typing import NamedTuple
+
+import numpy as np
+import torch
+import torch.utils.data
+from einops import repeat
+from jaxtyping import Bool, Float, Int, Shaped
+
+
+def pad_array(
+ values: Shaped[torch.Tensor, "*batch variates series_len"], # noqa: F722
+ patch_stride: int,
+) -> Shaped[torch.Tensor, "*batch variates padded_length"]: # noqa: F722
+ """
+ Makes sure that the series length is divisible by the patch_stride
+ by adding left-padding.
+ """
+ if isinstance(values, np.ndarray):
+ values = torch.from_numpy(values)
+ series_len = values.shape[-1]
+ padded_length = int(np.ceil(series_len / patch_stride) * patch_stride)
+ if values.ndim == 2:
+ padded_values = torch.zeros((values.shape[0], padded_length),
dtype=values.dtype, device=values.device)
+ elif values.ndim == 3:
+ padded_values = torch.zeros(
+ (values.shape[0], values.shape[1], padded_length),
+ dtype=values.dtype,
+ device=values.device,
+ )
+ else:
+ raise ValueError(f"Unsupported number of dimensions: {values.ndim}")
+ padded_values[..., -series_len:] = values
+
+ return padded_values
+
+
+def pad_id_mask(
+ id_mask: Int[torch.Tensor, "*batch variates series_len"], # noqa: F722
+ patch_stride: int,
+) -> Int[torch.Tensor, "*batch variates padded_length"]: # noqa: F722
+ """
+ Makes sure that the series length is divisible by the patch_stride
+ by adding left-padding to the id mask.
+ """
+ series_len = id_mask.shape[-1]
+ padded_length = int(np.ceil(series_len / patch_stride) * patch_stride)
+ padding_amount = padded_length - series_len
+ left_edge: Int[torch.Tensor, "*batch variates"] = id_mask[..., 0] # noqa:
F722
+ if id_mask.ndim == 2:
+ padding = repeat(
+ left_edge,
+ "variates -> variates padding_amount",
+ padding_amount=padding_amount,
+ )
+ id_mask = torch.cat([padding, id_mask], dim=1)
+ elif id_mask.ndim == 3:
+ padding = repeat(
+ left_edge,
+ "batch variates -> batch variates padding_amount",
+ padding_amount=padding_amount,
+ )
+ id_mask = torch.cat([padding, id_mask], dim=2)
+ else:
+ raise ValueError(f"Unsupported number of dimensions: {id_mask.ndim}")
+
+ return id_mask
+
+
+class MaskedTimeseries(NamedTuple):
+ series: Float[torch.Tensor, "*batch variates series_len"] # noqa: F722
+ padding_mask: Bool[torch.Tensor, "*batch variates series_len"] # noqa:
F722
+ id_mask: Int[torch.Tensor, "*batch variates #series_len"] # noqa: F722
+ timestamp_seconds: Int[torch.Tensor, "*batch variates series_len"] #
noqa: F722
+ time_interval_seconds: Int[torch.Tensor, "*batch variates"] # noqa: F722
+ num_exogenous_variables: int = 0
+
+ def to(self, device: torch.device) -> "MaskedTimeseries":
+ return MaskedTimeseries(
+ series=self.series.to(device),
+ padding_mask=self.padding_mask.to(device),
+ id_mask=self.id_mask.to(device),
+ timestamp_seconds=self.timestamp_seconds.to(device),
+ time_interval_seconds=self.time_interval_seconds.to(device),
+ num_exogenous_variables=self.num_exogenous_variables,
+ )
+
+
+def is_extreme_value(t: torch.Tensor) -> torch.Tensor:
+ if torch.is_floating_point(t):
+ max_value = torch.finfo(t.dtype).max
+ else:
+ max_value = torch.iinfo(t.dtype).max
+
+ return reduce(
+ torch.logical_or,
+ (
+ torch.isinf(t),
+ torch.isnan(t),
+ t.abs() >= max_value / 2,
+ ),
+ )
+
+
+def replace_extreme_values(t: torch.Tensor, replacement: float = 0.0) ->
torch.Tensor:
+ return torch.where(is_extreme_value(t), torch.tensor(replacement,
dtype=t.dtype, device=t.device), t)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py
new file mode 100644
index 00000000000..2a9db2aa629
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/inference/forecaster.py
@@ -0,0 +1,452 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from dataclasses import dataclass
+from typing import cast
+
+import numpy as np
+import torch
+from einops import rearrange, repeat
+from jaxtyping import Bool, Float, Int
+from torch.distributions import Distribution, TransformedDistribution
+from torch.distributions.transforms import AffineTransform
+
+from ..data.util.dataset import (
+ MaskedTimeseries,
+ pad_array,
+ pad_id_mask,
+ replace_extreme_values,
+)
+from ..model.backbone import TotoBackbone
+
+
+class AffineTransformed(TransformedDistribution):
+ """
+ Thin wrapper around TransformedDistribution with AffineTransform,
+ replacing gluonts.torch.distributions.AffineTransformed.
+ """
+
+ def __init__(self, base_distribution, loc=0.0, scale=1.0):
+ super().__init__(base_distribution, AffineTransform(loc=loc,
scale=scale))
+
+ @property
+ def mean(self):
+ loc = self.transforms[0].loc
+ scale = self.transforms[0].scale
+ return loc + scale * self.base_dist.mean
+
+ # Note: Do NOT override sample() here. TransformedDistribution.sample()
correctly
+ # calls base_dist.sample() (not rsample), which works for
non-reparameterizable
+ # distributions like MixtureSameFamily.
+
+
+@dataclass(frozen=True)
+class Forecast:
+ mean: Float[torch.Tensor, "batch variate future_time_steps"]
+ samples: Float[torch.Tensor, "batch variate future_time_steps samples"] |
None = (
+ None
+ )
+
+ def quantile(
+ self, q: float | torch.Tensor
+ ) -> Float[torch.Tensor, "batch variate future_time_steps"]:
+ assert self.samples is not None, "samples must be provided to compute
quantiles"
+ assert isinstance(q, (float, torch.Tensor)), "q must be a float or a
tensor"
+ if isinstance(q, float):
+ q = torch.tensor(q, device=self.samples.device,
dtype=self.samples.dtype)
+ return self.samples.quantile(q, dim=-1)
+
+ @property
+ def median(self) -> Float[torch.Tensor, "batch variate future_time_steps"]:
+ return self.quantile(0.5)
+
+ @property
+ def std(self) -> Float[torch.Tensor, "batch variate future_time_steps"]:
+ assert (
+ self.samples is not None
+ ), "samples must be provided to compute standard deviation"
+ return self.samples.std(dim=-1)
+
+
+class TotoForecaster:
+ """
+ A forecaster class for the Toto model that handles autoregressive decoding
+ for time series forecasting.
+ """
+
+ model: TotoBackbone
+
+ def __init__(self, model: TotoBackbone):
+ self.model = model
+ self.model.eval()
+
+ def forecast(
+ self,
+ inputs: MaskedTimeseries,
+ prediction_length: int,
+ num_samples: int | None = None,
+ samples_per_batch: int = 10,
+ use_kv_cache: bool = True,
+ future_exogenous_variables: (
+ Float[torch.Tensor, "batch exogenous_variables future_time_steps"]
| None
+ ) = None,
+ ) -> Forecast:
+ if len(inputs.series.shape) == 2:
+ batch = cast(MaskedTimeseries,
torch.utils.data.default_collate([inputs]))
+ else:
+ batch = inputs
+
+ if (
+ future_exogenous_variables is not None
+ and len(future_exogenous_variables.shape) == 2
+ ):
+ future_exogenous_variables =
future_exogenous_variables.unsqueeze(0)
+
+ series = pad_array(batch.series, self.model.patch_embed.stride)
+ padding_mask = pad_array(batch.padding_mask,
self.model.patch_embed.stride)
+ id_mask = batch.id_mask
+ if id_mask is not None:
+ id_mask = pad_id_mask(batch.id_mask, self.model.patch_embed.stride)
+ timestamp_seconds = pad_array(
+ batch.timestamp_seconds, self.model.patch_embed.stride
+ )
+ time_interval_seconds: Int[torch.Tensor, "batch variate series_len"] =
(
+ torch.as_tensor(
+ batch.time_interval_seconds, device=series.device,
dtype=torch.int
+ )
+ )
+
+ if num_samples is not None:
+ samples = self.generate_samples(
+ inputs=series,
+ prediction_length=prediction_length,
+ num_samples=num_samples,
+ timestamp_seconds=timestamp_seconds,
+ time_interval_seconds=time_interval_seconds,
+ input_padding_mask=padding_mask,
+ id_mask=id_mask,
+ sampling_batch_size=samples_per_batch,
+ use_kv_cache=use_kv_cache,
+ future_exogenous_variables=future_exogenous_variables,
+ num_exogenous_variables=batch.num_exogenous_variables,
+ )
+ mean = samples.mean(dim=-1)
+ else:
+ mean = self.generate_mean(
+ inputs=series,
+ prediction_length=prediction_length,
+ timestamp_seconds=timestamp_seconds,
+ time_interval_seconds=time_interval_seconds,
+ input_padding_mask=padding_mask,
+ id_mask=id_mask,
+ use_kv_cache=use_kv_cache,
+ future_exogenous_variables=future_exogenous_variables,
+ num_exogenous_variables=batch.num_exogenous_variables,
+ )
+ samples = None
+
+ return Forecast(mean=mean, samples=samples)
+
+ def assert_ev_compatibility(
+ self,
+ inputs,
+ future_exogenous_variables,
+ prediction_length,
+ num_exogenous_variables,
+ ) -> None:
+ assert future_exogenous_variables.shape[-1] == prediction_length
+ assert future_exogenous_variables.shape[0] == inputs.shape[0]
+ assert num_exogenous_variables == future_exogenous_variables.shape[-2]
+
+ def round_ft_ev(self, future_exogenous_variables, T_rounded):
+ B, V_ev, T_future = future_exogenous_variables.shape
+ dtype = future_exogenous_variables.dtype
+ device = future_exogenous_variables.device
+ padding = torch.zeros(B, V_ev, T_rounded - T_future, device=device,
dtype=dtype)
+ return torch.cat([future_exogenous_variables, padding], dim=-1)
+
+ @torch.no_grad()
+ def generate_mean(
+ self,
+ inputs: Float[torch.Tensor, "batch variate time_steps"],
+ prediction_length: int,
+ timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"],
+ time_interval_seconds: Int[torch.Tensor, "batch variate"],
+ input_padding_mask: (
+ Bool[torch.Tensor, "batch variate time_steps"] | None
+ ) = None,
+ id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None =
None,
+ use_kv_cache: bool = False,
+ future_exogenous_variables=None,
+ num_exogenous_variables: int = 0,
+ ) -> Float[torch.Tensor, "batch variate time_steps"]:
+ if input_padding_mask is None:
+ input_padding_mask = torch.ones_like(
+ inputs, dtype=torch.bool, device=inputs.device
+ )
+ if id_mask is None:
+ id_mask = torch.zeros_like(inputs, dtype=torch.int,
device=inputs.device)
+
+ if future_exogenous_variables is not None:
+ self.assert_ev_compatibility(
+ inputs,
+ future_exogenous_variables,
+ prediction_length,
+ num_exogenous_variables,
+ )
+
+ patch_size = self.model.patch_embed.stride
+ rounded_steps = int(np.ceil(prediction_length / patch_size) *
patch_size)
+ if rounded_steps > prediction_length and future_exogenous_variables is
not None:
+ future_exogenous_variables = self.round_ft_ev(
+ future_exogenous_variables, rounded_steps
+ )
+ start_index = inputs.shape[-1]
+ end_index = start_index + prediction_length
+
+ dummy_padding = torch.ones(
+ (input_padding_mask.shape[0], input_padding_mask.shape[1],
patch_size),
+ device=inputs.device,
+ dtype=torch.bool,
+ )
+ dummy_id_mask = repeat(
+ id_mask[:, :, -1:],
+ "batch variates 1 -> batch variates patch_size",
+ patch_size=patch_size,
+ )
+ if use_kv_cache:
+ kv_cache = self.model.allocate_kv_cache(
+ batch_size=inputs.shape[0],
+ num_variates=inputs.shape[1],
+ max_time_steps=inputs.shape[2] + rounded_steps,
+ device=inputs.device,
+ dtype=inputs.dtype,
+ )
+ else:
+ kv_cache = None
+
+ scaling_prefix_length = inputs.shape[-1]
+
+ for idx in range(rounded_steps // patch_size):
+ base_distr, loc, scale = self.model(
+ inputs=inputs,
+ input_padding_mask=input_padding_mask,
+ id_mask=id_mask,
+ kv_cache=kv_cache,
+ scaling_prefix_length=scaling_prefix_length,
+ num_exogenous_variables=num_exogenous_variables,
+ )
+ distr = self.create_affine_transformed(base_distr, loc, scale)
+
+ samples = replace_extreme_values(distr.mean[:, :, -patch_size:])
+
+ if future_exogenous_variables is not None:
+ start, stop = idx * patch_size, (idx + 1) * patch_size
+ samples[:, -num_exogenous_variables:] =
future_exogenous_variables[
+ :, :, start:stop
+ ]
+
+ inputs = torch.cat([inputs, samples], dim=-1)
+ id_mask = torch.cat([id_mask, dummy_id_mask], dim=-1)
+ input_padding_mask = torch.cat([input_padding_mask,
dummy_padding], dim=-1)
+ for _ in range(patch_size):
+ next_timestamp = timestamp_seconds[:, :, -1] +
time_interval_seconds
+ timestamp_seconds = torch.cat(
+ [timestamp_seconds, next_timestamp.unsqueeze(-1)], dim=-1
+ )
+
+ return inputs.detach()[:, :, start_index:end_index]
+
+ @torch.no_grad()
+ def generate_samples(
+ self,
+ inputs: Float[torch.Tensor, "batch variate time_steps"],
+ prediction_length: int,
+ num_samples: int,
+ timestamp_seconds: Int[torch.Tensor, "batch variate time_steps"],
+ time_interval_seconds: Int[torch.Tensor, "batch variate"],
+ input_padding_mask: (
+ Bool[torch.Tensor, "batch variate time_steps"] | None
+ ) = None,
+ id_mask: Float[torch.Tensor, "batch #variate time_steps"] | None =
None,
+ sampling_batch_size: int = 10,
+ use_kv_cache: bool = False,
+ future_exogenous_variables=None,
+ num_exogenous_variables: int = 0,
+ ) -> Float[torch.Tensor, "batch variate time_steps samples"]:
+ if input_padding_mask is None:
+ input_padding_mask = torch.ones_like(
+ inputs, dtype=torch.bool, device=inputs.device
+ )
+ if id_mask is None:
+ id_mask = torch.zeros_like(inputs, dtype=torch.int,
device=inputs.device)
+
+ if future_exogenous_variables is not None:
+ self.assert_ev_compatibility(
+ inputs,
+ future_exogenous_variables,
+ prediction_length,
+ num_exogenous_variables,
+ )
+
+ assert (
+ num_samples % sampling_batch_size == 0
+ ), "num_samples must be divisible by sampling_batch_size"
+ num_batches = num_samples // sampling_batch_size
+
+ patch_size = self.model.patch_embed.patch_size
+ rounded_steps = int(np.ceil(prediction_length / patch_size) *
patch_size)
+ if rounded_steps > prediction_length and future_exogenous_variables is
not None:
+ future_exogenous_variables = self.round_ft_ev(
+ future_exogenous_variables, rounded_steps
+ )
+ start_index = inputs.shape[-1]
+ end_index = start_index + prediction_length
+
+ dummy_padding = torch.ones(
+ (
+ input_padding_mask.shape[0] * sampling_batch_size,
+ input_padding_mask.shape[1],
+ patch_size,
+ ),
+ dtype=torch.bool,
+ device=inputs.device,
+ )
+ dummy_id_mask = repeat(
+ id_mask[:, :, -1:],
+ "batch variates 1 -> (sampling_batch_size batch) variates
patch_size",
+ sampling_batch_size=sampling_batch_size,
+ patch_size=patch_size,
+ )
+ inputs = repeat(
+ inputs,
+ "batch variates seq_len -> (sampling_batch_size batch) variates
seq_len",
+ sampling_batch_size=sampling_batch_size,
+ )
+ if future_exogenous_variables is not None:
+ future_exogenous_variables = repeat(
+ future_exogenous_variables,
+ "batch exogenous_variables future_time_steps ->
(sampling_batch_size batch) exogenous_variables future_time_steps",
+ sampling_batch_size=sampling_batch_size,
+ )
+ input_padding_mask = repeat(
+ input_padding_mask,
+ "batch variates seq_len -> (sampling_batch_size batch) variates
seq_len",
+ sampling_batch_size=sampling_batch_size,
+ )
+ id_mask = repeat(
+ id_mask,
+ "batch variates seq_len -> (sampling_batch_size batch) variates
seq_len",
+ sampling_batch_size=sampling_batch_size,
+ )
+ timestamp_seconds = repeat(
+ timestamp_seconds,
+ "batch variates seq_len -> (sampling_batch_size batch) variates
seq_len",
+ sampling_batch_size=sampling_batch_size,
+ )
+ time_interval_seconds = repeat(
+ time_interval_seconds,
+ "batch variates -> (sampling_batch_size batch) variates",
+ sampling_batch_size=sampling_batch_size,
+ )
+
+ all_samples = []
+ if use_kv_cache:
+ kv_cache = self.model.allocate_kv_cache(
+ batch_size=inputs.shape[0],
+ num_variates=inputs.shape[1],
+ max_time_steps=inputs.shape[2] + rounded_steps,
+ device=inputs.device,
+ dtype=inputs.dtype,
+ )
+ else:
+ kv_cache = None
+
+ scaling_prefix_length = inputs.shape[-1]
+
+ for _ in range(num_batches):
+ batch_inputs = torch.clone(inputs)
+ batch_input_padding_mask = torch.clone(input_padding_mask)
+ batch_id_mask = torch.clone(id_mask)
+ batch_timestamp_seconds = torch.clone(timestamp_seconds)
+
+ for idx in range(rounded_steps // patch_size):
+ base_distr, loc, scale = self.model(
+ inputs=batch_inputs,
+ input_padding_mask=batch_input_padding_mask,
+ id_mask=batch_id_mask,
+ kv_cache=kv_cache,
+ scaling_prefix_length=scaling_prefix_length,
+ num_exogenous_variables=num_exogenous_variables,
+ )
+ distr = self.create_affine_transformed(base_distr, loc, scale)
+
+ sample = distr.sample()
+ assert sample is not None
+
+ samples = replace_extreme_values(sample[:, :, -patch_size:])
+
+ if future_exogenous_variables is not None:
+ start, stop = idx * patch_size, (idx + 1) * patch_size
+ samples[:, -num_exogenous_variables:] =
future_exogenous_variables[
+ :, :, start:stop
+ ]
+ batch_inputs = torch.cat([batch_inputs, samples], dim=-1)
+ batch_id_mask = torch.cat([batch_id_mask, dummy_id_mask],
dim=-1)
+ batch_input_padding_mask = torch.cat(
+ [batch_input_padding_mask, dummy_padding], dim=-1
+ )
+ for _ in range(patch_size):
+ next_timestamp = (
+ batch_timestamp_seconds[:, :, -1] +
time_interval_seconds
+ )
+ batch_timestamp_seconds = torch.cat(
+ [batch_timestamp_seconds,
next_timestamp.unsqueeze(-1)], dim=-1
+ )
+ all_samples.append(batch_inputs)
+ if kv_cache is not None:
+ kv_cache.reset()
+
+ outputs = torch.cat(all_samples, dim=0)
+ unfolded_outputs = rearrange(
+ outputs,
+ "(samples batch) variates seq_len -> batch variates seq_len
samples",
+ samples=num_samples,
+ ).detach()
+
+ return unfolded_outputs[:, :, start_index:end_index, :]
+
+ @staticmethod
+ def create_affine_transformed(
+ base_distr: Distribution, loc: torch.Tensor, scale: torch.Tensor
+ ) -> Distribution:
+ base_shape = base_distr.mean.shape
+ base_time_dim = base_shape[-1]
+ loc_time_dim = loc.shape[-1]
+
+ if loc_time_dim == 1:
+ return AffineTransformed(base_distr, loc=loc, scale=scale)
+
+ return AffineTransformed(
+ base_distr,
+ loc=loc[:, :, -base_time_dim:],
+ scale=scale[:, :, -base_time_dim:],
+ )
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py
new file mode 100644
index 00000000000..ba26b1edd94
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py
new file mode 100644
index 00000000000..80f6d381ff2
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/attention.py
@@ -0,0 +1,276 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import logging
+import warnings
+from enum import Enum
+from typing import TYPE_CHECKING, Optional, Union
+
+import torch
+from einops import rearrange
+from jaxtyping import Bool, Float, Int
+
+from .rope import TimeAwareRotaryEmbedding
+
+if TYPE_CHECKING:
+ from .util import KVCache
+
+log = logging.getLogger(__name__)
+
+try:
+ from xformers.ops import LowerTriangularMask, memory_efficient_attention
+
+ XFORMERS_AVAILABLE = True
+ log.info("xFormers Memory-Efficient Attention available.")
+except ImportError:
+ warnings.warn(
+ "xFormers Memory-Efficient Attention not available. "
+ "Falling back to native PyTorch scaled_dot_product_attention.",
+ ImportWarning,
+ )
+
+ XFORMERS_AVAILABLE = False
+
+from torch.nn.functional import scaled_dot_product_attention
+
+
+class AttentionAxis(Enum):
+ TIME = 1
+ SPACE = 2
+
+
+class BaseMultiheadAttention(torch.nn.Module):
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float,
+ rotary_emb: Optional[TimeAwareRotaryEmbedding],
+ use_memory_efficient_attention: bool,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ assert (
+ embed_dim % num_heads == 0
+ ), "Embedding dimension must be divisible by number of heads."
+ self.head_dim = embed_dim // num_heads
+ self.rotary_emb = rotary_emb
+
+ self.wQKV = torch.nn.Linear(embed_dim, embed_dim * 3)
+ self.dropout = dropout
+ self.use_memory_efficient_attention = use_memory_efficient_attention
+ self.wO = torch.nn.Linear(embed_dim, embed_dim)
+
+ assert not (
+ not XFORMERS_AVAILABLE and self.use_memory_efficient_attention
+ ), "XFORMERS_AVAILABLE is False, so use_memory_efficient_attention
must be False"
+
+ if not hasattr(self, "attention_axis") or self.attention_axis not in (
+ AttentionAxis.TIME,
+ AttentionAxis.SPACE,
+ ):
+ raise ValueError(
+ "Child class must define attention_axis as AttentionAxis.TIME
or AttentionAxis.SPACE."
+ )
+
+ def rearrange_inputs(
+ self, inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"]
+ ) -> Float[torch.Tensor, "... embed_dim"]:
+ pattern = (
+ "batch variate seq_len embed_dim -> (batch variate) seq_len
embed_dim"
+ if self.attention_axis == AttentionAxis.TIME
+ else "batch variate seq_len embed_dim -> (batch seq_len) variate
embed_dim"
+ )
+ return rearrange(inputs, pattern)
+
+ def get_qkv(self, inputs: torch.Tensor) -> tuple[torch.Tensor, ...]:
+ if (
+ self.attention_axis == AttentionAxis.TIME
+ and self.use_memory_efficient_attention
+ ):
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv
batch_X_variate seq_len n_heads head_dim"
+ elif (
+ self.attention_axis == AttentionAxis.TIME
+ and not self.use_memory_efficient_attention
+ ):
+ pattern = "batch_X_variate seq_len (qkv head_dim n_heads) -> qkv
batch_X_variate n_heads seq_len head_dim"
+ elif (
+ self.attention_axis == AttentionAxis.SPACE
+ and self.use_memory_efficient_attention
+ ):
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv
batch_X_seq_len variate n_heads head_dim"
+ elif (
+ self.attention_axis == AttentionAxis.SPACE
+ and not self.use_memory_efficient_attention
+ ):
+ pattern = "batch_X_seq_len variate (qkv head_dim n_heads) -> qkv
batch_X_seq_len n_heads variate head_dim"
+
+ qkv = self.wQKV(inputs.contiguous())
+ return rearrange(
+ qkv, pattern, qkv=3, head_dim=self.head_dim, n_heads=self.num_heads
+ ).unbind(dim=0)
+
+ def positional_embedding(self, q, k, v, kv_cache, layer_idx):
+ seq_pos_offset = 0
+ if self.rotary_emb is not None and self.attention_axis ==
AttentionAxis.TIME:
+ if kv_cache is not None:
+ seq_pos_offset = kv_cache.seq_len(layer_idx)
+ q, k = self.rotary_emb.rotate_queries_and_keys(
+ q, k, seq_pos_offset=seq_pos_offset
+ )
+
+ if kv_cache is not None and self.attention_axis == AttentionAxis.TIME:
+ kv_cache.append(layer_idx, (k, v))
+ k, v = kv_cache[layer_idx]
+
+ q = q.contiguous()
+ k = k.contiguous().to(q.dtype)
+ v = v.contiguous().to(q.dtype)
+
+ return q, k, v, seq_pos_offset
+
+ def rearrange_output(
+ self, output: torch.Tensor, batch: int, variate: int, seq_len: int
+ ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+ if (
+ self.attention_axis == AttentionAxis.TIME
+ and self.use_memory_efficient_attention
+ ):
+ pattern = "(batch variate) seq_len n_heads head_dim -> batch
variate seq_len (n_heads head_dim)"
+ elif (
+ self.attention_axis == AttentionAxis.TIME
+ and not self.use_memory_efficient_attention
+ ):
+ pattern = "(batch variate) n_heads seq_len head_dim -> batch
variate seq_len (n_heads head_dim)"
+ elif (
+ self.attention_axis == AttentionAxis.SPACE
+ and self.use_memory_efficient_attention
+ ):
+ pattern = "(batch seq_len) variate n_heads head_dim -> batch
variate seq_len (n_heads head_dim)"
+ elif (
+ self.attention_axis == AttentionAxis.SPACE
+ and not self.use_memory_efficient_attention
+ ):
+ pattern = "(batch seq_len) n_heads variate head_dim -> batch
variate seq_len (n_heads head_dim)"
+
+ return rearrange(output, pattern, batch=batch, variate=variate,
seq_len=seq_len)
+
+ def run_attention(
+ self, attention_mask, q, k, v, seq_pos_offset, dropout, seq_len,
variate
+ ):
+ q_dim_start, q_dim_end = seq_pos_offset, seq_pos_offset + seq_len
+ kv_dim_start, kv_dim_end = 0, (
+ v.shape[1] if self.use_memory_efficient_attention else v.shape[2]
+ )
+ if (
+ self.attention_axis == AttentionAxis.TIME
+ and self.use_memory_efficient_attention
+ ):
+ attention_mask = (
+ attention_mask[..., q_dim_start:q_dim_end,
kv_dim_start:kv_dim_end]
+ if torch.is_tensor(attention_mask)
+ else LowerTriangularMask() if seq_pos_offset == 0 else None
+ )
+ return memory_efficient_attention(
+ q, k, v, attn_bias=attention_mask, p=dropout
+ )
+ elif (
+ self.attention_axis == AttentionAxis.TIME
+ and not self.use_memory_efficient_attention
+ ):
+ attention_mask = (
+ attention_mask[..., q_dim_start:q_dim_end,
kv_dim_start:kv_dim_end]
+ if torch.is_tensor(attention_mask)
+ else None
+ )
+ return scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ attn_mask=attention_mask,
+ dropout_p=dropout,
+ is_causal=(attention_mask is None and seq_pos_offset == 0),
+ )
+ elif (
+ self.attention_axis == AttentionAxis.SPACE
+ and self.use_memory_efficient_attention
+ ):
+ attention_mask = (
+ attention_mask[..., kv_dim_start:kv_dim_end,
kv_dim_start:kv_dim_end]
+ if torch.is_tensor(attention_mask)
+ else None
+ )
+ return memory_efficient_attention(
+ q, k, v, attn_bias=attention_mask, p=dropout
+ )
+ elif (
+ self.attention_axis == AttentionAxis.SPACE
+ and not self.use_memory_efficient_attention
+ ):
+ attention_mask = (
+ attention_mask[..., kv_dim_start:kv_dim_end,
kv_dim_start:kv_dim_end]
+ if torch.is_tensor(attention_mask)
+ else None
+ )
+ return scaled_dot_product_attention(
+ q, k, v, attn_mask=attention_mask, dropout_p=dropout,
is_causal=False
+ )
+
+ def forward(
+ self,
+ layer_idx: int,
+ inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+ attention_mask: Optional[
+ Union[
+ Bool[torch.Tensor, "batch_X_variate n_heads seq_len seq_len"],
+ Bool[torch.Tensor, "batch_X_seq_len n_heads variate variate"],
+ ]
+ ] = None,
+ kv_cache: Optional["KVCache"] = None,
+ ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+ batch_size, variate, seq_len, _ = inputs.shape
+ dropout = self.dropout if self.training else 0.0
+
+ rearranged_inputs = self.rearrange_inputs(inputs)
+ q, k, v = self.get_qkv(rearranged_inputs)
+
+ q, k, v, seq_pos_offset = self.positional_embedding(
+ q, k, v, kv_cache, layer_idx
+ )
+
+ output = self.run_attention(
+ attention_mask, q, k, v, seq_pos_offset, dropout, seq_len, variate
+ )
+
+ output = self.rearrange_output(output, batch_size, variate, seq_len)
+ return self.wO(output)
+
+
+class TimeWiseMultiheadAttention(BaseMultiheadAttention):
+ attention_axis = AttentionAxis.TIME
+
+
+class SpaceWiseMultiheadAttention(BaseMultiheadAttention):
+ attention_axis = AttentionAxis.SPACE
+
+
+MultiHeadAttention = TimeWiseMultiheadAttention | SpaceWiseMultiheadAttention
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py
new file mode 100644
index 00000000000..84fa537e3fc
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/backbone.py
@@ -0,0 +1,258 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from math import ceil
+from typing import NamedTuple, Optional, Type, cast
+
+import torch
+from einops import rearrange, repeat
+from jaxtyping import Bool, Float, Int
+
+from .distribution import DISTRIBUTION_CLASSES_LOOKUP, DistributionOutput
+from .embedding import PatchEmbedding
+from .fusion import Fusion
+from .scaler import scaler_types
+from .transformer import Transformer
+from .util import KVCache
+
+
+class TotoOutput(NamedTuple):
+ """
+ Output of the Toto model. Contains the output distribution, the location
parameters,
+ and the scale parameters.
+ """
+
+ distribution: torch.distributions.Distribution
+ loc: Float[torch.Tensor, "batch variate"]
+ scale: Float[torch.Tensor, "batch variate"]
+
+
+class TotoBackbone(torch.nn.Module):
+ """
+ Toto (Timeseries-Optimized Transformer for Observability) is a
transformer-based model
+ for multivariate time series forecasting.
+ """
+
+ def __init__(
+ self,
+ patch_size: int,
+ stride: int,
+ embed_dim: int,
+ num_layers: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ dropout: float,
+ spacewise_every_n_layers: int,
+ scaler_cls: str,
+ output_distribution_classes: list[str],
+ spacewise_first: bool = True,
+ output_distribution_kwargs: dict | None = None,
+ use_memory_efficient_attention: bool = True,
+ stabilize_with_global: bool = True,
+ scale_factor_exponent: float = 10.0,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.fusion: Optional[Fusion] = None
+ self.num_prepended_tokens: int = 0
+ self.target_variate_label: Optional[torch.nn.Parameter] = None
+ self.exogenous_variate_label: Optional[torch.nn.Parameter] = None
+
+ if scaler_cls in (
+ "<class 'model.scaler.CausalPatchStdMeanScaler'>",
+ "per_variate_causal_patch",
+ ):
+ self.scaler = scaler_types[scaler_cls](
+ patch_size=patch_size,
+ stabilize_with_global=stabilize_with_global,
+ scale_factor_exponent=scale_factor_exponent,
+ )
+ else:
+ self.scaler = scaler_types[scaler_cls]()
+
+ self.patch_embed = PatchEmbedding(patch_size, stride, embed_dim)
+ self.dropout = dropout
+ self.num_layers = num_layers
+ self.use_memory_efficient_attention = use_memory_efficient_attention
+ self.transformer = Transformer(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ num_layers=self.num_layers,
+ mlp_hidden_dim=mlp_hidden_dim,
+ dropout=dropout,
+ spacewise_every_n_layers=spacewise_every_n_layers,
+ spacewise_first=spacewise_first,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ fusion=self.fusion,
+ )
+ self.unembed = torch.nn.Linear(embed_dim, embed_dim * patch_size)
+
+ output_distribution_classes_ = [
+ DISTRIBUTION_CLASSES_LOOKUP[c] for c in output_distribution_classes
+ ]
+ self.output_distribution = output_distribution_classes_[0](
+ embed_dim, **(output_distribution_kwargs or {})
+ )
+
+ def allocate_kv_cache(
+ self,
+ batch_size: int,
+ num_variates: int,
+ max_time_steps: int,
+ device: torch.device,
+ dtype: torch.dtype,
+ ) -> KVCache:
+ return KVCache(
+ batch_size=batch_size,
+ num_variates=num_variates,
+ transformer_layers=list(self.transformer.layers),
+ num_layers=self.num_layers,
+ embed_dim=self.embed_dim,
+ num_heads=cast(int, self.transformer.layers[0].num_heads),
+ max_seq_len=ceil(max_time_steps / self.patch_embed.stride),
+ device=device,
+ dtype=dtype,
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
+ )
+
+ def backbone(
+ self,
+ inputs: Float[torch.Tensor, "batch variate time_steps"],
+ input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"],
+ id_mask: Float[torch.Tensor, "batch #variate time_steps"],
+ kv_cache: Optional[KVCache] = None,
+ scaling_prefix_length: Optional[int] = None,
+ num_exogenous_variables: int = 0,
+ ) -> tuple[
+ Float[torch.Tensor, "batch variates time_steps embed_dim"],
+ Float[torch.Tensor, "batch variates time_steps"],
+ Float[torch.Tensor, "batch variates time_steps"],
+ ]:
+ scaled_inputs, loc, scale = self.scaler(
+ inputs,
+ weights=torch.ones_like(inputs, device=inputs.device),
+ padding_mask=input_padding_mask,
+ prefix_length=scaling_prefix_length,
+ )
+
+ if kv_cache is not None:
+ kv_cache_len_tensor = kv_cache.current_len(0)
+ kv_cache_len = (
+ int(kv_cache_len_tensor)
+ if isinstance(kv_cache_len_tensor, torch.Tensor)
+ else kv_cache_len_tensor
+ )
+ prefix_len = max(
+ 0, self.patch_embed.stride * (kv_cache_len -
self.num_prepended_tokens)
+ )
+
+ scaled_inputs = scaled_inputs[:, :, prefix_len:]
+
+ assert (prefix_len == 0) or (
+ scaled_inputs.shape[-1] == self.patch_embed.stride
+ ), "Must decode one step at a time."
+
+ input_padding_mask = input_padding_mask[:, :, prefix_len:]
+ id_mask = id_mask[:, :, prefix_len:]
+
+ embeddings, reduced_id_mask = self.patch_embed(scaled_inputs, id_mask)
+
+ variate_label_embeds = self.build_variate_label_embeds(
+ num_exogenous_variables, embeddings
+ )
+
+ original_seq_len = embeddings.shape[2]
+ transformed = self.transformer(
+ embeddings,
+ reduced_id_mask,
+ kv_cache,
+ variate_label_embeds=variate_label_embeds,
+ )
+ added_tokens = transformed.shape[2] - original_seq_len
+ if added_tokens > 0:
+ transformed = transformed[:, :, added_tokens:]
+
+ flattened: Float[torch.Tensor, "batch variates new_seq_len embed_dim"]
= (
+ rearrange(
+ self.unembed(transformed),
+ "batch variates seq_len (patch_size embed_dim) -> batch
variates (seq_len patch_size) embed_dim",
+ embed_dim=self.embed_dim,
+ )
+ )
+ return flattened, loc, scale
+
+ def forward(
+ self,
+ inputs: Float[torch.Tensor, "batch variate time_steps"],
+ input_padding_mask: Bool[torch.Tensor, "batch variate time_steps"],
+ id_mask: Float[torch.Tensor, "batch #variate time_steps"],
+ kv_cache: Optional[KVCache] = None,
+ scaling_prefix_length: Optional[int] = None,
+ num_exogenous_variables: int = 0,
+ ) -> TotoOutput:
+ flattened, loc, scale = self.backbone(
+ inputs,
+ input_padding_mask,
+ id_mask,
+ kv_cache,
+ scaling_prefix_length,
+ num_exogenous_variables,
+ )
+
+ return TotoOutput(self.output_distribution(flattened), loc, scale)
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def enable_variate_labels(self) -> None:
+ self.fusion = Fusion()
+ self.num_prepended_tokens = 1
+ self.target_variate_label =
torch.nn.Parameter(torch.randn(self.embed_dim))
+ self.exogenous_variate_label =
torch.nn.Parameter(torch.randn(self.embed_dim))
+ if hasattr(self, "transformer") and self.transformer is not None:
+ self.transformer.fusion = self.fusion
+
+ def build_variate_label_embeds(
+ self,
+ num_exogenous_variables: int,
+ embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+ ) -> Optional[Float[torch.Tensor, "batch variate 1 embed_dim"]]:
+ if self.fusion is None:
+ return None
+
+ assert self.target_variate_label is not None
+ assert self.exogenous_variate_label is not None
+
+ batch_size, num_variates, _, _ = embeddings.shape
+
+ target_variate_label = repeat(
+ self.target_variate_label, "d -> b v 1 d", b=batch_size,
v=num_variates
+ ).to(device=embeddings.device, dtype=embeddings.dtype)
+ exogenous_variate_label = repeat(
+ self.exogenous_variate_label, "d -> b v 1 d", b=batch_size,
v=num_variates
+ ).to(device=embeddings.device, dtype=embeddings.dtype)
+ exog_mask = torch.zeros(
+ 1, num_variates, 1, 1, dtype=torch.bool, device=embeddings.device
+ )
+ if num_exogenous_variables > 0:
+ exog_mask[:, -num_exogenous_variables:] = True
+ return torch.where(exog_mask, exogenous_variate_label,
target_variate_label)
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py
new file mode 100644
index 00000000000..f34bd4afdf0
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/distribution.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from abc import ABC
+
+import torch
+import torch.nn.functional as F
+from torch.distributions import TransformedDistribution
+from torch.distributions.transforms import AffineTransform
+
+
+class AffineTransformed(TransformedDistribution):
+ """
+ A thin wrapper around TransformedDistribution with an AffineTransform,
+ replacing the gluonts.torch.distributions.AffineTransformed dependency.
+ Provides the same interface: mean, variance, sample(), log_prob().
+ """
+
+ def __init__(self, base_distribution, loc=0.0, scale=1.0):
+ super().__init__(base_distribution, AffineTransform(loc=loc,
scale=scale))
+
+ @property
+ def mean(self):
+ # mean(aX + b) = a * mean(X) + b
+ loc = self.transforms[0].loc
+ scale = self.transforms[0].scale
+ return loc + scale * self.base_dist.mean
+
+ # Note: Do NOT override sample() here. TransformedDistribution.sample()
correctly
+ # calls base_dist.sample() (not rsample), which works for
non-reparameterizable
+ # distributions like MixtureSameFamily.
+
+
+class DistributionOutput(ABC, torch.nn.Module):
+ pass
+
+
+class StudentTOutput(DistributionOutput):
+ def __init__(self, embed_dim):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.df = torch.nn.Linear(embed_dim, 1)
+ self.loc_proj = torch.nn.Linear(embed_dim, 1)
+ self.scale_proj = torch.nn.Linear(embed_dim, 1)
+
+ def forward(self, inputs, loc=None, scale=None):
+ eps = torch.finfo(inputs.dtype).eps
+ df = 2.0 + F.softplus(self.df(inputs)).clamp_min(eps).squeeze(-1)
+ base_loc = self.loc_proj(inputs).squeeze(-1)
+ base_scale =
F.softplus(self.scale_proj(inputs)).clamp_min(eps).squeeze(-1)
+
+ base_dist = torch.distributions.StudentT(
+ df, base_loc, base_scale, validate_args=False
+ )
+
+ if loc is not None and scale is not None:
+ return AffineTransformed(base_dist, loc=loc, scale=scale)
+ return base_dist
+
+
+class MixtureOfStudentTsOutput(DistributionOutput):
+ def __init__(self, embed_dim, k_components):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.k_components = k_components
+
+ self.df = torch.nn.Linear(embed_dim, k_components)
+ self.loc_proj = torch.nn.Linear(embed_dim, k_components)
+ self.scale_proj = torch.nn.Linear(embed_dim, k_components)
+ self.mixture_weights = torch.nn.Linear(embed_dim, k_components)
+
+ def forward(self, inputs, loc=None, scale=None):
+ df = 2.0 +
F.softplus(self.df(inputs)).clamp_min(torch.finfo(inputs.dtype).eps)
+ component_loc = self.loc_proj(inputs)
+ component_scale = F.softplus(self.scale_proj(inputs)).clamp_min(
+ torch.finfo(inputs.dtype).eps
+ )
+ logits = self.mixture_weights(inputs)
+ probs = F.softmax(logits, dim=-1)
+ components = torch.distributions.StudentT(
+ df, component_loc, component_scale, validate_args=False
+ )
+ mixture_distribution = torch.distributions.Categorical(probs=probs)
+
+ return torch.distributions.MixtureSameFamily(mixture_distribution,
components)
+
+
+DISTRIBUTION_CLASSES_LOOKUP = {
+ "<class 'model.distribution.StudentTOutput'>": StudentTOutput,
+ "<class 'model.distribution.MixtureOfStudentTsOutput'>":
MixtureOfStudentTsOutput,
+ # Short-form aliases for convenience
+ "student_t": StudentTOutput,
+ "student_t_mixture": MixtureOfStudentTsOutput,
+}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py
new file mode 100644
index 00000000000..fc7eadac9af
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/embedding.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from typing import Optional
+
+import torch
+from jaxtyping import Float, Int, Num
+
+
+def patchify_id_mask(
+ id_mask: Int[torch.Tensor, "batch variate time_steps"], patch_size: int
+) -> Int[torch.Tensor, "batch variate seq_len patch_size"]:
+ patched_id_mask = id_mask.unfold(dimension=-1, size=patch_size,
step=patch_size)
+ patched_id_mask_min = patched_id_mask.min(-1).values
+ patched_id_mask_max = patched_id_mask.max(-1).values
+ assert torch.eq(
+ patched_id_mask_min, patched_id_mask_max
+ ).all(), "Patches cannot span multiple datasets"
+ return patched_id_mask_min
+
+
+class PatchEmbedding(torch.nn.Module):
+ """
+ Multivariate time series patch embedding.
+ Patchifies each variate separately.
+ """
+
+ def __init__(self, patch_size: int, stride: int, embed_dim: int):
+ super().__init__()
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.stride = stride
+ self.projection = torch.nn.Linear(self.patch_size, self.embed_dim)
+
+ def _patchify(
+ self, x: Num[torch.Tensor, "batch variate time_steps"]
+ ) -> Num[torch.Tensor, "batch variate seq_len patch_size"]:
+ return x.unfold(dimension=-1, size=self.patch_size, step=self.stride)
+
+ def forward(
+ self,
+ x: Float[torch.Tensor, "batch #variate time_steps"],
+ id_mask: Float[torch.Tensor, "batch time_steps"],
+ ) -> tuple[
+ Float[torch.Tensor, "batch variate seq_len embed_dim"],
+ Int[torch.Tensor, "batch seq_len"],
+ ]:
+ assert (
+ x.shape[-1] % self.patch_size == 0
+ ), f"Series length ({x.shape=}) must be divisible by
({self.patch_size=})"
+ x_patched: Float[torch.Tensor, "batch variate seq_len patch_size"] = (
+ self._patchify(x)
+ )
+ id_mask_patched: Int[torch.Tensor, "batch variate seq_len patch_size"]
= (
+ self._patchify(id_mask)
+ )
+
+ assert torch.eq(
+ id_mask_patched.min(-1).values, id_mask_patched.max(-1).values
+ ).all(), "Patches cannot span multiple datasets"
+
+ return (
+ self.projection(x_patched),
+ id_mask_patched.min(-1).values,
+ )
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py
new file mode 100644
index 00000000000..024a8bed727
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/feed_forward.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import torch
+import torch.nn.functional as F
+
+
+class SwiGLU(torch.nn.Module):
+ """
+ https://arxiv.org/abs/2002.05202
+ NOTE: x should be 2x the size you want
+ """
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # Note this ordering is unusual, but is done so to match xFormers
+ gate, x = x.chunk(2, dim=-1)
+ return F.silu(gate) * x
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py
new file mode 100644
index 00000000000..cfe364ac91e
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/fusion.py
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from jaxtyping import Float
+
+
+class Fusion(torch.nn.Module):
+ """
+ Prepends variate label embeddings to the input embeddings along the
sequence dimension.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+
+ def forward(
+ self,
+ embeddings: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+ variate_label_embeds: Optional[
+ Float[torch.Tensor, "batch variate 1 embed_dim"]
+ ] = None,
+ ) -> Float[torch.Tensor, "batch variate new_seq_len embed_dim"]:
+
+ if variate_label_embeds is None:
+ return embeddings
+
+ processed_embeddings = F.normalize(variate_label_embeds, p=2, dim=-1)
+
+ return torch.cat(
+ [
+ processed_embeddings.to(
+ dtype=embeddings.dtype, device=embeddings.device,
non_blocking=True
+ ),
+ embeddings,
+ ],
+ dim=2,
+ )
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py
new file mode 100644
index 00000000000..96e62517077
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/rope.py
@@ -0,0 +1,94 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+from typing import Optional
+
+import torch
+from einops import rearrange
+from jaxtyping import Int
+from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb
+from rotary_embedding_torch.rotary_embedding_torch import default
+
+
+def exists(val):
+ return val is not None
+
+
+class TimeAwareRotaryEmbedding(RotaryEmbedding):
+ """
+ A variant of the rotary position embedding that (optionally) uses the time
index
+ to compute the sinusoidal and cosine embeddings. Useful for time series
data.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # If the parent stored `freqs` as a Parameter, remove it and register
as a buffer
+ if hasattr(self, "freqs") and isinstance(self.freqs,
torch.nn.Parameter):
+ freqs_data = self.freqs.data
+ self._parameters.pop("freqs")
+ self.register_buffer("freqs", freqs_data, persistent=False)
+
+ def rotate_queries_and_keys(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ seq_dim: Optional[int] = None,
+ seq_pos: Optional[Int[torch.Tensor, "... seq_len"]] = None,
+ seq_pos_offset: int = 0,
+ ):
+ if seq_dim is None:
+ seq_dim = self.default_seq_dim
+
+ assert self.use_xpos
+ device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim]
+
+ seq = default(seq_pos, self.get_seq_pos(seq_len, dtype=dtype,
device=device))
+ seq = seq + seq_pos_offset
+
+ freqs = self.forward(seq)
+
+ scale = self.get_scale(seq).to(dtype)
+
+ if seq_dim == -3:
+ num_heads = q.shape[-2]
+ freqs = freqs.unsqueeze(1).expand(-1, num_heads, -1)
+ scale = scale.unsqueeze(1).expand(-1, num_heads, -1)
+
+ rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim)
+ rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1,
seq_dim=seq_dim)
+
+ rotated_q = rotated_q.type(q.dtype)
+ rotated_k = rotated_k.type(k.dtype)
+
+ return rotated_q, rotated_k
+
+ def get_scale(
+ self,
+ t: torch.Tensor,
+ ):
+ assert self.use_xpos
+
+ power = (t - t.max(-1).values.unsqueeze(-1) // 2) / self.scale_base
+
+ scale = self.scale ** rearrange(power, "... n -> ... n 1")
+ scale = torch.cat((scale, scale), dim=-1)
+
+ return scale
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py
new file mode 100644
index 00000000000..e640e3ef3a2
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/scaler.py
@@ -0,0 +1,328 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import warnings
+from typing import Tuple
+
+import torch
+from einops import reduce, repeat
+
+
+class Scaler(torch.nn.Module):
+ """
+ Minimal base class replacing gluonts.torch.scaler.Scaler.
+ Provides a __call__ interface for scaling data.
+ """
+
+ pass
+
+
+class StdMeanScaler(Scaler):
+ """
+ Scales data to have zero mean and unit variance along a given dimension.
+ """
+
+ def __init__(
+ self,
+ dim: int = -1,
+ keepdim: bool = True,
+ minimum_scale: float = 1e-3,
+ ) -> None:
+ super().__init__()
+ self.dim = dim
+ self.keepdim = keepdim
+ self.minimum_scale = minimum_scale
+
+ def __call__(
+ self,
+ data: torch.Tensor,
+ padding_mask: torch.Tensor,
+ weights: torch.Tensor,
+ prefix_length: int | None = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ assert data.shape == weights.shape, "data and weights must have same
shape"
+ with torch.no_grad():
+ if prefix_length is not None:
+ prefix_mask = torch.zeros_like(weights)
+ prefix_mask[..., :prefix_length] = 1.0
+ weights = weights * prefix_mask
+
+ weights = weights * padding_mask
+
+ try:
+ high_precision_data = data.to(torch.float64)
+ except TypeError:
+ warnings.warn(
+ f"Float64 is not supported by device {data.device}. "
+ "Using float32 instead for accumulating denominator in
input scaler. "
+ "This may lead to overflow issues if the data contains
extreme values.",
+ RuntimeWarning,
+ )
+ high_precision_data = data.to(torch.float32)
+
+ denominator = (
+ weights.sum(self.dim, keepdim=self.keepdim)
+ .clamp_min(1.0)
+ .to(high_precision_data.dtype)
+ )
+ means = (high_precision_data * weights).sum(
+ self.dim, keepdim=self.keepdim
+ ) / denominator
+ means = torch.nan_to_num(means)
+
+ variance = (((high_precision_data - means) * weights) ** 2).sum(
+ self.dim, keepdim=self.keepdim
+ ) / denominator
+ scale = torch.sqrt(variance + self.minimum_scale).to(data.dtype)
+ loc = means.to(data.dtype)
+
+ return (data - loc) / scale, loc, scale
+
+
+def compute_causal_statistics(
+ data: torch.Tensor,
+ weights: torch.Tensor,
+ padding_mask: torch.Tensor,
+ dim: int,
+ minimum_scale: float,
+ use_bessel_correction: bool = True,
+ stabilize_with_global: bool = False,
+ scale_factor_exponent: float = 10.0,
+ prefix_length: int | None = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert dim == -1, "compute_causal_statistics only supports dim=-1 (last
dimension)"
+
+ with torch.no_grad():
+ weights = weights * padding_mask
+
+ try:
+ high_precision_data = data.to(torch.float64)
+ high_precision_weights = weights.to(torch.float64)
+ except TypeError:
+ warnings.warn(
+ f"Float64 is not supported by device {data.device}. "
+ "Using float32 instead for causal scaler calculations.",
+ RuntimeWarning,
+ )
+ high_precision_data = data.to(torch.float32)
+ high_precision_weights = weights.to(torch.float32)
+
+ prev_deterministic = torch.are_deterministic_algorithms_enabled()
+ if prev_deterministic and data.device.type == "cuda":
+ torch.use_deterministic_algorithms(False)
+
+ try:
+ weighted_data = high_precision_weights * high_precision_data
+
+ cum_weights = torch.cumsum(high_precision_weights, dim=dim)
+ cum_values = torch.cumsum(weighted_data, dim=dim)
+
+ denominator = cum_weights.clamp_min(1.0)
+ causal_means = cum_values / denominator
+
+ shifted_means = torch.zeros_like(causal_means)
+ shifted_means[..., 1:] = causal_means[..., :-1]
+
+ delta = high_precision_data - shifted_means
+ increment = (
+ delta * (high_precision_data - causal_means) *
high_precision_weights
+ )
+ m_2 = torch.cumsum(increment, dim=dim)
+
+ if use_bessel_correction:
+ causal_variance = m_2 / torch.clamp(denominator - 1.0, min=1.0)
+ else:
+ causal_variance = m_2 / denominator
+
+ causal_scale = torch.sqrt(causal_variance + minimum_scale)
+
+ if stabilize_with_global:
+ if prefix_length is not None:
+ prefix_mask = torch.zeros_like(weights)
+ prefix_mask[..., :prefix_length] = 1.0
+ weighted_data = weighted_data * prefix_mask
+ weights = weights * prefix_mask
+ padding_mask = padding_mask * prefix_mask
+
+ scale_factor_min = 10.0 ** (-scale_factor_exponent)
+ scale_factor_max = 10.0**scale_factor_exponent
+
+ global_denominator = (
+ (weights * padding_mask).sum(dim,
keepdim=True).clamp_min(1.0)
+ )
+ global_means = (weighted_data).sum(
+ dim, keepdim=True
+ ) / global_denominator
+ global_means = torch.nan_to_num(global_means)
+
+ global_variance = (
+ ((high_precision_data - global_means) * weights *
padding_mask) ** 2
+ ).sum(dim, keepdim=True) / global_denominator
+ global_scale = torch.sqrt(global_variance + minimum_scale)
+
+ expanded_global_scale = global_scale.expand_as(causal_scale)
+ min_allowed_scale = expanded_global_scale * scale_factor_min
+ max_allowed_scale = expanded_global_scale * scale_factor_max
+
+ causal_scale = torch.clamp(
+ causal_scale,
+ min=torch.max(
+ torch.tensor(minimum_scale,
device=causal_scale.device),
+ min_allowed_scale,
+ ),
+ max=max_allowed_scale,
+ )
+
+ causal_means = causal_means.to(data.dtype)
+ causal_scale = causal_scale.to(data.dtype)
+
+ finally:
+ if prev_deterministic and data.device.type == "cuda":
+ torch.use_deterministic_algorithms(True)
+
+ return causal_means, causal_scale
+
+
+class CausalStdMeanScaler(Scaler):
+ def __init__(
+ self,
+ dim: int = -1,
+ minimum_scale: float = 0.1,
+ use_bessel_correction: bool = True,
+ stabilize_with_global: bool = False,
+ scale_factor_exponent: float = 10.0,
+ ) -> None:
+ super().__init__()
+ assert dim == -1, "CausalStdMeanScaler only supports dim=-1 (last
dimension)"
+ self.dim = dim
+ self.minimum_scale = minimum_scale
+ self.use_bessel_correction = use_bessel_correction
+ self.stabilize_with_global = stabilize_with_global
+ self.scale_factor_exponent = scale_factor_exponent
+
+ def __call__(
+ self,
+ data: torch.Tensor,
+ padding_mask: torch.Tensor,
+ weights: torch.Tensor,
+ prefix_length: int | None = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ assert data.shape == weights.shape, "data and weights must have same
shape"
+ assert (
+ len(data.shape) == 3
+ ), "Input data must have shape [batch, variates, time_steps]"
+
+ causal_means, causal_scale = compute_causal_statistics(
+ data,
+ weights,
+ padding_mask,
+ self.dim,
+ self.minimum_scale,
+ self.use_bessel_correction,
+ self.stabilize_with_global,
+ self.scale_factor_exponent,
+ prefix_length,
+ )
+
+ scaled_data = (data - causal_means) / causal_scale
+
+ return scaled_data, causal_means, causal_scale
+
+
+class CausalPatchStdMeanScaler(Scaler):
+ def __init__(
+ self,
+ dim: int = -1,
+ patch_size: int = 32,
+ minimum_scale: float = 0.1,
+ use_bessel_correction: bool = True,
+ stabilize_with_global: bool = False,
+ scale_factor_exponent: float = 10.0,
+ ) -> None:
+ super().__init__()
+ assert (
+ dim == -1
+ ), "CausalPatchStdMeanScaler only supports dim=-1 (last dimension)"
+ self.dim = dim
+ self.patch_size = patch_size
+ self.minimum_scale = minimum_scale
+ self.use_bessel_correction = use_bessel_correction
+ self.stabilize_with_global = stabilize_with_global
+ self.scale_factor_exponent = scale_factor_exponent
+
+ def __call__(
+ self,
+ data: torch.Tensor,
+ padding_mask: torch.Tensor,
+ weights: torch.Tensor,
+ prefix_length: int | None = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ assert data.shape == weights.shape, "data and weights must have same
shape"
+ assert (
+ len(data.shape) == 3
+ ), "Input data must have shape [batch, variates, time_steps]"
+
+ with torch.no_grad():
+ time_steps = data.shape[-1]
+ assert (
+ time_steps % self.patch_size == 0
+ ), f"Time steps ({time_steps}) must be divisible by patch size
({self.patch_size})"
+
+ causal_means, causal_scale = compute_causal_statistics(
+ data,
+ weights,
+ padding_mask,
+ -1,
+ self.minimum_scale,
+ self.use_bessel_correction,
+ self.stabilize_with_global,
+ self.scale_factor_exponent,
+ prefix_length,
+ )
+
+ means_unfolded = causal_means.unfold(-1, self.patch_size,
self.patch_size)
+ scales_unfolded = causal_scale.unfold(-1, self.patch_size,
self.patch_size)
+
+ patch_stats_means = means_unfolded[..., -1]
+ patch_stats_scales = scales_unfolded[..., -1]
+
+ patch_means = repeat(
+ patch_stats_means, "b v p -> b v (p s)", s=self.patch_size
+ )
+ patch_scales = repeat(
+ patch_stats_scales, "b v p -> b v (p s)", s=self.patch_size
+ )
+
+ scaled_data = (data - patch_means) / patch_scales
+
+ return scaled_data, patch_means, patch_scales
+
+
+# for deserialization of SafeTensors checkpoints
+scaler_types = {
+ "<class 'model.scaler.StdMeanScaler'>": StdMeanScaler,
+ "<class 'model.scaler.CausalStdMeanScaler'>": CausalStdMeanScaler,
+ "<class 'model.scaler.CausalPatchStdMeanScaler'>":
CausalPatchStdMeanScaler,
+ # Short aliases used in config.json
+ "per_variate": StdMeanScaler,
+ "per_variate_causal": CausalStdMeanScaler,
+ "per_variate_causal_patch": CausalPatchStdMeanScaler,
+}
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py
new file mode 100644
index 00000000000..61595334171
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/toto.py
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import json
+import os
+import re
+from pathlib import Path
+from typing import Dict, Optional, Union
+
+import safetensors.torch as safetorch
+import torch
+
+from .attention import XFORMERS_AVAILABLE
+from .backbone import TotoBackbone
+from .transformer import XFORMERS_SWIGLU_AVAILABLE
+
+
+class Toto(torch.nn.Module):
+ """
+ PyTorch module for Toto (Timeseries-Optimized Transformer for
Observability).
+ This class is used internally for checkpoint loading logic.
+ """
+
+ def __init__(
+ self,
+ patch_size: int,
+ stride: int,
+ embed_dim: int,
+ num_layers: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ dropout: float,
+ spacewise_every_n_layers: int,
+ scaler_cls: str,
+ output_distribution_classes: list[str],
+ spacewise_first: bool = True,
+ output_distribution_kwargs: dict | None = None,
+ use_memory_efficient_attention: bool = True,
+ stabilize_with_global: bool = True,
+ scale_factor_exponent: float = 10.0,
+ **model_kwargs,
+ ):
+ super().__init__()
+ self.model = TotoBackbone(
+ patch_size=patch_size,
+ stride=stride,
+ embed_dim=embed_dim,
+ num_layers=num_layers,
+ num_heads=num_heads,
+ mlp_hidden_dim=mlp_hidden_dim,
+ dropout=dropout,
+ spacewise_every_n_layers=spacewise_every_n_layers,
+ scaler_cls=scaler_cls,
+ output_distribution_classes=output_distribution_classes,
+ spacewise_first=spacewise_first,
+ output_distribution_kwargs=output_distribution_kwargs,
+ use_memory_efficient_attention=use_memory_efficient_attention,
+ stabilize_with_global=stabilize_with_global,
+ scale_factor_exponent=scale_factor_exponent,
+ )
+
+ @classmethod
+ def load_from_checkpoint(
+ cls,
+ checkpoint_path,
+ map_location: str = "cpu",
+ strict=True,
+ **model_kwargs,
+ ):
+ if os.path.isdir(checkpoint_path):
+ safetensors_file = os.path.join(checkpoint_path,
"model.safetensors")
+ else:
+ safetensors_file = checkpoint_path
+
+ if os.path.exists(safetensors_file):
+ model_state = safetorch.load_file(safetensors_file,
device=map_location)
+ else:
+ raise FileNotFoundError(
+ f"Model checkpoint not found at: {safetensors_file}"
+ )
+
+ config_file = os.path.join(checkpoint_path, "config.json")
+ config = {}
+ if os.path.exists(config_file):
+ with open(config_file, "r") as f:
+ config = json.load(f)
+
+ config.update(model_kwargs)
+
+ remapped_state_dict = cls._map_state_dict_keys(
+ model_state,
+ XFORMERS_SWIGLU_AVAILABLE
+ and not config.get("pre_xformers_checkpoint", False),
+ )
+
+ if not XFORMERS_AVAILABLE and config.get(
+ "use_memory_efficient_attention", True
+ ):
+ config["use_memory_efficient_attention"] = False
+
+ instance = cls(**config)
+ instance.to(map_location)
+
+ filtered_remapped_state_dict = {
+ k: v
+ for k, v in remapped_state_dict.items()
+ if k in instance.state_dict() and not
k.endswith("rotary_emb.freqs")
+ }
+
+ instance.load_state_dict(filtered_remapped_state_dict, strict=strict)
+ return instance
+
+ @staticmethod
+ def _map_state_dict_keys(state_dict, use_fused_swiglu):
+ if use_fused_swiglu:
+ remap_keys = {
+ "mlp.0.weight": "mlp.0.w12.weight",
+ "mlp.0.bias": "mlp.0.w12.bias",
+ "mlp.2.weight": "mlp.0.w3.weight",
+ "mlp.2.bias": "mlp.0.w3.bias",
+ }
+ else:
+ remap_keys = {
+ "mlp.0.w12.weight": "mlp.0.weight",
+ "mlp.0.w12.bias": "mlp.0.bias",
+ "mlp.0.w3.weight": "mlp.2.weight",
+ "mlp.0.w3.bias": "mlp.2.bias",
+ }
+
+ def replace_key(text):
+ for pattern, replacement in remap_keys.items():
+ text = re.sub(pattern, replacement, text)
+ return text
+
+ return {replace_key(k): v for k, v in state_dict.items()}
+
+ @property
+ def device(self):
+ return next(self.model.parameters()).device
diff --git
a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py
new file mode 100644
index 00000000000..58220c30e62
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/transformer.py
@@ -0,0 +1,318 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import warnings
+from typing import Literal, Optional, Union, cast
+
+import torch
+import torch.nn.functional as F
+from einops import rearrange
+from jaxtyping import Bool, Float, Int
+from rotary_embedding_torch import RotaryEmbedding
+
+from .attention import (
+ AttentionAxis,
+ MultiHeadAttention,
+ SpaceWiseMultiheadAttention,
+ TimeWiseMultiheadAttention,
+)
+from .feed_forward import SwiGLU
+from .fusion import Fusion
+from .rope import TimeAwareRotaryEmbedding
+from .util import KVCache, RMSNorm, make_batched_block_mask
+
+try:
+ from xformers.ops.swiglu_op import SwiGLU as SwiGLU_fused
+
+ XFORMERS_SWIGLU_AVAILABLE = True
+except ImportError:
+ warnings.warn(
+ "xFormers fused SwiGLU kernel not found. "
+ "Using native PyTorch implementation for feed-forward layers.",
+ ImportWarning,
+ )
+ XFORMERS_SWIGLU_AVAILABLE = False
+
+
+class TransformerLayer(torch.nn.Module):
+ embed_dim: int
+ num_heads: int
+ mlp_hidden_dim: int
+ dropout: float
+ attention_axis: AttentionAxis
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ dropout: float,
+ rotary_emb: RotaryEmbedding = None,
+ attention_axis: AttentionAxis = AttentionAxis.TIME,
+ RMS_norm: bool = True,
+ use_memory_efficient_attention: bool = True,
+ ):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.mlp_hidden_dim = mlp_hidden_dim
+ self.dropout = dropout
+ self.attention_axis = attention_axis
+
+ if RMS_norm:
+ self.norm1: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim)
+ self.norm2: Union[RMSNorm, torch.nn.LayerNorm] = RMSNorm(embed_dim)
+ else:
+ self.norm1 = torch.nn.LayerNorm(embed_dim)
+ self.norm2 = torch.nn.LayerNorm(embed_dim)
+
+ self.attention: MultiHeadAttention
+
+ if attention_axis == AttentionAxis.TIME:
+ self.attention = TimeWiseMultiheadAttention(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ rotary_emb=rotary_emb,
+ use_memory_efficient_attention=use_memory_efficient_attention,
+ )
+ elif attention_axis == AttentionAxis.SPACE:
+ self.attention = SpaceWiseMultiheadAttention(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ dropout=dropout,
+ rotary_emb=None,
+ use_memory_efficient_attention=use_memory_efficient_attention,
+ )
+ else:
+ raise ValueError("Invalid attention axis")
+
+ if XFORMERS_SWIGLU_AVAILABLE:
+ self.mlp = torch.nn.Sequential(
+ SwiGLU_fused(in_features=embed_dim,
hidden_features=mlp_hidden_dim),
+ torch.nn.Dropout(dropout),
+ )
+ else:
+ self.mlp = torch.nn.Sequential(
+ torch.nn.Linear(embed_dim, 2 * mlp_hidden_dim),
+ SwiGLU(),
+ torch.nn.Linear(mlp_hidden_dim, embed_dim),
+ torch.nn.Dropout(dropout),
+ )
+
+ def forward(
+ self,
+ layer_idx: int,
+ inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+ attention_mask: Optional[
+ Union[
+ Bool[torch.Tensor, "batch seq_len variate variate"],
+ Bool[torch.Tensor, "batch #variate seq_len seq_len"],
+ ]
+ ] = None,
+ kv_cache: Optional[KVCache] = None,
+ ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+ pre_norm_1 = self.norm1(inputs)
+ hidden_state = (
+ inputs
+ + self.attention(
+ layer_idx, pre_norm_1, attention_mask, kv_cache
+ ).contiguous()
+ )
+
+ pre_norm_2 = self.norm2(hidden_state)
+ return hidden_state + self.mlp(pre_norm_2)
+
+
+class Transformer(torch.nn.Module):
+ def __init__(
+ self,
+ num_layers: int,
+ embed_dim: int,
+ num_heads: int,
+ mlp_hidden_dim: int,
+ dropout: float,
+ spacewise_every_n_layers: int,
+ spacewise_first: bool,
+ use_memory_efficient_attention: bool = True,
+ *,
+ fusion: Optional[Fusion] = None,
+ ):
+ super().__init__()
+
+ assert (
+ embed_dim % num_heads == 0
+ ), "Embedding dimension must be divisible by number of heads."
+
+ self.rotary_emb = TimeAwareRotaryEmbedding(
+ embed_dim // num_heads,
+ use_xpos=True,
+ cache_if_possible=True,
+ seq_before_head_dim=use_memory_efficient_attention,
+ )
+ attention_axes = self._get_layer_types(
+ num_layers, spacewise_every_n_layers, spacewise_first
+ )
+
+ self.use_memory_efficient_attention = use_memory_efficient_attention
+ self.fusion = fusion
+
+ self.layers = torch.nn.ModuleList(
+ [
+ TransformerLayer(
+ embed_dim=embed_dim,
+ num_heads=num_heads,
+ mlp_hidden_dim=mlp_hidden_dim,
+ dropout=dropout,
+ rotary_emb=self.rotary_emb,
+ attention_axis=attention_axes[i],
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
+ )
+ for i in range(num_layers)
+ ]
+ )
+
+ def _get_mask(
+ self,
+ num_heads: int,
+ dtype: torch.dtype,
+ id_mask: Optional[torch.Tensor] = None,
+ ) -> Union[
+ Bool[torch.Tensor, "batch num_heads seq_len seq_len"],
+ Float[torch.Tensor, "batch num_heads seq_len seq_len"],
+ Bool[torch.Tensor, "batch num_heads variate variate"],
+ Float[torch.Tensor, "batch num_heads variate variate"],
+ ]:
+ if id_mask is None:
+ raise ValueError("id_mask must be provided for spacewise masks.")
+
+ mask = make_batched_block_mask(id_mask.transpose(-1, -2))
+
+ if self.use_memory_efficient_attention:
+ mask = self._pad_to_multiple(mask)
+ mask = (
+ mask.float()
+ .masked_fill(~mask, float("-inf"))
+ .masked_fill(mask, 0.0)
+ .to(dtype)
+ )
+
+ mask = rearrange(
+ mask,
+ "batch seq_len variate1 variate2 -> (batch seq_len) 1 variate1
variate2",
+ )
+ return mask.expand(-1, num_heads, -1, -1).contiguous()
+
+ def _pad_to_multiple(
+ self,
+ tensor: torch.Tensor,
+ multiple: int = 8,
+ causal: bool = False,
+ ) -> torch.Tensor:
+ pad_amount = (multiple - tensor.shape[-1] % multiple) % multiple
+ if pad_amount > 0:
+ new_size = tensor.shape[-1] + pad_amount
+ if causal:
+ full_mask = torch.tril(
+ torch.ones(
+ (new_size, new_size), dtype=tensor.dtype,
device=tensor.device
+ )
+ )
+ full_mask[: tensor.shape[-1], : tensor.shape[-1]] = tensor
+ tensor = full_mask
+ else:
+ tensor = F.pad(tensor, (0, pad_amount, 0, pad_amount))
+ return tensor
+
+ def _get_layer_types(
+ self,
+ num_layers: int,
+ spacewise_every_n_layers: int,
+ spacewise_first: bool,
+ ) -> list[AttentionAxis]:
+ if spacewise_every_n_layers == -1:
+ return [AttentionAxis.TIME] * num_layers
+ assert num_layers % spacewise_every_n_layers == 0
+
+ block = [AttentionAxis.TIME] * (spacewise_every_n_layers - 1)
+
+ if spacewise_first:
+ block = [AttentionAxis.SPACE] + block
+ else:
+ block = block + [AttentionAxis.SPACE]
+
+ return block * (num_layers // spacewise_every_n_layers)
+
+ def forward(
+ self,
+ inputs: Float[torch.Tensor, "batch variate seq_len embed_dim"],
+ id_mask: Float[torch.Tensor, "batch #variate seq_len"],
+ kv_cache: Optional[KVCache] = None,
+ variate_label_embeds: Optional[
+ Float[torch.Tensor, "batch variate 1 embed_dim"]
+ ] = None,
+ ) -> Float[torch.Tensor, "batch variate seq_len embed_dim"]:
+
+ if self.fusion is not None and variate_label_embeds is not None:
+ should_apply_fusion = True
+ if kv_cache is not None:
+ kv_len_tensor = kv_cache.current_len(0)
+ kv_len = (
+ int(kv_len_tensor)
+ if isinstance(kv_len_tensor, torch.Tensor)
+ else kv_len_tensor
+ )
+ should_apply_fusion = kv_len == 0
+ if should_apply_fusion:
+ inputs = self.fusion(inputs,
variate_label_embeds=variate_label_embeds)
+
+ batch, _, seq_len, _ = inputs.shape
+
+ if id_mask is not None and id_mask.shape[-1] != seq_len:
+ added = int(seq_len - id_mask.shape[-1])
+ if added > 0:
+ pad_slice = id_mask[..., :1]
+ id_mask = torch.cat([pad_slice.expand(-1, -1, added),
id_mask], dim=-1)
+
+ seq_len = (kv_cache.seq_len(1) if kv_cache else 0) + seq_len
+
+ num_heads: int = cast(int, self.layers[0].num_heads)
+
+ timewise_attention_mask = None
+
+ spacewise_attention_mask = self._get_mask(
+ num_heads=num_heads,
+ dtype=inputs.dtype,
+ id_mask=id_mask,
+ )
+
+ for layer_idx, layer in enumerate(self.layers):
+ inputs = layer(
+ layer_idx,
+ inputs,
+ (
+ timewise_attention_mask
+ if layer.attention_axis == AttentionAxis.TIME
+ else spacewise_attention_mask
+ ),
+ kv_cache,
+ )
+ return inputs
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py
new file mode 100644
index 00000000000..d913329e7e8
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/model/util.py
@@ -0,0 +1,251 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# This file includes code derived from DataDog/toto
+# (https://github.com/DataDog/toto), licensed under the Apache-2.0 License.
+# Copyright 2025 Datadog, Inc.
+
+import warnings
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, List, Optional, TypeAlias, Union
+
+import torch
+from einops import rearrange
+from jaxtyping import Float, Int
+
+from .attention import TimeWiseMultiheadAttention
+
+if TYPE_CHECKING:
+ from .transformer import TransformerLayer
+
+try:
+ from xformers import _is_triton_available
+ from xformers.ops.rmsnorm import rms_norm, rms_norm_add
+
+ XFORMERS_RMSNORM_AVAILABLE = True
+except ImportError:
+ warnings.warn(
+ "xFormers fused RMSNorm implementation not available. Will not use "
+ "optimized kernel for inference.",
+ ImportWarning,
+ )
+
+ def _is_triton_available():
+ return False
+
+ XFORMERS_RMSNORM_AVAILABLE = False
+
+
+class RMSNorm(torch.nn.Module):
+ """
+ Wraps xFormers' rms_norm for eval/frozen mode, and does a Python fallback
for train mode.
+ """
+
+ def __init__(self, dim: int, include_weight: bool = True, eps: float =
1e-8):
+ super(RMSNorm, self).__init__()
+ self.eps = eps
+ if include_weight:
+ self.scale: Optional[torch.nn.Parameter] = torch.nn.Parameter(
+ torch.ones(dim)
+ )
+ else:
+ self.scale = None
+
+ def forward(self, x: torch.Tensor):
+ if (
+ (
+ (not self.training)
+ or (self.scale is not None and not self.scale.requires_grad)
+ )
+ and XFORMERS_RMSNORM_AVAILABLE
+ and _is_triton_available()
+ ):
+ return rms_norm(x, self.scale, self.eps)
+
+ x_normed = x / torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) +
self.eps)
+ return x_normed if self.scale is None else x_normed * self.scale
+
+ def increment_and_forward_(self, x: torch.Tensor, y: torch.Tensor):
+ if (not self.training) or (
+ self.scale is not None and not self.scale.requires_grad
+ ):
+ return rms_norm_add(x, y, self.scale, self.eps)
+ return self.forward(x + y)
+
+
+def make_batched_block_mask(t: torch.Tensor) -> torch.Tensor:
+ unsqueezed = rearrange(t, "... d -> ... 1 d")
+ return unsqueezed == unsqueezed.transpose(-1, -2)
+
+
+K: TypeAlias = Float[
+ torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"
+]
+V: TypeAlias = Float[
+ torch.Tensor, "batch_size_X_num_variates num_heads seq_len head_dim"
+]
+KV: TypeAlias = tuple[K, V]
+
+
+@dataclass
+class KVCache:
+ """
+ Key/Value cache for storing intermediate attention values during multistep
inference.
+ Only stores KV cache for timewise layers, skipping spacewise layers.
+ """
+
+ batch_size: int
+ num_variates: int
+ transformer_layers: List["TransformerLayer"]
+ num_layers: int
+ embed_dim: int
+ num_heads: int
+ max_seq_len: int
+ device: torch.device = torch.device("cpu")
+ dtype: torch.dtype = torch.float32
+ use_memory_efficient_attention: bool = True
+
+ _keys: Union[
+ Float[
+ torch.Tensor,
+ "time_layer_count batch_size_X_num_variates max_seq_len num_heads
head_dim",
+ ],
+ Float[
+ torch.Tensor,
+ "time_layer_count batch_size_X_num_variates num_heads max_seq_len
head_dim",
+ ],
+ ] = field(init=False)
+
+ _values: Union[
+ Float[
+ torch.Tensor,
+ "time_layer_count batch_size_X_num_variates max_seq_len num_heads
head_dim",
+ ],
+ Float[
+ torch.Tensor,
+ "time_layer_count batch_size_X_num_variates num_heads max_seq_len
head_dim",
+ ],
+ ] = field(init=False)
+
+ _current_idx: Int[torch.Tensor, "time_layer_count"] = field(init=False)
+ _layer_cache_map: Int[torch.Tensor, "num_layers"] = field(init=False)
+
+ def __post_init__(self):
+ assert (
+ self.embed_dim % self.num_heads == 0
+ ), "embed_dim must be divisible by num_heads"
+ head_dim = self.embed_dim // self.num_heads
+
+ time_layer_indices = [
+ i
+ for i in range(self.num_layers)
+ if isinstance(
+ self.transformer_layers[i].attention,
TimeWiseMultiheadAttention
+ )
+ ]
+
+ time_layer_count = max(1, len(time_layer_indices))
+ if self.use_memory_efficient_attention:
+ shape = (
+ time_layer_count,
+ self.batch_size * self.num_variates,
+ self.max_seq_len,
+ self.num_heads,
+ head_dim,
+ )
+ else:
+ shape = (
+ time_layer_count,
+ self.batch_size * self.num_variates,
+ self.num_heads,
+ self.max_seq_len,
+ head_dim,
+ )
+ self._keys = torch.zeros(shape, device=self.device, dtype=self.dtype)
+ self._values = torch.zeros_like(self._keys)
+ self._current_idx = torch.zeros(
+ time_layer_count, device=self.device, dtype=torch.int
+ )
+ self._layer_cache_map = torch.zeros(
+ (self.num_layers,), dtype=torch.int, device=self.device
+ )
+ for cache_idx, layer_idx in enumerate(time_layer_indices):
+ self._layer_cache_map[layer_idx] = int(cache_idx)
+
+ def __getitem__(self, layer_idx: int) -> KV:
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
+ end_idx = int(self._current_idx[cache_idx].item())
+
+ if self.use_memory_efficient_attention:
+ return (
+ self._keys[cache_idx, :, :end_idx, :, :],
+ self._values[cache_idx, :, :end_idx, :, :],
+ )
+ else:
+ return (
+ self._keys[cache_idx, :, :, :end_idx, :],
+ self._values[cache_idx, :, :, :end_idx, :],
+ )
+
+ def current_len(self, cache_idx: int) -> int:
+ return (
+ int(self._current_idx[cache_idx].item())
+ if self._current_idx.numel() > 0
+ else 0
+ )
+
+ def seq_len(self, layer_idx: int) -> int:
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
+ return self.current_len(cache_idx)
+
+ def append(self, layer_idx: int, kv: KV):
+ cache_idx = int(self._layer_cache_map[layer_idx].item())
+ keys, values = kv
+
+ assert keys.shape == values.shape, "keys and values must have the same
shape"
+ assert (
+ keys.shape[0] == self.batch_size * self.num_variates
+ ), "keys and values must have batch_size * num_variates as their first
dimension"
+
+ if self.use_memory_efficient_attention:
+ assert keys.shape[2] == self.num_heads
+ else:
+ assert keys.shape[1] == self.num_heads
+ assert keys.shape[3] == self.embed_dim // self.num_heads
+
+ start_idx = self._current_idx[cache_idx]
+ if self.use_memory_efficient_attention:
+ end_idx = start_idx + keys.shape[1]
+ else:
+ end_idx = start_idx + keys.shape[2]
+ assert (
+ end_idx <= self.max_seq_len
+ ), f"max_seq_len exceeded {end_idx} > {self.max_seq_len}, keys.shape:
{keys.shape}"
+
+ if self.use_memory_efficient_attention:
+ self._keys[cache_idx, :, start_idx:end_idx, :, :] = keys
+ self._values[cache_idx, :, start_idx:end_idx, :, :] = values
+ else:
+ self._keys[cache_idx, :, :, start_idx:end_idx, :] = keys
+ self._values[cache_idx, :, :, start_idx:end_idx, :] = values
+
+ self._current_idx[cache_idx] = end_idx
+
+ def reset(self):
+ self._keys.zero_()
+ self._values.zero_()
+ self._current_idx.zero_()
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py
new file mode 100644
index 00000000000..08fda1c3c72
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/modeling_toto.py
@@ -0,0 +1,167 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import json
+import os
+
+import safetensors.torch as safetorch
+from transformers import PreTrainedModel
+
+from iotdb.ainode.core.log import Logger
+
+from .configuration_toto import TotoConfig
+from .model.attention import XFORMERS_AVAILABLE
+from .model.backbone import TotoBackbone
+from .model.toto import Toto
+from .model.transformer import XFORMERS_SWIGLU_AVAILABLE
+
+logger = Logger()
+
+
+class TotoPreTrainedModel(PreTrainedModel):
+ """Abstract base class for all Toto model variants."""
+
+ config_class = TotoConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = False
+
+ def _init_weights(self, module):
+ # Weights are loaded from the pretrained checkpoint; no random
initialisation needed.
+ pass
+
+
+class TotoForPrediction(TotoPreTrainedModel):
+ """
+ Toto (Timeseries-Optimized Transformer for Observability) model for time
series prediction.
+
+ Integrates the Toto backbone with AINode's model loading mechanism using
the
+ transformers PreTrainedModel interface. Weights are loaded directly from
the
+ Datadog/Toto-Open-Base-1.0 safetensors checkpoint.
+
+ The backbone is stored as ``self.model`` so that safetensors key prefixes
+ (``model.*``) map directly to parameters without any renaming.
+
+ Reference: https://huggingface.co/Datadog/Toto-Open-Base-1.0
+ """
+
+ def __init__(self, config: TotoConfig):
+ super().__init__(config)
+ # Backbone stored as self.model so safetensors keys (model.*) match
directly.
+ self.model = TotoBackbone(
+ patch_size=config.patch_size,
+ stride=config.stride,
+ embed_dim=config.embed_dim,
+ num_layers=config.num_layers,
+ num_heads=config.num_heads,
+ mlp_hidden_dim=config.mlp_hidden_dim,
+ dropout=config.dropout,
+ spacewise_every_n_layers=config.spacewise_every_n_layers,
+ scaler_cls=config.scaler_cls,
+ output_distribution_classes=config.output_distribution_classes,
+ output_distribution_kwargs=config.output_distribution_kwargs,
+ spacewise_first=config.spacewise_first,
+
use_memory_efficient_attention=config.use_memory_efficient_attention,
+ stabilize_with_global=config.stabilize_with_global,
+ scale_factor_exponent=config.scale_factor_exponent,
+ )
+ self.post_init()
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
+ """
+ Load TotoForPrediction from a local directory containing
``config.json``
+ and ``model.safetensors``.
+
+ This override is required because:
+ 1. The safetensors file uses legacy SwiGLU key names that need
remapping.
+ 2. The config uses class-path strings for ``scaler_cls`` and
+ ``output_distribution_classes`` that must not be filtered out.
+
+ Args:
+ pretrained_model_name_or_path (str): Path to a local directory.
+ **kwargs: Extra key/value pairs merged into the config before
construction.
+
+ Returns:
+ TotoForPrediction: Fully initialised and weight-loaded model in
eval mode.
+ """
+ if os.path.isdir(pretrained_model_name_or_path):
+ config_file = os.path.join(pretrained_model_name_or_path,
"config.json")
+ safetensors_file = os.path.join(
+ pretrained_model_name_or_path, "model.safetensors"
+ )
+ else:
+ raise ValueError(
+ f"pretrained_model_name_or_path must be a local directory, "
+ f"got: {pretrained_model_name_or_path}"
+ )
+
+ # ── Load config ──────────────────────────────────────────────────────
+ config_dict: dict = {}
+ if os.path.exists(config_file):
+ with open(config_file, "r") as f:
+ config_dict = json.load(f)
+ config_dict.update(kwargs)
+
+ # Disable xFormers memory-efficient attention if the library is absent.
+ if not XFORMERS_AVAILABLE and config_dict.get(
+ "use_memory_efficient_attention", True
+ ):
+ config_dict["use_memory_efficient_attention"] = False
+
+ config = TotoConfig(**config_dict)
+
+ # ── Instantiate model
─────────────────────────────────────────────────
+ instance = cls(config)
+
+ # ── Load safetensors weights
──────────────────────────────────────────
+ if not os.path.exists(safetensors_file):
+ raise FileNotFoundError(
+ f"Model checkpoint not found at: {safetensors_file}"
+ )
+
+ state_dict = safetorch.load_file(safetensors_file, device="cpu")
+
+ # Remap SwiGLU weight names if the fused xFormers kernel is available.
+ use_fused_swiglu = XFORMERS_SWIGLU_AVAILABLE and not config_dict.get(
+ "pre_xformers_checkpoint", False
+ )
+ state_dict = Toto._map_state_dict_keys(state_dict, use_fused_swiglu)
+
+ # Filter to keys that exist in the model, skipping cached rotary
buffers.
+ model_state = instance.state_dict()
+ filtered_state_dict = {
+ k: v
+ for k, v in state_dict.items()
+ if k in model_state and not k.endswith("rotary_emb.freqs")
+ }
+
+ instance.load_state_dict(filtered_state_dict, strict=False)
+ instance.eval()
+
+ logger.info(f"Loaded Toto model from {pretrained_model_name_or_path}")
+ return instance
+
+ @property
+ def backbone(self):
+ """The underlying ``TotoBackbone`` used for inference."""
+ return self.model
+
+ @property
+ def device(self):
+ """Device on which model parameters reside."""
+ return next(self.parameters()).device
diff --git a/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py
b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py
new file mode 100644
index 00000000000..c6778a5e90b
--- /dev/null
+++ b/iotdb-core/ainode/iotdb/ainode/core/model/toto/pipeline_toto.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import torch
+
+from iotdb.ainode.core.inference.pipeline.basic_pipeline import
ForecastPipeline
+from iotdb.ainode.core.log import Logger
+from iotdb.ainode.core.model.toto.data.util.dataset import MaskedTimeseries
+from iotdb.ainode.core.model.toto.inference.forecaster import TotoForecaster
+
+logger = Logger()
+
+
+class TotoPipeline(ForecastPipeline):
+ """
+ Inference pipeline for the Toto time series foundation model.
+
+ Converts raw input tensors into ``MaskedTimeseries`` objects and delegates
+ autoregressive decoding to ``TotoForecaster``. The forecaster is created
+ lazily on the first call to ``forecast()`` so that pipeline construction
+ does not require a live model (useful during import / registration time).
+ """
+
+ def __init__(self, model_info, **model_kwargs):
+ super().__init__(model_info, **model_kwargs)
+ # Forecaster is created lazily to avoid issues at construction time.
+ self._forecaster: TotoForecaster | None = None
+
+ def _get_forecaster(self) -> TotoForecaster:
+ """Return the cached forecaster, creating it on first call."""
+ if self._forecaster is None:
+ self._forecaster = TotoForecaster(self.model.backbone)
+ return self._forecaster
+
+ def _preprocess(self, inputs, **infer_kwargs):
+ """
+ Preprocess input data for Toto.
+
+ Converts each input dict into a ``MaskedTimeseries`` named-tuple that
+ the ``TotoForecaster`` expects.
+
+ Parameters
+ ----------
+ inputs : list of dict
+ A list of dictionaries containing input data. Each dictionary
contains:
+ - 'targets': A tensor (1D or 2D) of shape (input_length,) or
(target_count, input_length).
+
+ infer_kwargs: Additional keyword arguments for inference, such as:
+ - `output_length`(int): Prediction length.
+
+ Returns
+ -------
+ list of MaskedTimeseries
+ Processed inputs compatible with Toto's forecaster.
+ """
+ processed_inputs = []
+ for item in inputs:
+ targets = item["targets"]
+ if targets.ndim == 1:
+ targets = targets.unsqueeze(0)
+
+ n_variates, series_len = targets.shape
+ device = targets.device
+
+ if "past_covariates" in item or "future_covariates" in item:
+ logger.warning(
+ "TotoPipeline does not support covariates; they will be
ignored."
+ )
+
+ padding_mask = ~torch.isnan(targets)
+ targets = targets.nan_to_num(0.0)
+
+ id_mask = torch.zeros(
+ n_variates, series_len, dtype=torch.long, device=device
+ )
+ timestamp_seconds = (
+ torch.arange(series_len, dtype=torch.long, device=device)
+ .unsqueeze(0)
+ .expand(n_variates, series_len)
+ )
+ time_interval_seconds = torch.ones(
+ n_variates, dtype=torch.long, device=device
+ )
+
+ processed_inputs.append(
+ MaskedTimeseries(
+ series=targets,
+ padding_mask=padding_mask,
+ id_mask=id_mask,
+ timestamp_seconds=timestamp_seconds,
+ time_interval_seconds=time_interval_seconds,
+ )
+ )
+
+ return processed_inputs
+
+ def forecast(self, inputs, **infer_kwargs) -> list[torch.Tensor]:
+ output_length = infer_kwargs.get("output_length", 96)
+ num_samples = infer_kwargs.get("num_samples", None)
+ samples_per_batch = infer_kwargs.get("samples_per_batch", 10)
+
+ forecaster = self._get_forecaster()
+
+ outputs = []
+ for masked_ts in inputs:
+ masked_ts = masked_ts._replace(
+ series=masked_ts.series.to(self.model.device),
+ padding_mask=masked_ts.padding_mask.to(self.model.device),
+ id_mask=masked_ts.id_mask.to(self.model.device),
+
timestamp_seconds=masked_ts.timestamp_seconds.to(self.model.device),
+ time_interval_seconds=masked_ts.time_interval_seconds.to(
+ self.model.device
+ ),
+ )
+ result = forecaster.forecast(
+ masked_ts,
+ prediction_length=output_length,
+ num_samples=num_samples,
+ samples_per_batch=samples_per_batch,
+ )
+ mean = result.mean
+ # Remove batch dimension if present (batch=1 squeeze).
+ if mean.ndim == 3 and mean.shape[0] == 1:
+ mean = mean.squeeze(0)
+ outputs.append(mean)
+ return outputs
+
+ def _postprocess(self, outputs, **infer_kwargs) -> list[torch.Tensor]:
+ return outputs
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index 9a142fe7259..0dc630fb0ff 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -117,6 +117,7 @@ setuptools = ">=75.3.0"
joblib = ">=1.4.2"
urllib3 = "2.6.3"
jaxtyping = ">=0.2.24"
+rotary-embedding-torch = ">=0.8.0"
[tool.poetry.scripts]
ainode = "iotdb.ainode.core.script:main"