This is an automated email from the ASF dual-hosted git repository.

yongzao pushed a commit to branch refactor_timerxl
in repository https://gitbox.apache.org/repos/asf/iotdb.git

commit 1049340c432c9f3b5d90db130b7555d4d7036d8b
Author: Yongzao <[email protected]>
AuthorDate: Thu Jun 5 21:47:08 2025 +0800

    almost done
---
 .../ainode/ainode/TimerXL/layers/Attn_Bias.py      | 108 ----
 .../ainode/TimerXL/layers/Attn_Projection.py       | 127 -----
 iotdb-core/ainode/ainode/TimerXL/layers/Embed.py   | 290 ----------
 .../ainode/TimerXL/layers/SelfAttention_Family.py  | 207 --------
 .../ainode/TimerXL/layers/Transformer_EncDec.py    | 329 ------------
 .../ainode/ainode/TimerXL/layers/__init__.py       |  17 -
 .../ainode/ainode/TimerXL/models/__init__.py       |  17 -
 .../ainode/ainode/TimerXL/models/timer_xl.py       | 446 ----------------
 .../ainode/core/manager/inference_manager.py       |   3 +-
 .../ainode/core/model/built_in_model_factory.py    |   8 +-
 .../{TimerXL => core/model/timerxl}/__init__.py    |   2 +-
 .../model/timerxl}/configuration_timer.py          |  31 +-
 .../ainode/core/model/timerxl/modeling_timer.py    | 591 +++++++++++++++++++++
 .../core/model/timerxl/ts_generation_mixin.py      | 297 +++++++++++
 iotdb-core/ainode/pyproject.toml                   |   2 +-
 15 files changed, 911 insertions(+), 1564 deletions(-)

diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py 
b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py
deleted file mode 100644
index 3c5ad760032..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import abc
-import math
-
-import torch
-from einops import rearrange
-from torch import nn
-
-
-class AttentionBias(nn.Module, abc.ABC):
-    def __init__(self, dim: int, num_heads: int):
-        super().__init__()
-        assert num_heads > 0 and dim % num_heads == 0
-
-        self.num_heads = num_heads
-        self.head_dim = dim // num_heads
-
-    @abc.abstractmethod
-    def forward(self, query_id, kv_id): ...
-
-
-class BinaryAttentionBias(AttentionBias):
-    def __init__(self, dim: int, num_heads: int):
-        super().__init__(dim, num_heads)
-        self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads)
-
-    def forward(self, query_id, kv_id):
-        ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2))
-        weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 
1")
-        bias = ~ind * weight[:1] + ind * weight[1:]
-        return bias
-
-
-def _relative_position_bucket(
-    relative_position, bidirectional=True, num_buckets=32, max_distance=128
-):
-    relative_buckets = 0
-    if bidirectional:
-        num_buckets //= 2
-        relative_buckets += (relative_position > 0).to(torch.long) * 
num_buckets
-        relative_position = torch.abs(relative_position)
-    else:
-        relative_position = -torch.min(
-            relative_position, torch.zeros_like(relative_position)
-        )
-
-    max_exact = num_buckets // 2
-    is_small = relative_position < max_exact
-    relative_position_if_large = max_exact + (
-        torch.log(relative_position.float() / max_exact)
-        / math.log(max_distance / max_exact)
-        * (num_buckets - max_exact)
-    ).to(torch.long)
-    relative_position_if_large = torch.min(
-        relative_position_if_large,
-        torch.full_like(relative_position_if_large, num_buckets - 1),
-    )
-
-    relative_buckets += torch.where(
-        is_small, relative_position, relative_position_if_large
-    )
-    return relative_buckets
-
-
-class T5AttentionBias(AttentionBias):
-    def __init__(self, dim: int, num_heads: int):
-        super().__init__(dim, num_heads)
-        self.num_buckets = 32
-        self.max_distance = 32
-        self.relative_attention_bias = nn.Embedding(self.num_buckets, 1)
-
-    def forward(self, n_vars, n_tokens):
-        context_position = torch.arange(
-            n_tokens,
-            dtype=torch.long,
-        )[:, None]
-        memory_position = torch.arange(
-            n_tokens,
-            dtype=torch.long,
-        )[None, :]
-        relative_position = memory_position - context_position
-        bucket = _relative_position_bucket(
-            relative_position=relative_position,
-            bidirectional=False,
-            num_buckets=self.num_buckets,
-            max_distance=self.max_distance,
-        ).to(self.relative_attention_bias.weight.device)
-        bias = self.relative_attention_bias(bucket).squeeze(-1)
-        bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1])
-        mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device)
-        final_bias = torch.kron(mask1, bias)
-        return final_bias
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py 
b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py
deleted file mode 100644
index 18e2b29c3d6..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import abc
-from functools import cached_property
-
-import torch
-from einops import einsum, rearrange, repeat
-from torch import nn
-
-
-class Projection(nn.Module, abc.ABC):
-    def __init__(self, proj_width: int, num_heads: int, **kwargs):
-        super().__init__()
-        self.proj_width = proj_width
-        self.num_heads = num_heads
-
-    @abc.abstractmethod
-    def forward(self, x, seq_id): ...
-
-
-class RotaryProjection(Projection):
-    def __init__(
-        self, *, proj_width: int, num_heads: int, max_len: int = 512, base: 
int = 10000
-    ):
-        super().__init__(proj_width, num_heads)
-        assert (
-            self.proj_width % 2 == 0
-        ), f"proj_width must be even, got {self.proj_width}"
-        self.register_buffer(
-            "theta",
-            1.0
-            / torch.pow(
-                base,
-                torch.arange(0, self.proj_width, 2, dtype=torch.float)
-                / self.proj_width,
-            ),
-            persistent=False,
-        )
-        self.register_buffer("cos", None, persistent=False)
-        self.register_buffer("sin", None, persistent=False)
-        self._init_freq(max_len=max_len)
-
-    def _init_freq(self, max_len: int):
-        if self.cos is None or self.cos.size(-2) < max_len:
-            position = torch.arange(
-                max_len, device=self.theta.device, dtype=self.theta.dtype
-            )
-            m_theta = einsum(position, self.theta, "length, width -> length 
width")
-            m_theta = repeat(m_theta, "length width -> length (width 2)")
-            self.register_buffer("cos", torch.cos(m_theta), persistent=False)
-            self.register_buffer("sin", torch.sin(m_theta), persistent=False)
-
-    @staticmethod
-    def _rotate(x):
-        x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2)
-        return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2)  # noqa
-
-    def forward(self, x, seq_id):
-        self._init_freq(max_len=seq_id.max() + 1)
-        rot_cos = self.cos[seq_id]
-        rot_sin = self.sin[seq_id]
-        return rot_cos * x + rot_sin * self._rotate(x)
-
-
-class QueryKeyProjection(nn.Module):
-    def __init__(
-        self, dim: int, num_heads: int, proj_layer, kwargs=None, 
partial_factor=None
-    ):
-        super().__init__()
-        if partial_factor is not None:
-            assert (
-                0.0 <= partial_factor[0] < partial_factor[1] <= 1.0
-            ), f"got {partial_factor[0]}, {partial_factor[1]}"
-        assert num_heads > 0 and dim % num_heads == 0
-
-        self.head_dim = dim // num_heads
-        self.partial_factor = partial_factor
-        self.query_proj = proj_layer(
-            proj_width=self.proj_width,
-            num_heads=num_heads,
-            **(kwargs or {}),
-        )
-        self.key_proj = self.query_proj
-
-    @cached_property
-    def proj_width(self) -> int:
-        if self.partial_factor is None:
-            return self.head_dim
-        return int(self.head_dim * (self.partial_factor[1] - 
self.partial_factor[0]))
-
-    @cached_property
-    def split_sizes(self):
-        if self.partial_factor is None:
-            return 0, self.head_dim, 0
-        return (
-            int(self.partial_factor[0] * self.head_dim),
-            self.proj_width,
-            int((1.0 - self.partial_factor[1]) * self.head_dim),
-        )
-
-    def forward(self, query, key, query_id, kv_id):
-        if self.partial_factor is not None:
-            queries = list(query.split(self.split_sizes, dim=-1))
-            keys = list(key.split(self.split_sizes, dim=-1))
-            queries[1] = self.query_proj(queries[1], seq_id=query_id)
-            keys[1] = self.key_proj(keys[1], seq_id=kv_id)
-            query = torch.cat(queries, dim=-1)
-            key = torch.cat(keys, dim=-1)
-        else:
-            query = self.query_proj(query, seq_id=query_id)
-            key = self.key_proj(key, seq_id=kv_id)
-        return query, key
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py 
b/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py
deleted file mode 100644
index 8c3cf570baf..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import math
-
-import torch
-import torch.nn as nn
-from torch.jit import is_scripting
-
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-
-class PositionalEmbedding(nn.Module):
-    def __init__(self, d_model, max_len=6500):
-        super(PositionalEmbedding, self).__init__()
-        # Compute the positional encodings once in log space.
-        pe = torch.zeros(max_len, d_model).float()
-        pe.require_grad = False
-
-        position = torch.arange(0, max_len).float().unsqueeze(1)
-        div_term = (
-            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / 
d_model)
-        ).exp()
-
-        pe[:, 0::2] = torch.sin(position * div_term)
-        pe[:, 1::2] = torch.cos(position * div_term)
-
-        pe = pe.unsqueeze(0)
-        self.register_buffer("pe", pe)
-
-    def forward(self, x):
-        return self.pe[:, : x.size(1)]
-
-
-class TokenEmbedding(nn.Module):
-    def __init__(self, c_in, d_model):
-        super(TokenEmbedding, self).__init__()
-        padding = 1 if torch.__version__ >= "1.5.0" else 2
-        self.tokenConv = nn.Conv1d(
-            in_channels=c_in,
-            out_channels=d_model,
-            kernel_size=3,
-            padding=padding,
-            padding_mode="circular",
-            bias=False,
-        )
-        for m in self.modules():
-            if isinstance(m, nn.Conv1d):
-                nn.init.kaiming_normal_(
-                    m.weight, mode="fan_in", nonlinearity="leaky_relu"
-                )
-
-    def forward(self, x):
-        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
-        return x
-
-
-class FixedEmbedding(nn.Module):
-    def __init__(self, c_in, d_model):
-        super(FixedEmbedding, self).__init__()
-
-        w = torch.zeros(c_in, d_model).float()
-        w.require_grad = False
-
-        position = torch.arange(0, c_in).float().unsqueeze(1)
-        div_term = (
-            torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / 
d_model)
-        ).exp()
-
-        w[:, 0::2] = torch.sin(position * div_term)
-        w[:, 1::2] = torch.cos(position * div_term)
-
-        self.emb = nn.Embedding(c_in, d_model)
-        self.emb.weight = nn.Parameter(w, requires_grad=False)
-
-    def forward(self, x):
-        return self.emb(x).detach()
-
-
-class TemporalEmbedding(nn.Module):
-    def __init__(self, d_model, embed_type="fixed", freq="h"):
-        super(TemporalEmbedding, self).__init__()
-
-        minute_size = 4
-        hour_size = 24
-        weekday_size = 7
-        day_size = 32
-        month_size = 13
-
-        Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
-        if freq == "t":
-            self.minute_embed = Embed(minute_size, d_model)
-        self.hour_embed = Embed(hour_size, d_model)
-        self.weekday_embed = Embed(weekday_size, d_model)
-        self.day_embed = Embed(day_size, d_model)
-        self.month_embed = Embed(month_size, d_model)
-
-    def forward(self, x):
-        x = x.long()
-        minute_x = (
-            self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") 
else 0.0
-        )
-        hour_x = self.hour_embed(x[:, :, 3])
-        weekday_x = self.weekday_embed(x[:, :, 2])
-        day_x = self.day_embed(x[:, :, 1])
-        month_x = self.month_embed(x[:, :, 0])
-
-        return hour_x + weekday_x + day_x + month_x + minute_x
-
-
-class TimeFeatureEmbedding(nn.Module):
-    def __init__(self, d_model, embed_type="timeF", freq="h"):
-        super(TimeFeatureEmbedding, self).__init__()
-
-        freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, 
"b": 3}
-        d_inp = freq_map[freq]
-        self.embed = nn.Linear(d_inp, d_model, bias=False)
-
-    def forward(self, x):
-        return self.embed(x)
-
-
-class DataEmbedding(nn.Module):
-    def __init__(self, c_in, d_model, embed_type="fixed", freq="h", 
dropout=0.1):
-        super(DataEmbedding, self).__init__()
-
-        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
-        self.position_embedding = PositionalEmbedding(d_model=d_model)
-        self.temporal_embedding = (
-            TemporalEmbedding(d_model=d_model, embed_type=embed_type, 
freq=freq)
-            if embed_type != "timeF"
-            else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, 
freq=freq)
-        )
-        self.dropout = nn.Dropout(p=dropout)
-
-    def forward(self, x, x_mark):
-        if x_mark is None:
-            x = self.value_embedding(x) + self.position_embedding(x)
-        else:
-            x = (
-                self.value_embedding(x)
-                + self.temporal_embedding(x_mark)
-                + self.position_embedding(x)
-            )
-        return self.dropout(x)
-
-
-class DataEmbedding_inverted(nn.Module):
-    def __init__(self, c_in, d_model, embed_type="fixed", freq="h", 
dropout=0.1):
-        super(DataEmbedding_inverted, self).__init__()
-        self.value_embedding = nn.Linear(c_in, d_model)
-        self.dropout = nn.Dropout(p=dropout)
-
-    def forward(self, x, x_mark):
-        x = x.permute(0, 2, 1)
-        # x: [Batch Variate Time]
-        if x_mark is None:
-            x = self.value_embedding(x)
-        else:
-            x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 
1))
-        # x: [Batch Variate d_model]
-        return self.dropout(x)
-
-
-class DataEmbedding_wo_pos(nn.Module):
-    def __init__(self, c_in, d_model, embed_type="fixed", freq="h", 
dropout=0.1):
-        super(DataEmbedding_wo_pos, self).__init__()
-
-        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
-        self.position_embedding = PositionalEmbedding(d_model=d_model)
-        self.temporal_embedding = (
-            TemporalEmbedding(d_model=d_model, embed_type=embed_type, 
freq=freq)
-            if embed_type != "timeF"
-            else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, 
freq=freq)
-        )
-        self.dropout = nn.Dropout(p=dropout)
-
-    def forward(self, x, x_mark):
-        if x_mark is None:
-            x = self.value_embedding(x)
-        else:
-            x = self.value_embedding(x) + self.temporal_embedding(x_mark)
-        return self.dropout(x)
-
-
-class PatchEmbedding(nn.Module):
-    def __init__(self, d_model, patch_len, stride, padding, dropout):
-        super(PatchEmbedding, self).__init__()
-        # Patching
-        self.patch_len = patch_len
-        self.stride = stride
-        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
-
-        # Backbone, Input encoding: projection of feature vectors onto a d-dim 
vector space
-        self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
-
-        # Positional embedding
-        self.position_embedding = PositionalEmbedding(d_model)
-
-        # Residual dropout
-        self.dropout = nn.Dropout(dropout)
-
-    def forward(self, x):
-        # do patching
-        n_vars = x.shape[1]
-        x = self.padding_patch_layer(x)
-        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
-        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
-        # Input encoding
-        x = self.value_embedding(x) + self.position_embedding(x)
-        return self.dropout(x), n_vars
-
-
-class TimerPatchEmbedding(nn.Module):
-    def __init__(self, config: TimerxlConfig):
-        super().__init__()
-        self.input_token_len = config.input_token_len
-        self.emb = nn.Linear(config.input_token_len, config.hidden_size, 
bias=False)
-
-    def forward(self, hidden_state: torch.Tensor):
-        hidden_state = hidden_state.unfold(
-            dimension=-1, size=self.input_token_len, step=self.input_token_len
-        )
-        return self.emb(hidden_state)
-
-
-class TimeMoeRotaryEmbedding(torch.nn.Module):
-    def __init__(self, dim, max_position_embeddings=10000, base=10000, 
device=None):
-        super().__init__()
-        self.dim = dim
-        self.max_position_embeddings = max_position_embeddings
-        self.base = base
-        self.max_seq_len_cached: int = 0
-        inv_freq = 1.0 / (
-            self.base
-            ** (
-                torch.arange(0, self.dim, 2, 
dtype=torch.int64).float().to(device)
-                / self.dim
-            )
-        )
-        self.register_buffer("inv_freq", inv_freq, persistent=False)
-
-        # Build here to make `torch.jit.trace` work.
-        self._set_cos_sin_cache(
-            seq_len=max_position_embeddings,
-            device=self.inv_freq.device,
-            dtype=torch.get_default_dtype(),
-        )
-
-    def _set_cos_sin_cache(
-        self, seq_len: int, device: torch.device, dtype: torch.dtype
-    ):
-        self.max_seq_len_cached = int(seq_len)
-        t = torch.arange(
-            self.max_seq_len_cached, device=device, dtype=torch.int64
-        ).type_as(self.inv_freq)
-
-        freqs = torch.outer(t, self.inv_freq)
-        # Different from paper, but it uses a different permutation in order 
to obtain the same calculation
-        emb = torch.cat((freqs, freqs), dim=-1)
-        if not is_scripting():
-            self.register_buffer("cos_cached", emb.cos().to(dtype), 
persistent=False)
-            self.register_buffer("sin_cached", emb.sin().to(dtype), 
persistent=False)
-        else:
-            self.cos_cached = emb.cos().to(dtype)
-            self.sin_cached = emb.sin().to(dtype)
-
-    def forward(self, x, seq_len: int = 0):
-        # x: [bs, num_attention_heads, seq_len, head_size]
-        if seq_len > self.max_seq_len_cached:
-            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, 
dtype=x.dtype)
-
-        return (
-            self.cos_cached[:seq_len].to(dtype=x.dtype),
-            self.sin_cached[:seq_len].to(dtype=x.dtype),
-        )
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py 
b/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py
deleted file mode 100644
index 4a2fb0d27e0..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py
+++ /dev/null
@@ -1,207 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from math import sqrt
-from typing import Any, Optional, Tuple
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import repeat
-
-from ainode.core.util.huggingface_cache import Cache, DynamicCache
-from ainode.core.util.masking import (
-    TimerCovariateMask,
-    TimerMultivariateMask,
-    TriangularCausalMask,
-)
-from ainode.TimerXL.layers.Attn_Bias import BinaryAttentionBias
-from ainode.TimerXL.layers.Attn_Projection import QueryKeyProjection, 
RotaryProjection
-from ainode.TimerXL.layers.Embed import TimeMoeRotaryEmbedding
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-
-def rotate_half(x):
-    x1 = x[..., : x.shape[-1] // 2]
-    x2 = x[..., x.shape[-1] // 2 :]
-    return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
-    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
-    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
-    q_embed = (q * cos) + (rotate_half(q) * sin)
-    k_embed = (k * cos) + (rotate_half(k) * sin)
-    return q_embed, k_embed
-
-
-class FullAttention(nn.Module):
-    def __init__(
-        self, mask_flag=True, scale=None, attention_dropout=0.1, 
output_attention=False
-    ):
-        super(FullAttention, self).__init__()
-        self.scale = scale
-        self.mask_flag = mask_flag
-        self.output_attention = output_attention
-        self.dropout = nn.Dropout(attention_dropout)
-
-    def forward(
-        self,
-        queries,
-        keys,
-        values,
-        attn_mask,
-        n_vars=None,
-        n_tokens=None,
-        tau=None,
-        delta=None,
-    ):
-        B, L, H, E = queries.shape
-        _, S, _, D = values.shape
-        scale = self.scale or 1.0 / sqrt(E)
-
-        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
-
-        if self.mask_flag:
-            if attn_mask is None:
-                attn_mask = TriangularCausalMask(B, L, device=queries.device)
-
-            scores.masked_fill_(attn_mask.mask, -np.inf)
-
-        A = self.dropout(torch.softmax(scale * scores, dim=-1))
-        V = torch.einsum("bhls,bshd->blhd", A, values)
-
-        if self.output_attention:
-            return V.contiguous(), A
-        else:
-            return V.contiguous(), None
-
-
-class TimerAttention(nn.Module):
-    def __init__(self, config: TimerxlConfig, layer_idx: Optional[int] = None):
-        super().__init__()
-        self.layer_idx = layer_idx
-        self.hidden_size = config.hidden_size
-        self.num_heads = config.num_attention_heads
-        self.head_dim = self.hidden_size // self.num_heads
-        self.attention_dropout = config.attention_dropout
-        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
-        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
-        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
-        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
-        self.rotary_emb = TimeMoeRotaryEmbedding(
-            self.head_dim, 
max_position_embeddings=config.max_position_embeddings
-        )
-
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        past_key_value: Optional["Cache"] = None,
-    ) -> Tuple[torch.Tensor, Optional["Cache"]]:
-        bsz, q_len, _ = hidden_states.size()
-
-        query_states = self.q_proj(hidden_states)
-        key_states = self.k_proj(hidden_states)
-        value_states = self.v_proj(hidden_states)
-
-        query_states = query_states.view(
-            bsz, q_len, self.num_heads, self.head_dim
-        ).transpose(1, 2)
-        key_states = key_states.view(
-            bsz, q_len, self.num_heads, self.head_dim
-        ).transpose(1, 2)
-        value_states = value_states.view(
-            bsz, q_len, self.num_heads, self.head_dim
-        ).transpose(1, 2)
-
-        kv_seq_len = key_states.shape[-2]
-        if past_key_value is not None:
-            kv_seq_len += past_key_value.get_usable_length(kv_seq_len, 
self.layer_idx)
-        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
-        query_states, key_states = apply_rotary_pos_emb(
-            query_states, key_states, cos, sin, position_ids
-        )
-
-        if past_key_value is not None:
-            key_states, value_states = past_key_value.update(
-                key_states, value_states, self.layer_idx
-            )
-
-        attn_output = F.scaled_dot_product_attention(
-            query_states,
-            key_states,
-            value_states,
-            attention_mask,
-            dropout_p=self.attention_dropout,
-        )
-
-        attn_output = attn_output.transpose(1, 2).contiguous()
-        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
-        attn_output = self.o_proj(attn_output)
-
-        return attn_output, past_key_value
-
-
-class AttentionLayer(nn.Module):
-    def __init__(self, attention, d_model, n_heads, d_keys=None, 
d_values=None):
-        super(AttentionLayer, self).__init__()
-
-        d_keys = d_keys or (d_model // n_heads)
-        d_values = d_values or (d_model // n_heads)
-
-        self.inner_attention = attention
-        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
-        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
-        self.value_projection = nn.Linear(d_model, d_values * n_heads)
-        self.out_projection = nn.Linear(d_values * n_heads, d_model)
-        self.n_heads = n_heads
-
-    def forward(
-        self,
-        queries,
-        keys,
-        values,
-        attn_mask,
-        n_vars=None,
-        n_tokens=None,
-        tau=None,
-        delta=None,
-    ):
-        B, L, _ = queries.shape
-        _, S, _ = keys.shape
-        H = self.n_heads
-
-        queries = self.query_projection(queries).view(B, L, H, -1)
-        keys = self.key_projection(keys).view(B, S, H, -1)
-        values = self.value_projection(values).view(B, S, H, -1)
-
-        out, attn = self.inner_attention(
-            queries,
-            keys,
-            values,
-            attn_mask,
-            n_vars=n_vars,
-            n_tokens=n_tokens,
-            tau=tau,
-            delta=delta,
-        )
-        out = out.view(B, L, -1)
-
-        return self.out_projection(out), attn
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py 
b/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py
deleted file mode 100644
index d5bad30ea05..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py
+++ /dev/null
@@ -1,329 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from typing import Optional, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ainode.core.util.activation import ACT2FN
-from ainode.core.util.huggingface_cache import Cache, DynamicCache
-from ainode.TimerXL.layers.SelfAttention_Family import TimerAttention
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-
-class EncoderLayer(nn.Module):
-    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, 
activation="relu"):
-        super(EncoderLayer, self).__init__()
-        d_ff = d_ff or 4 * d_model
-        self.attention = attention
-        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, 
kernel_size=1)
-        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, 
kernel_size=1)
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.dropout = nn.Dropout(dropout)
-        self.activation = F.relu if activation == "relu" else F.gelu
-
-    def forward(self, x, attn_mask=None, tau=None, delta=None):
-        new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-        x = x + self.dropout(new_x)
-
-        y = x = self.norm1(x)
-        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
-        y = self.dropout(self.conv2(y).transpose(-1, 1))
-
-        return self.norm2(x + y), attn
-
-
-class DecoderLayer(nn.Module):
-    def __init__(
-        self,
-        self_attention,
-        cross_attention,
-        d_model,
-        d_ff=None,
-        dropout=0.1,
-        activation="relu",
-    ):
-        super(DecoderLayer, self).__init__()
-        d_ff = d_ff or 4 * d_model
-        self.self_attention = self_attention
-        self.cross_attention = cross_attention
-        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, 
kernel_size=1)
-        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, 
kernel_size=1)
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.norm3 = nn.LayerNorm(d_model)
-        self.dropout = nn.Dropout(dropout)
-        self.activation = F.relu if activation == "relu" else F.gelu
-
-    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, 
delta=None):
-        x = x + self.dropout(
-            self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, 
delta=None)[0]
-        )
-        x = self.norm1(x)
-
-        x = x + self.dropout(
-            self.cross_attention(
-                x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
-            )[0]
-        )
-
-        y = x = self.norm2(x)
-        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
-        y = self.dropout(self.conv2(y).transpose(-1, 1))
-
-        return self.norm3(x + y)
-
-
-class DecoderOnlyLayer(nn.Module):
-    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, 
activation="relu"):
-        super(DecoderOnlyLayer, self).__init__()
-        d_ff = d_ff or 4 * d_model
-        self.attention = attention
-        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, 
kernel_size=1)
-        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, 
kernel_size=1)
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.dropout = nn.Dropout(dropout)
-        self.activation = F.relu if activation == "relu" else F.gelu
-
-    def forward(self, x, attn_mask=None, tau=None, delta=None):
-        new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-        x = x + self.dropout(new_x)
-
-        y = x = self.norm1(x)
-        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
-        y = self.dropout(self.conv2(y).transpose(-1, 1))
-
-        return self.norm2(x + y), attn
-
-
-class TimerLayer(nn.Module):
-    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, 
activation="relu"):
-        super(TimerLayer, self).__init__()
-        d_ff = d_ff or 4 * d_model
-        self.attention = attention
-        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, 
kernel_size=1)
-        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, 
kernel_size=1)
-        self.norm1 = nn.LayerNorm(d_model)
-        self.norm2 = nn.LayerNorm(d_model)
-        self.dropout = nn.Dropout(dropout)
-        self.activation = F.relu if activation == "relu" else F.gelu
-
-    def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, 
delta=None):
-        new_x, attn = self.attention(
-            x,
-            x,
-            x,
-            n_vars=n_vars,
-            n_tokens=n_tokens,
-            attn_mask=attn_mask,
-            tau=tau,
-            delta=delta,
-        )
-        x = x + self.dropout(new_x)
-
-        y = x = self.norm1(x)
-        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
-        y = self.dropout(self.conv2(y).transpose(-1, 1))
-
-        return self.norm2(x + y), attn
-
-
-class Encoder(nn.Module):
-    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
-        super(Encoder, self).__init__()
-        self.attn_layers = nn.ModuleList(attn_layers)
-        self.conv_layers = (
-            nn.ModuleList(conv_layers) if conv_layers is not None else None
-        )
-        self.norm = norm_layer
-
-    def forward(self, x, attn_mask=None, tau=None, delta=None):
-        # x [B, L, D]
-        attns = []
-        if self.conv_layers is not None:
-            for i, (attn_layer, conv_layer) in enumerate(
-                zip(self.attn_layers, self.conv_layers)
-            ):
-                delta = delta if i == 0 else None
-                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-                x = conv_layer(x)
-                attns.append(attn)
-            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
-            attns.append(attn)
-        else:
-            for attn_layer in self.attn_layers:
-                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-                attns.append(attn)
-
-        if self.norm is not None:
-            x = self.norm(x)
-
-        return x, attns
-
-
-class Decoder(nn.Module):
-    def __init__(self, layers, norm_layer=None, projection=None):
-        super(Decoder, self).__init__()
-        self.layers = nn.ModuleList(layers)
-        self.norm = norm_layer
-        self.projection = projection
-
-    def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, 
delta=None):
-        for layer in self.layers:
-            x = layer(
-                x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, 
delta=delta
-            )
-
-        if self.norm is not None:
-            x = self.norm(x)
-
-        if self.projection is not None:
-            x = self.projection(x)
-        return x
-
-
-class DecoderOnly(nn.Module):
-    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
-        super(DecoderOnly, self).__init__()
-        self.attn_layers = nn.ModuleList(attn_layers)
-        self.conv_layers = (
-            nn.ModuleList(conv_layers) if conv_layers is not None else None
-        )
-        self.norm = norm_layer
-
-    def forward(self, x, attn_mask=None, tau=None, delta=None):
-        # x [B, L, D]
-        attns = []
-        if self.conv_layers is not None:
-            for i, (attn_layer, conv_layer) in enumerate(
-                zip(self.attn_layers, self.conv_layers)
-            ):
-                delta = delta if i == 0 else None
-                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-                x = conv_layer(x)
-                attns.append(attn)
-            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
-            attns.append(attn)
-        else:
-            for attn_layer in self.attn_layers:
-                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-                attns.append(attn)
-
-        if self.norm is not None:
-            x = self.norm(x)
-
-        return x, attns
-
-
-class TimerBlock(nn.Module):
-    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
-        super(TimerBlock, self).__init__()
-        self.attn_layers = nn.ModuleList(attn_layers)
-        self.conv_layers = (
-            nn.ModuleList(conv_layers) if conv_layers is not None else None
-        )
-        self.norm = norm_layer
-
-    def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, 
delta=None):
-        # x [B, L, D]
-        attns = []
-        if self.conv_layers is not None:
-            for i, (attn_layer, conv_layer) in enumerate(
-                zip(self.attn_layers, self.conv_layers)
-            ):
-                delta = delta if i == 0 else None
-                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, 
delta=delta)
-                x = conv_layer(x)
-                attns.append(attn)
-            x, attn = self.attn_layers[-1](x, n_vars, n_tokens, tau=tau, 
delta=None)
-            attns.append(attn)
-        else:
-            for attn_layer in self.attn_layers:
-                x, attn = attn_layer(
-                    x, n_vars, n_tokens, attn_mask=attn_mask, tau=tau, 
delta=delta
-                )
-                attns.append(attn)
-
-        if self.norm is not None:
-            x = self.norm(x)
-
-        return x, attns
-
-
-class TimerMLP(nn.Module):
-    def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: 
str):
-        super().__init__()
-        self.hidden_size = hidden_size
-        self.intermediate_size = intermediate_size
-        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, 
bias=False)
-        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, 
bias=False)
-        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, 
bias=False)
-        self.act_fn = ACT2FN[hidden_act]
-
-    def forward(self, hidden_state):
-        return self.down_proj(
-            self.act_fn(self.gate_proj(hidden_state)) * 
self.up_proj(hidden_state)
-        )
-
-
-class TimerDecoderLayer(nn.Module):
-    def __init__(self, config: TimerxlConfig, layer_idx: int):
-        super().__init__()
-        self.self_attn = TimerAttention(config, layer_idx)
-
-        self.ffn_layer = TimerMLP(
-            hidden_size=config.hidden_size,
-            intermediate_size=config.intermediate_size,
-            hidden_act=config.hidden_act,
-        )
-        self.norm1 = torch.nn.LayerNorm(config.hidden_size)
-        self.norm2 = torch.nn.LayerNorm(config.hidden_size)
-
-    def forward(
-        self,
-        hidden_states: torch.Tensor,
-        attention_mask: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        past_key_value: Optional[Cache] = None,
-        use_cache: bool = False,
-    ) -> Tuple[torch.FloatTensor, Optional[Cache]]:
-        residual = hidden_states
-
-        # Self Attention
-        hidden_states, present_key_value = self.self_attn(
-            hidden_states=hidden_states,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_value=past_key_value,
-        )
-
-        hidden_states = residual + hidden_states
-        hidden_states = self.norm1(hidden_states)
-
-        # Fully Connected
-        residual = hidden_states
-        hidden_states = self.ffn_layer(hidden_states)
-        hidden_states = residual + hidden_states
-        hidden_states = self.norm2(hidden_states)
-
-        if not use_cache:
-            present_key_value = None
-        return hidden_states, present_key_value
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py 
b/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py
deleted file mode 100644
index 2a1e720805f..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py 
b/iotdb-core/ainode/ainode/TimerXL/models/__init__.py
deleted file mode 100644
index 2a1e720805f..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py 
b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
deleted file mode 100644
index b3962a052a5..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import os
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple
-
-import torch
-from huggingface_hub import hf_hub_download
-from safetensors.torch import load_file as load_safetensors
-from torch import nn
-
-from ainode.core.log import Logger
-from ainode.core.util.huggingface_cache import Cache, DynamicCache
-from ainode.core.util.masking import prepare_4d_causal_attention_mask
-from ainode.TimerXL.layers.Embed import TimerPatchEmbedding
-from ainode.TimerXL.layers.Transformer_EncDec import TimerDecoderLayer
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-logger = Logger()
-
-
-@dataclass
-class Output:
-    outputs: torch.Tensor
-    past_key_values: Optional[Any] = None
-
-
-class TimerModel(nn.Module):
-    def __init__(self, config: TimerxlConfig):
-        super().__init__()
-        self.config = config
-        self.embed_layer = TimerPatchEmbedding(config)
-        self.layers = nn.ModuleList(
-            [
-                TimerDecoderLayer(config, layer_idx)
-                for layer_idx in range(config.num_hidden_layers)
-            ]
-        )
-        self.norm = torch.nn.LayerNorm(config.hidden_size)
-        self.gradient_checkpointing = False
-
-    def forward(
-        self,
-        input_ids: torch.FloatTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        past_key_values: Optional[List[torch.FloatTensor]] = None,
-        use_cache: bool = None,
-    ):
-        # input_ids is the input of time series, its shape is [batch_size, 
seq_len]
-
-        if input_ids is not None:
-            batch_size, seq_length = input_ids.shape
-        else:
-            raise ValueError(
-                "You have to specify either decoder_input_ids or 
decoder_inputs_embeds"
-            )
-
-        inputs_embeds = self.embed_layer(input_ids)
-
-        seq_length = inputs_embeds.shape[1]
-
-        past_key_values_length = 0
-
-        if use_cache:
-            use_legacy_cache = not isinstance(past_key_values, Cache)
-            if use_legacy_cache:
-                past_key_values = 
DynamicCache.from_legacy_cache(past_key_values)
-            past_key_values_length = 
past_key_values.get_usable_length(seq_length)
-
-        if position_ids is None:
-            device = input_ids.device if input_ids is not None else 
inputs_embeds.device
-            position_ids = torch.arange(
-                past_key_values_length,
-                seq_length + past_key_values_length,
-                dtype=torch.long,
-                device=device,
-            )
-            position_ids = position_ids.view(-1, seq_length)
-        else:
-            position_ids = position_ids.view(-1, seq_length).long()
-
-        # 4d mask is passed through the layers
-        attention_mask = prepare_4d_causal_attention_mask(
-            attention_mask,
-            (batch_size, seq_length),
-            inputs_embeds,
-            past_key_values_length,
-        )
-
-        hidden_states = inputs_embeds
-
-        # decoder layers
-        next_decoder_cache = None
-
-        for decoder_layer in self.layers:
-            layer_outputs = decoder_layer(
-                hidden_states,
-                attention_mask=attention_mask,
-                position_ids=position_ids,
-                past_key_value=past_key_values,
-                use_cache=use_cache,
-            )
-
-            hidden_states = layer_outputs[0]
-
-            if use_cache:
-                next_decoder_cache = layer_outputs[1]
-
-        hidden_states = self.norm(hidden_states)
-
-        next_cache = None
-        if use_cache:
-            next_cache = (
-                next_decoder_cache.to_legacy_cache()
-                if use_legacy_cache
-                else next_decoder_cache
-            )
-
-        return Output(outputs=hidden_states, past_key_values=next_cache)
-
-
-class TimerForPrediction(nn.Module):
-    def __init__(self, config):
-        super().__init__()
-        self.config = config
-        self.model = TimerModel(self.config)
-        lm_head_list = []
-        self.output_token_len_map = {}
-        for i, output_token_len in enumerate(self.config.output_token_lens):
-            lm_head_list.append(
-                nn.Linear(self.config.hidden_size, output_token_len, 
bias=False)
-            )
-            self.output_token_len_map[output_token_len] = i
-        self.lm_heads = nn.ModuleList(lm_head_list)
-        self.loss_function = torch.nn.MSELoss(reduction="none")
-
-    def forward(
-        self,
-        input_ids: torch.FloatTensor = None,
-        attention_mask: Optional[torch.Tensor] = None,
-        position_ids: Optional[torch.LongTensor] = None,
-        past_key_values: Optional[List[torch.FloatTensor]] = None,
-        use_cache: Optional[bool] = None,
-        max_output_length: Optional[int] = None,
-        revin: Optional[bool] = True,
-    ):
-        if revin:
-            means, stdev = input_ids.mean(dim=-1, keepdim=True), input_ids.std(
-                dim=-1, keepdim=True
-            )
-            input_ids = (input_ids - means) / stdev
-
-        outputs = self.model(
-            input_ids=input_ids,
-            attention_mask=attention_mask,
-            position_ids=position_ids,
-            past_key_values=past_key_values,
-            use_cache=use_cache,
-        )
-        hidden_states = outputs.outputs
-
-        if max_output_length is None:
-            output_token_len = self.config.output_token_lens[0]
-            max_output_length = output_token_len
-        else:
-            output_token_len = self.config.output_token_lens[0]
-            for h in self.config.output_token_lens[1:]:
-                if h > max_output_length:
-                    break
-                else:
-                    output_token_len = h
-
-        lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
-        predictions = lm_head(hidden_states)[:, -1, :]
-
-        if output_token_len > max_output_length:
-            predictions = predictions[:, :max_output_length]
-        if revin:
-            predictions = predictions * stdev + means
-
-        return Output(predictions, outputs.past_key_values)
-
-
-class Model(nn.Module):
-    """
-    Timer-XL: Long-Context Transformers for Unified Time Series Forecasting
-
-    Paper: https://arxiv.org/abs/2410.04803
-
-    GitHub: https://github.com/thuml/Timer-XL
-
-    Citation: @article{liu2024timer,
-        title={Timer-XL: Long-Context Transformers for Unified Time Series 
Forecasting},
-        author={Liu, Yong and Qin, Guo and Huang, Xiangdong and Wang, Jianmin 
and Long, Mingsheng},
-        journal={arXiv preprint arXiv:2410.04803},
-        year={2024}
-    }
-    """
-
-    def __init__(self, config: TimerxlConfig):
-        super().__init__()
-        self.config = config  # can't be scripted by torch
-
-        self.device = torch.device("cuda" if torch.cuda.is_available() else 
"cpu")
-        self.model = TimerForPrediction(config).to(self.device)
-
-        if config.ckpt_path is not None and config.ckpt_path != "":
-            if config.ckpt_path.endswith(".pt") or 
config.ckpt_path.endswith(".pth"):
-                state_dict = torch.load(config.ckpt_path)
-            elif config.ckpt_path.endswith(".safetensors"):
-                if not os.path.exists(config.ckpt_path):
-                    logger.info(
-                        f"Checkpoint not found at {config.ckpt_path}, 
downloading from HuggingFace..."
-                    )
-                    repo_id = "thuml/timer-base-84m"
-                    try:
-                        config.ckpt_path = hf_hub_download(
-                            repo_id=repo_id,
-                            filename=os.path.basename(config.ckpt_path),
-                            local_dir=os.path.dirname(config.ckpt_path),
-                        )
-                        logger.info(f"Got checkpoint to {config.ckpt_path}")
-                    except Exception as e:
-                        logger.error(
-                            f"Failed to download checkpoint to 
{config.ckpt_path} due to {e}"
-                        )
-                        raise e
-                state_dict = load_safetensors(config.ckpt_path)
-            else:
-                raise ValueError("unsupported model weight type")
-            # If there is no key beginning with 'model.model' in state_dict, 
add a 'model.' before all keys. (The model code here has an additional layer of 
encapsulation compared to the code on huggingface.)
-            if not any(k.startswith("model.model") for k in state_dict.keys()):
-                state_dict = {"model." + k: v for k, v in state_dict.items()}
-            self.load_state_dict(state_dict, strict=True)
-
-    def set_device(self, device):
-        self.model.to(device)
-        self.device = next(self.model.parameters()).device
-
-    def inference(self, x, max_new_tokens: int = 96):
-        # x.shape: [L, C], type: DataFrame
-        # here we only except C=1 temporarily
-        # change [L, C=1] to [batchsize=1, L]
-        self.device = next(self.model.parameters()).device
-
-        x = torch.tensor(
-            x, dtype=next(self.model.parameters()).dtype, device=self.device
-        )
-        x = x.view(1, -1)
-
-        preds = self.forward(x, max_new_tokens)
-        preds = preds.detach().cpu().numpy()
-
-        return preds
-
-    def forward(self, x, max_new_tokens: int = 96):
-        # self.config.is_encoder_decoder = False
-        self.eval()
-        self.device = next(self.model.parameters()).device
-
-        if len(x.shape) == 2:
-            batch_size, cur_len = x.shape
-            if cur_len < self.config.input_token_len:
-                raise ValueError(
-                    f"Input length must be at least 
{self.config.input_token_len}"
-                )
-            elif cur_len % self.config.input_token_len != 0:
-                new_len = (
-                    cur_len // self.config.input_token_len
-                ) * self.config.input_token_len
-                x = x[:, -new_len:]
-        else:
-            raise ValueError("Input shape must be: [batch_size, seq_len]")
-
-        use_cache = self.config.use_cache
-        all_input_ids = x
-
-        attention_mask = 
self.prepare_attention_mask_for_generation(all_input_ids)
-        all_input_ids_length = all_input_ids.shape[-1]
-        max_length = max_new_tokens + all_input_ids_length
-
-        all_input_ids = all_input_ids.to(self.device)
-        batch_size, cur_len = all_input_ids.shape
-
-        unfinished_sequences = torch.ones(
-            batch_size, dtype=torch.long, device=all_input_ids.device
-        )
-        cache_position = torch.arange(cur_len, device=all_input_ids.device)
-        true_seq_len = cur_len // self.config.input_token_len
-        attention_mask = attention_mask[:, -true_seq_len:]
-
-        this_peer_finished = False
-        past_key_values = None
-        position_ids = None
-        while not this_peer_finished:
-            (input_ids, position_ids, past_key_values, attention_mask, revin) 
= (
-                self.prepare_inputs_for_generation(
-                    all_input_ids,
-                    past_key_values=past_key_values,
-                    attention_mask=attention_mask,
-                    # position_ids=position_ids     # Wrong?!
-                    position_ids=None,  # True?!    based on huggingface code
-                )
-            )
-
-            input_length = all_input_ids.shape[1]
-
-            # forward pass to get next token
-            outputs = self.model(
-                input_ids,
-                attention_mask=attention_mask,
-                position_ids=position_ids,
-                past_key_values=past_key_values,
-                use_cache=use_cache,
-                max_output_length=max_length - input_length,
-                revin=revin,
-            )
-
-            next_tokens = outputs.outputs
-
-            # update generated ids, model inputs, and length for next step
-            horizon_length = next_tokens.shape[1] // 
self.config.input_token_len
-
-            all_input_ids = torch.cat([all_input_ids, next_tokens], dim=-1)
-            (past_key_values, attention_mask, cache_position) = (
-                self._update_model_kwargs_for_generation(
-                    outputs,
-                    attention_mask=attention_mask,
-                    horizon_length=horizon_length,
-                    cache_position=cache_position,
-                )
-            )
-
-            unfinished_sequences = unfinished_sequences & (
-                all_input_ids.shape[1] < max_length
-            )
-            this_peer_finished = unfinished_sequences.max() == 0
-
-        if all_input_ids.shape[1] > max_length:
-            all_input_ids = all_input_ids[:, :max_length]
-
-        return all_input_ids[:, -(max_length - cur_len) :]
-
-    def prepare_attention_mask_for_generation(
-        self,
-        inputs: torch.Tensor,
-    ) -> torch.LongTensor:
-        return torch.ones(inputs.shape[:2], dtype=torch.long, 
device=inputs.device)
-
-    def prepare_inputs_for_generation(
-        self,
-        input_ids,
-        past_key_values=None,
-        attention_mask=None,
-        revin=True,
-        position_ids=None,
-    ):
-        # Omit tokens covered by past_key_values
-        if past_key_values is not None:
-            if isinstance(past_key_values, Cache):
-                cache_length = past_key_values.get_seq_length()
-                if isinstance(past_key_values, DynamicCache):
-                    past_length = past_key_values.seen_tokens
-                else:
-                    past_length = cache_length
-
-                max_cache_length = past_key_values.get_max_length()
-            else:
-                cache_length = past_length = past_key_values[0][0].shape[2]
-                max_cache_length = None
-
-            # Keep only the unprocessed tokens:
-            # 1 - If the length of the attention_mask exceeds the length of 
input_ids, then we are in a setting where
-            # some of the inputs are exclusively passed as part of the cache 
(e.g. when passing input_embeds as
-            # input)
-            if attention_mask is not None and attention_mask.shape[1] > (
-                input_ids.shape[1] // self.config.input_token_len
-            ):
-                input_ids = input_ids[:, -(attention_mask.shape[1] - 
past_length) :]
-            # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
-            # input_ids based on the past_length.
-            elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
-                input_ids = input_ids[:, past_length * 
self.config.input_token_len :]
-            # 3 - Otherwise (past_length >= (input_ids.shape[1] // 
self.config.input_token_len)), let's assume input_ids only has unprocessed 
tokens.
-
-            # If we are about to go beyond the maximum cache length, we need 
to crop the input attention mask.
-            if (
-                max_cache_length is not None
-                and attention_mask is not None
-                and cache_length + (input_ids.shape[1] // 
self.config.input_token_len)
-                > max_cache_length
-            ):
-                attention_mask = attention_mask[:, -max_cache_length:]
-
-        if attention_mask is not None and position_ids is None:
-            # create position_ids on the fly for batch generation
-            position_ids = attention_mask.long().cumsum(-1) - 1
-            position_ids.masked_fill_(attention_mask == 0, 1)
-            if past_key_values:
-                position_ids = position_ids[
-                    :, -(input_ids.shape[1] // self.config.input_token_len) :
-                ]
-
-        return (input_ids, position_ids, past_key_values, attention_mask, 
revin)
-
-    def _update_model_kwargs_for_generation(
-        self,
-        outputs,
-        attention_mask=None,
-        cache_position=None,
-        horizon_length: int = 1,
-    ) -> Dict[str, Any]:
-        # update past_key_values
-        past_key_values = outputs.past_key_values
-
-        # update attention mask
-        if attention_mask is not None:
-            attention_mask = torch.cat(
-                [
-                    attention_mask,
-                    attention_mask.new_ones((attention_mask.shape[0], 
horizon_length)),
-                ],
-                dim=-1,
-            )
-
-        if cache_position is not None:
-            cache_position = cache_position[-1:] + horizon_length
-
-        return (past_key_values, attention_mask, cache_position)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py 
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 175961e7313..8d2b9a83d3d 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -57,7 +57,8 @@ class TimerXLStrategy(InferenceStrategy):
         data = full_data[1][0]
         if data.dtype.byteorder not in ("=", "|"):
             data = data.byteswap().newbyteorder()
-        output = self.model.inference(data, int(predict_length))
+        seqs = torch.tensor(data).unsqueeze(0).float()
+        output = self.model.generate(seqs, max_new_tokens=predict_length)
         df = pd.DataFrame(output[0])
         return convert_to_binary(df)
 
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py 
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 6298fb6a1db..d257d50ce74 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -39,10 +39,10 @@ from ainode.core.exception import (
     WrongAttributeTypeError,
 )
 from ainode.core.log import Logger
+from ainode.core.model.timerxl import modeling_timer
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
 from ainode.core.model.sundial import modeling_sundial
 from ainode.core.model.sundial.configuration_sundial import SundialConfig
-from ainode.TimerXL.models import timer_xl
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
 
 logger = Logger()
 
@@ -113,7 +113,9 @@ def fetch_built_in_model(model_id, inference_attributes):
     elif model_id == BuiltInModelType.STRAY.value:
         model = STRAYModel(attributes)
     elif model_id == BuiltInModelType.TIMER_XL.value:
-        model = timer_xl.Model(TimerxlConfig.from_dict(attributes))
+        model = modeling_timer.TimerForPrediction(
+            TimerConfig.from_dict(attributes)
+        )
     elif model_id == BuiltInModelType.SUNDIAL.value:
         model = modeling_sundial.SundialForPrediction(
             SundialConfig.from_dict(attributes)
diff --git a/iotdb-core/ainode/ainode/TimerXL/__init__.py 
b/iotdb-core/ainode/ainode/core/model/timerxl/__init__.py
similarity index 99%
rename from iotdb-core/ainode/ainode/TimerXL/__init__.py
rename to iotdb-core/ainode/ainode/core/model/timerxl/__init__.py
index 2a1e720805f..4b8ee97fad2 100644
--- a/iotdb-core/ainode/ainode/TimerXL/__init__.py
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/__init__.py
@@ -14,4 +14,4 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-#
+#
\ No newline at end of file
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py 
b/iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py
similarity index 68%
rename from iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py
rename to iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py
index ac5034aa85e..ccc0c9d6a13 100644
--- a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py
@@ -15,27 +15,29 @@
 # specific language governing permissions and limitations
 # under the License.
 #
+
 from typing import List
+from transformers import PretrainedConfig
 
 
-class TimerxlConfig:
-    model_type = "timerxl"
+class TimerConfig(PretrainedConfig):
+    model_type = "timer"
+    keys_to_ignore_at_inference = ["past_key_values"]
 
     def __init__(
         self,
-        input_token_len: int = 96,  # how many points as a token, don't change
-        hidden_size: int = 1024,  # model hidden size
-        intermediate_size: int = 2048,  # ffn middle size
-        output_token_lens: List[int] = [96],  # how many points as a token, 
don't change
+        input_token_len: int = 1,
+        hidden_size: int = 1024,
+        intermediate_size: int = 2048,
+        output_token_lens: List[int] = [1, 8, 32, 64],
         num_hidden_layers: int = 8,
         num_attention_heads: int = 8,
-        hidden_act: str = "silu",  # activation function
-        use_cache: bool = True,  # kv cache
-        rope_theta: int = 10000,  # ROBE parameter
+        hidden_act: str = "silu",
+        use_cache: bool = True,
+        rope_theta: int = 10000,
         attention_dropout: float = 0.0,
-        initializer_range: float = 0.02,  # be of no use, because we already 
have weights
+        initializer_range: float = 0.02,
         max_position_embeddings: int = 10000,
-        ckpt_path: str = None,  # weight path
         **kwargs,
     ):
         self.input_token_len = input_token_len
@@ -50,12 +52,7 @@ class TimerxlConfig:
         self.attention_dropout = attention_dropout
         self.initializer_range = initializer_range
         self.max_position_embeddings = max_position_embeddings
-        self.ckpt_path = ckpt_path
 
         super().__init__(
             **kwargs,
-        )
-
-    @classmethod
-    def from_dict(cls, config_dict: dict) -> "TimerxlConfig":
-        return cls(**config_dict)
+        )
\ No newline at end of file
diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py 
b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
new file mode 100644
index 00000000000..584d841d641
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
@@ -0,0 +1,591 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+from typing import Optional, Tuple, List, Union
+import torch
+from torch import nn
+import torch.nn.functional as F
+from transformers import PreTrainedModel, Cache, DynamicCache
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import 
_prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import MoeModelOutputWithPast, 
MoeCausalLMOutputWithPast
+from .configuration_timer import TimerConfig
+from .ts_generation_mixin import TSGenerationMixin
+
+
+def rotate_half(x):
+    x1 = x[..., : x.shape[-1] // 2]
+    x2 = x[..., x.shape[-1] // 2:]
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+    cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+    sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+    q_embed = (q * cos) + (rotate_half(q) * sin)
+    k_embed = (k * cos) + (rotate_half(k) * sin)
+    return q_embed, k_embed
+
+
+class TimerPatchEmbedding(nn.Module):
+    def __init__(self, config: TimerConfig):
+        super().__init__()
+        self.input_token_len = config.input_token_len
+        self.emb = nn.Linear(config.input_token_len,
+                             config.hidden_size, bias=False)
+
+    def forward(self, hidden_state: torch.Tensor):
+        hidden_state = hidden_state.unfold(
+            dimension=-1, size=self.input_token_len, step=self.input_token_len)
+        return self.emb(hidden_state)
+
+
+class TimerPointEmbedding(nn.Module):
+    def __init__(self, config: TimerConfig):
+        super().__init__()
+        self.emb_layer = nn.Linear(
+            config.input_token_len, config.hidden_size, bias=False)
+        self.gate_layer = nn.Linear(
+            config.input_token_len, config.hidden_size, bias=False)
+        self.act_fn = ACT2FN[config.hidden_act]
+
+    def forward(self, x):
+        emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x)
+        return emb
+
+
+class TimeMoeRotaryEmbedding(torch.nn.Module):
+    def __init__(self, dim, max_position_embeddings=10000, base=10000, 
device=None):
+        super().__init__()
+        self.dim = dim
+        self.max_position_embeddings = max_position_embeddings
+        self.base = base
+        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim,
+                                                     2, 
dtype=torch.int64).float().to(device) / self.dim))
+        self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+        # Build here to make `torch.jit.trace` work.
+        self._set_cos_sin_cache(
+            seq_len=max_position_embeddings, device=self.inv_freq.device, 
dtype=torch.get_default_dtype()
+        )
+
+    def _set_cos_sin_cache(self, seq_len, device, dtype):
+        self.max_seq_len_cached = seq_len
+        t = torch.arange(self.max_seq_len_cached, device=device,
+                         dtype=torch.int64).type_as(self.inv_freq)
+
+        freqs = torch.outer(t, self.inv_freq)
+        # Different from paper, but it uses a different permutation in order 
to obtain the same calculation
+        emb = torch.cat((freqs, freqs), dim=-1)
+        self.register_buffer(
+            "cos_cached", emb.cos().to(dtype), persistent=False)
+        self.register_buffer(
+            "sin_cached", emb.sin().to(dtype), persistent=False)
+
+    def forward(self, x, seq_len=None):
+        # x: [bs, num_attention_heads, seq_len, head_size]
+        if seq_len > self.max_seq_len_cached:
+            self._set_cos_sin_cache(
+                seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+        return (
+            self.cos_cached[:seq_len].to(dtype=x.dtype),
+            self.sin_cached[:seq_len].to(dtype=x.dtype),
+        )
+
+
+class TimerAttention(nn.Module):
+    def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None):
+        super().__init__()
+        self.layer_idx = layer_idx
+        self.hidden_size = config.hidden_size
+        self.num_heads = config.num_attention_heads
+        self.head_dim = self.hidden_size // self.num_heads
+        self.attention_dropout = config.attention_dropout
+        self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+        self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+        self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+        self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+        self.rotary_emb = TimeMoeRotaryEmbedding(
+            self.head_dim, 
max_position_embeddings=config.max_position_embeddings)
+
+    def forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_value: Optional[Cache] = None,
+            output_attentions: bool = False,
+            **kwargs,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 
Optional[Tuple[torch.Tensor]]]:
+        bsz, q_len, _ = hidden_states.size()
+
+        query_states = self.q_proj(hidden_states)
+        key_states = self.k_proj(hidden_states)
+        value_states = self.v_proj(hidden_states)
+
+        query_states = query_states.view(
+            bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        key_states = key_states.view(
+            bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+        value_states = value_states.view(
+            bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+        kv_seq_len = key_states.shape[-2]
+        if past_key_value is not None:
+            kv_seq_len += past_key_value.get_usable_length(
+                kv_seq_len, self.layer_idx)
+        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+        query_states, key_states = apply_rotary_pos_emb(
+            query_states, key_states, cos, sin, position_ids)
+
+        if past_key_value is not None:
+            key_states, value_states = past_key_value.update(
+                key_states, value_states, self.layer_idx)
+
+        attn_output = F.scaled_dot_product_attention(
+            query_states, key_states, value_states, attention_mask, 
dropout_p=self.attention_dropout)
+
+        attn_output = attn_output.transpose(1, 2).contiguous()
+        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+        attn_output = self.o_proj(attn_output)
+
+        if not output_attentions:
+            attn_weights = None
+
+        return attn_output, attn_weights, past_key_value
+
+
+class TimerMLP(nn.Module):
+    def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: 
str):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.intermediate_size = intermediate_size
+        self.gate_proj = nn.Linear(
+            self.hidden_size, self.intermediate_size, bias=False)
+        self.up_proj = nn.Linear(
+            self.hidden_size, self.intermediate_size, bias=False)
+        self.down_proj = nn.Linear(
+            self.intermediate_size, self.hidden_size, bias=False)
+        self.act_fn = ACT2FN[hidden_act]
+
+    def forward(self, hidden_state):
+        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * 
self.up_proj(hidden_state))
+
+
+class TimerDecoderLayer(nn.Module):
+    def __init__(self, config: TimerConfig, layer_idx: int):
+        super().__init__()
+        self.self_attn = TimerAttention(config, layer_idx)
+
+        self.ffn_layer = TimerMLP(
+            hidden_size=config.hidden_size,
+            intermediate_size=config.intermediate_size,
+            hidden_act=config.hidden_act,
+        )
+        self.norm1 = torch.nn.LayerNorm(config.hidden_size)
+        self.norm2 = torch.nn.LayerNorm(config.hidden_size)
+
+    def forward(
+            self,
+            hidden_states: torch.Tensor,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_value: Optional[Tuple[torch.Tensor]] = None,
+            output_attentions: Optional[bool] = False,
+            use_cache: Optional[bool] = False,
+            **kwargs,
+    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, 
Optional[torch.FloatTensor], Optional[torch.FloatTensor]]:
+        residual = hidden_states
+
+        # Self Attention
+        hidden_states, self_attn_weights, present_key_value = self.self_attn(
+            hidden_states=hidden_states,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_value=past_key_value,
+            output_attentions=output_attentions,
+            use_cache=use_cache,
+        )
+        hidden_states = residual + hidden_states
+        hidden_states = self.norm1(hidden_states)
+
+        # Fully Connected
+        residual = hidden_states
+        hidden_states = self.ffn_layer(hidden_states)
+        hidden_states = residual + hidden_states
+        hidden_states = self.norm2(hidden_states)
+
+        if not output_attentions:
+            self_attn_weights = None
+
+        if not use_cache:
+            present_key_value = None
+        return hidden_states, self_attn_weights, present_key_value
+
+
+class TimerPreTrainedModel(PreTrainedModel):
+    config_class = TimerConfig
+    base_model_prefix = "model"
+    supports_gradient_checkpointing = True
+    _no_split_modules = ["TimeMoeDecoderLayer"]
+    _skip_keys_device_placement = "past_key_values"
+    _supports_flash_attn_2 = True
+    _supports_sdpa = False
+    _supports_cache_class = True
+
+    def _init_weights(self, module):
+        std = self.config.initializer_range
+        if isinstance(module, torch.nn.Linear):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.bias is not None:
+                module.bias.data.zero_()
+        elif isinstance(module, torch.nn.Embedding):
+            module.weight.data.normal_(mean=0.0, std=std)
+            if module.padding_idx is not None:
+                module.weight.data[module.padding_idx].zero_()
+
+
+class TimerModel(TimerPreTrainedModel):
+    def __init__(self, config: TimerConfig):
+        super().__init__(config)
+        self.embed_layer = TimerPatchEmbedding(config)
+        self.layers = nn.ModuleList(
+            [TimerDecoderLayer(config, layer_idx)
+             for layer_idx in range(config.num_hidden_layers)]
+        )
+        self.norm = torch.nn.LayerNorm(config.hidden_size)
+        self.gradient_checkpointing = False
+
+    def forward(
+            self,
+            input_ids: torch.FloatTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+    ) -> Union[Tuple, MoeModelOutputWithPast]:
+        # input_ids is the input of time series, its shape is [batch_size, 
seq_len]
+        output_attentions = output_attentions if output_attentions is not None 
else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else 
self.config.output_hidden_states
+        )
+        use_cache = use_cache if use_cache is not None else 
self.config.use_cache
+
+        return_dict = return_dict if return_dict is not None else 
self.config.use_return_dict
+
+        # retrieve input_ids and inputs_embeds
+        if input_ids is not None and inputs_embeds is not None:
+            raise ValueError(
+                "You cannot specify both decoder_input_ids and 
decoder_inputs_embeds at the same time")
+        elif input_ids is not None:
+            batch_size, seq_length = input_ids.shape
+        elif inputs_embeds is not None:
+            batch_size, seq_length, _ = inputs_embeds.shape
+        else:
+            raise ValueError(
+                "You have to specify either decoder_input_ids or 
decoder_inputs_embeds")
+
+        if inputs_embeds is None:
+            inputs_embeds = self.embed_layer(input_ids)
+            seq_length = inputs_embeds.shape[1]
+
+        if self.gradient_checkpointing and self.training:
+            if use_cache:
+                use_cache = False
+
+        past_key_values_length = 0
+
+        if use_cache:
+            use_legacy_cache = not isinstance(past_key_values, Cache)
+            if use_legacy_cache:
+                past_key_values = DynamicCache.from_legacy_cache(
+                    past_key_values)
+            past_key_values_length = past_key_values.get_usable_length(
+                seq_length)
+
+        if position_ids is None:
+            device = input_ids.device if input_ids is not None else 
inputs_embeds.device
+            position_ids = torch.arange(
+                past_key_values_length, seq_length + past_key_values_length, 
dtype=torch.long, device=device
+            )
+            # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+            position_ids = position_ids.view(-1, seq_length)
+        else:
+            position_ids = position_ids.view(-1, seq_length).long()
+
+        # 4d mask is passed through the layers
+        attention_mask = _prepare_4d_causal_attention_mask(
+            attention_mask,
+            (batch_size, seq_length),
+            inputs_embeds,
+            past_key_values_length,
+            sliding_window=None,
+        )
+
+        hidden_states = inputs_embeds
+
+        # decoder layers
+        all_hidden_states = () if output_hidden_states else None
+        all_self_attns = () if output_attentions else None
+        next_decoder_cache = None
+
+        for decoder_layer in self.layers:
+            if output_hidden_states:
+                all_hidden_states += (hidden_states,)
+
+            if self.gradient_checkpointing and self.training:
+                layer_outputs = self._gradient_checkpointing_func(
+                    decoder_layer.__call__,
+                    hidden_states,
+                    attention_mask,
+                    position_ids,
+                    past_key_values,
+                    output_attentions,
+                    use_cache,
+                )
+            else:
+                layer_outputs = decoder_layer(
+                    hidden_states,
+                    attention_mask=attention_mask,
+                    position_ids=position_ids,
+                    past_key_value=past_key_values,
+                    output_attentions=output_attentions,
+                    use_cache=use_cache,
+                )
+
+            hidden_states = layer_outputs[0]
+
+            if output_attentions:
+                all_self_attns += (layer_outputs[1],)
+
+            if use_cache:
+                next_decoder_cache = layer_outputs[2]
+
+        hidden_states = self.norm(hidden_states)
+        # add hidden states from the last decoder layer
+        if output_hidden_states:
+            all_hidden_states += (hidden_states,)
+
+        next_cache = None
+        if use_cache:
+            next_cache = next_decoder_cache.to_legacy_cache(
+            ) if use_legacy_cache else next_decoder_cache
+
+        if not return_dict:
+            return tuple(
+                v
+                for v in [hidden_states, next_cache, all_hidden_states, 
all_self_attns]
+                if v is not None
+            )
+        return MoeModelOutputWithPast(
+            last_hidden_state=hidden_states,
+            past_key_values=next_cache,
+            hidden_states=all_hidden_states,
+            attentions=all_self_attns,
+        )
+
+
+class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
+    def __init__(self, config: TimerConfig):
+        super().__init__(config)
+        self.config = config
+        self.model = TimerModel(self.config)
+        lm_head_list = []
+        self.output_token_len_map = {}
+        for i, output_token_len in enumerate(self.config.output_token_lens):
+            lm_head_list.append(
+                nn.Linear(self.config.hidden_size, output_token_len, 
bias=False))
+            self.output_token_len_map[output_token_len] = i
+        self.lm_heads = nn.ModuleList(lm_head_list)
+        self.loss_function = torch.nn.MSELoss(reduction='none')
+        self.post_init()
+
+    def set_decoder(self, decoder):
+        self.model = decoder
+
+    def get_decoder(self):
+        return self.model
+
+    def forward(
+            self,
+            input_ids: torch.FloatTensor = None,
+            attention_mask: Optional[torch.Tensor] = None,
+            position_ids: Optional[torch.LongTensor] = None,
+            past_key_values: Optional[List[torch.FloatTensor]] = None,
+            inputs_embeds: Optional[torch.FloatTensor] = None,
+            labels: Optional[torch.FloatTensor] = None,
+            loss_masks: Optional[torch.FloatTensor] = None,
+            use_cache: Optional[bool] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            return_dict: Optional[bool] = None,
+            max_output_length: Optional[int] = None,
+            revin: Optional[bool] = False,
+    ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+
+        output_attentions = output_attentions if output_attentions is not None 
else self.config.output_attentions
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else 
self.config.output_hidden_states
+        )
+        return_dict = return_dict if return_dict is not None else 
self.config.use_return_dict
+
+        if revin:
+            mean, std = input_ids.mean(dim=-1, keepdim=True), 
input_ids.std(dim=-1, keepdim=True)
+            input_ids = (input_ids - mean) / std
+        outputs = self.model(
+            input_ids=input_ids,
+            attention_mask=attention_mask,
+            position_ids=position_ids,
+            past_key_values=past_key_values,
+            inputs_embeds=inputs_embeds,
+            use_cache=use_cache,
+            output_attentions=output_attentions,
+            output_hidden_states=output_hidden_states,
+            return_dict=return_dict,
+        )
+
+        hidden_states = outputs[0] if not return_dict else 
outputs.last_hidden_state
+        predictions = None
+
+        loss = None
+        if labels is not None:
+            ar_loss = 0.0
+            for lm_head, output_token_len in zip(self.lm_heads, 
self.config.output_token_lens):
+                one_predictions = lm_head(hidden_states)
+                one_loss = self.calc_ar_loss(
+                    one_predictions, labels, loss_masks, output_token_len)
+                ar_loss += one_loss
+                if predictions is None:
+                    predictions = one_predictions
+            loss = ar_loss / len(self.config.output_token_lens)
+        else:
+            if max_output_length is None:
+                output_token_len = self.config.output_token_lens[0]
+                max_output_length = output_token_len
+            else:
+                output_token_len = self.config.output_token_lens[0]
+                for h in self.config.output_token_lens[1:]:
+                    if h > max_output_length:
+                        break
+                    else:
+                        output_token_len = h
+            lm_head = 
self.lm_heads[self.output_token_len_map[output_token_len]]
+            predictions = lm_head(hidden_states)[:, -1, :]
+            if output_token_len > max_output_length:
+                predictions = predictions[:, :max_output_length]
+            if revin:
+                predictions = predictions * std + mean
+        if not return_dict:
+            output = (predictions,) + outputs[1:]
+            return (loss) + output if loss is not None else output
+
+        return MoeCausalLMOutputWithPast(
+            loss=loss,
+            logits=predictions,
+            past_key_values=outputs.past_key_values,
+            hidden_states=outputs.hidden_states,
+            attentions=outputs.attentions,
+        )
+
+    def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len):
+        seq_len = predictions.shape[1] * self.config.input_token_len
+        labels = labels[:, :seq_len -
+                            self.config.input_token_len + output_token_len]
+        shift_labels = labels.unfold(
+            dimension=-1, size=output_token_len, 
step=self.config.input_token_len)
+
+        # Calculate loss with mask
+        losses = self.loss_function(predictions, shift_labels).mean(dim=-1)
+        if loss_masks is not None:
+            losses = losses * loss_masks
+            loss = losses.sum() / loss_masks.sum()
+        else:
+            loss = torch.mean(losses)
+
+        return loss
+
+    def prepare_inputs_for_generation(
+            self, input_ids, past_key_values=None, attention_mask=None, 
inputs_embeds=None, revin=True, **kwargs
+    ):
+        # Omit tokens covered by past_key_values
+        if past_key_values is not None:
+            if isinstance(past_key_values, Cache):
+                cache_length = past_key_values.get_seq_length()
+                if isinstance(past_key_values, DynamicCache):
+                    past_length = past_key_values.seen_tokens
+                else:
+                    past_length = cache_length
+
+                max_cache_length = past_key_values.get_max_length()
+            else:
+                cache_length = past_length = past_key_values[0][0].shape[2]
+                max_cache_length = None
+
+            # Keep only the unprocessed tokens:
+            # 1 - If the length of the attention_mask exceeds the length of 
input_ids, then we are in a setting where
+            # some of the inputs are exclusively passed as part of the cache 
(e.g. when passing input_embeds as
+            # input)
+            if attention_mask is not None and attention_mask.shape[1] > (
+                    input_ids.shape[1] // self.config.input_token_len):
+                input_ids = input_ids[:, -
+                                         (attention_mask.shape[1] - 
past_length):]
+            # 2 - If the past_length is smaller than input_ids', then 
input_ids holds all input tokens. We can discard
+            # input_ids based on the past_length.
+            elif past_length < (input_ids.shape[1] // 
self.config.input_token_len):
+                input_ids = input_ids[:, past_length *
+                                         self.config.input_token_len:]
+            # 3 - Otherwise (past_length >= (input_ids.shape[1] // 
self.config.input_token_len)), let's assume input_ids only has unprocessed 
tokens.
+
+            # If we are about to go beyond the maximum cache length, we need 
to crop the input attention mask.
+            if (
+                    max_cache_length is not None
+                    and attention_mask is not None
+                    and cache_length + (input_ids.shape[1] // 
self.config.input_token_len) > max_cache_length
+            ):
+                attention_mask = attention_mask[:, -max_cache_length:]
+
+        position_ids = kwargs.get("position_ids", None)
+        if attention_mask is not None and position_ids is None:
+            # create position_ids on the fly for batch generation
+            position_ids = attention_mask.long().cumsum(-1) - 1
+            position_ids.masked_fill_(attention_mask == 0, 1)
+            if past_key_values:
+                position_ids = position_ids[:, -
+                                               (input_ids.shape[1] // 
self.config.input_token_len):]
+
+        # if `inputs_embeds` are passed, we only want to use them in the 1st 
generation step
+        if inputs_embeds is not None and past_key_values is None:
+            model_inputs = {"inputs_embeds": inputs_embeds}
+        else:
+            model_inputs = {"input_ids": input_ids}
+
+        model_inputs.update(
+            {
+                "position_ids": position_ids,
+                "past_key_values": past_key_values,
+                "use_cache": kwargs.get("use_cache"),
+                "attention_mask": attention_mask,
+                "revin": revin
+            }
+        )
+        return model_inputs
\ No newline at end of file
diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py 
b/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py
new file mode 100644
index 00000000000..07f16bd9a04
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py
@@ -0,0 +1,297 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import warnings
+from typing import Any, Dict, List, Optional, Union, Callable
+import torch
+from transformers import GenerationMixin, LogitsProcessorList, 
StoppingCriteriaList
+from transformers.generation import validate_stopping_criteria, 
EosTokenCriteria
+from transformers.generation.utils import GenerateNonBeamOutput, 
GenerateEncoderDecoderOutput, \
+    GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput
+from transformers.utils import ModelOutput
+
+
+class TSGenerationMixin(GenerationMixin):
+
+    @torch.no_grad()
+    def generate(
+            self,
+            inputs: Optional[torch.Tensor] = None,
+            generation_config: Optional[GenerationConfig] = None,
+            logits_processor: Optional[LogitsProcessorList] = None,
+            stopping_criteria: Optional[StoppingCriteriaList] = None,
+            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], 
List[int]]] = None,
+            synced_gpus: Optional[bool] = None,
+            assistant_model: Optional["PreTrainedModel"] = None,
+            streamer: Optional["BaseStreamer"] = None,
+            negative_prompt_ids: Optional[torch.Tensor] = None,
+            negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+            **kwargs,
+    ) -> Union[GenerateOutput, torch.LongTensor]:
+        if len(inputs.shape) == 2:
+            batch_size, cur_len = inputs.shape
+            if cur_len < self.config.input_token_len:
+                raise ValueError(
+                    f"Input length must be at least 
{self.config.input_token_len}")
+            elif cur_len % self.config.input_token_len != 0:
+                new_len = (cur_len // self.config.input_token_len) * \
+                          self.config.input_token_len
+                inputs = inputs[:, -new_len:]
+        else:
+            raise ValueError('Input shape must be: [batch_size, seq_len]')
+        return super().generate(inputs=inputs, 
generation_config=generation_config, logits_processor=logits_processor,
+                                stopping_criteria=stopping_criteria, 
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+                                synced_gpus=synced_gpus, 
assistant_model=assistant_model, streamer=streamer,
+                                negative_prompt_ids=negative_prompt_ids,
+                                
negative_prompt_attention_mask=negative_prompt_attention_mask, **kwargs)
+
+    def _greedy_search(
+            self,
+            input_ids: torch.Tensor,
+            logits_processor: Optional[LogitsProcessorList] = None,
+            stopping_criteria: Optional[StoppingCriteriaList] = None,
+            max_length: Optional[int] = None,
+            pad_token_id: Optional[int] = None,
+            eos_token_id: Optional[Union[int, List[int]]] = None,
+            output_attentions: Optional[bool] = None,
+            output_hidden_states: Optional[bool] = None,
+            output_scores: Optional[bool] = None,
+            output_logits: Optional[bool] = None,
+            return_dict_in_generate: Optional[bool] = None,
+            synced_gpus: bool = False,
+            streamer: Optional["BaseStreamer"] = None,
+            **model_kwargs,
+    ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
+        input_ids = input_ids.to(self.device)
+        batch_size, cur_len = input_ids.shape
+        # init values
+        logits_processor = logits_processor if logits_processor is not None 
else LogitsProcessorList()
+        stopping_criteria = stopping_criteria if stopping_criteria is not None 
else StoppingCriteriaList()
+        if max_length is not None:
+            warnings.warn(
+                "`max_length` is deprecated in this function, use"
+                " 
`stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`
 instead.",
+                UserWarning,
+            )
+            stopping_criteria = validate_stopping_criteria(
+                stopping_criteria, max_length)
+        pad_token_id = pad_token_id if pad_token_id is not None else 
self.generation_config.pad_token_id
+        if eos_token_id is not None:
+            stopping_criteria.append(
+                EosTokenCriteria(eos_token_id=eos_token_id))
+        else:
+            # remove when the method is totally private
+            # need to get `eos_token_id` and add stopping criteria, so that 
generation does not go forever
+            eos_token_id = [
+                criteria.eos_token_id.tolist() for criteria in 
stopping_criteria if hasattr(criteria, "eos_token_id")
+            ]
+            eos_token_id = eos_token_id[0] if eos_token_id else None
+            if eos_token_id is None and self.generation_config.eos_token_id is 
not None:
+                eos_token_id = self.generation_config.eos_token_id
+                stopping_criteria.append(
+                    EosTokenCriteria(eos_token_id=eos_token_id))
+
+        if isinstance(eos_token_id, int):
+            eos_token_id = [eos_token_id]
+        output_scores = output_scores if output_scores is not None else 
self.generation_config.output_scores
+        output_attentions = (
+            output_attentions if output_attentions is not None else 
self.generation_config.output_attentions
+        )
+        output_hidden_states = (
+            output_hidden_states if output_hidden_states is not None else 
self.generation_config.output_hidden_states
+        )
+        return_dict_in_generate = (
+            return_dict_in_generate
+            if return_dict_in_generate is not None
+            else self.generation_config.return_dict_in_generate
+        )
+
+        # init attention / hidden states / scores tuples
+        raw_logits = () if (return_dict_in_generate and output_logits) else 
None
+        scores = () if (return_dict_in_generate and output_scores) else None
+        decoder_attentions = () if (return_dict_in_generate and 
output_attentions) else None
+        cross_attentions = () if (return_dict_in_generate and 
output_attentions) else None
+        decoder_hidden_states = () if (
+                return_dict_in_generate and output_hidden_states) else None
+
+        # if model is an encoder-decoder, retrieve encoder attention weights 
and hidden states
+        if return_dict_in_generate and self.config.is_encoder_decoder:
+            encoder_attentions = model_kwargs["encoder_outputs"].get(
+                "attentions") if output_attentions else None
+            encoder_hidden_states = (
+                model_kwargs["encoder_outputs"].get(
+                    "hidden_states") if output_hidden_states else None
+            )
+
+        # keep track of which sequences are already finished
+        if "inputs_embeds" in model_kwargs:
+            cur_len = model_kwargs["inputs_embeds"].shape[1]
+        this_peer_finished = False
+        unfinished_sequences = torch.ones(
+            batch_size, dtype=torch.long, device=input_ids.device)
+        model_kwargs["cache_position"] = torch.arange(
+            cur_len, device=input_ids.device)
+        true_seq_len = cur_len // self.config.input_token_len
+        model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 
-true_seq_len:]
+        max_length = stopping_criteria.max_length
+        while self._has_unfinished_sequences(this_peer_finished, synced_gpus, 
device=input_ids.device):
+            # prepare model inputs
+            model_inputs = self.prepare_inputs_for_generation(
+                input_ids, **model_kwargs)
+
+            input_length = input_ids.shape[1]
+
+            # forward pass to get next token
+            outputs = self(
+                **model_inputs,
+                return_dict=True,
+                output_attentions=output_attentions,
+                output_hidden_states=output_hidden_states,
+                max_output_length=max_length - input_length,
+            )
+
+            if synced_gpus and this_peer_finished:
+                continue  # don't waste resources running the code we don't 
need
+
+            next_token_logits = outputs.logits
+
+            # pre-process distribution
+            next_tokens_scores = logits_processor(input_ids, next_token_logits)
+
+            # Store scores, attentions and hidden_states when required
+            if return_dict_in_generate:
+                if output_scores:
+                    scores += (next_tokens_scores,)
+                if output_logits:
+                    raw_logits += (next_token_logits,)
+                if output_attentions:
+                    decoder_attentions += (
+                        (outputs.decoder_attentions,) if 
self.config.is_encoder_decoder else (
+                            outputs.attentions,)
+                    )
+                    if self.config.is_encoder_decoder:
+                        cross_attentions += (outputs.cross_attentions,)
+
+                if output_hidden_states:
+                    decoder_hidden_states += (
+                        (outputs.decoder_hidden_states,)
+                        if self.config.is_encoder_decoder
+                        else (outputs.hidden_states,)
+                    )
+
+            # argmax
+            # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+            next_tokens = next_tokens_scores
+
+            # finished sentences should have their next token be a padding 
token
+            if eos_token_id is not None:
+                if pad_token_id is None:
+                    raise ValueError(
+                        "If `eos_token_id` is defined, make sure that 
`pad_token_id` is defined.")
+                next_tokens = next_tokens * unfinished_sequences + \
+                              pad_token_id * (1 - unfinished_sequences)
+
+            # update generated ids, model inputs, and length for next step
+            horizon_length = next_tokens.shape[1] // 
self.config.input_token_len
+
+            input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+            if streamer is not None:
+                streamer.put(next_tokens.cpu())
+            model_kwargs = self._update_model_kwargs_for_generation(
+                outputs,
+                model_kwargs,
+                horizon_length=horizon_length,
+                is_encoder_decoder=self.config.is_encoder_decoder,
+            )
+            unfinished_sequences = unfinished_sequences & ~stopping_criteria(
+                input_ids, scores)
+            this_peer_finished = unfinished_sequences.max() == 0
+
+        if input_ids.shape[1] > max_length:
+            input_ids = input_ids[:, :max_length]
+
+        if streamer is not None:
+            streamer.end()
+
+        if return_dict_in_generate:
+            if self.config.is_encoder_decoder:
+                return GenerateEncoderDecoderOutput(
+                    sequences=input_ids,
+                    scores=scores,
+                    logits=raw_logits,
+                    encoder_attentions=encoder_attentions,
+                    encoder_hidden_states=encoder_hidden_states,
+                    decoder_attentions=decoder_attentions,
+                    cross_attentions=cross_attentions,
+                    decoder_hidden_states=decoder_hidden_states,
+                    past_key_values=model_kwargs.get("past_key_values"),
+                )
+            else:
+                return GenerateDecoderOnlyOutput(
+                    sequences=input_ids,
+                    scores=scores,
+                    logits=raw_logits,
+                    attentions=decoder_attentions,
+                    hidden_states=decoder_hidden_states,
+                    past_key_values=model_kwargs.get("past_key_values"),
+                )
+        else:
+            return input_ids[:, -(max_length - cur_len):]
+
+    def _update_model_kwargs_for_generation(
+            self,
+            outputs: ModelOutput,
+            model_kwargs: Dict[str, Any],
+            horizon_length: int = 1,
+            is_encoder_decoder: bool = False,
+            standardize_cache_format: bool = False,
+    ) -> Dict[str, Any]:
+        # update past_key_values
+        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+            outputs, standardize_cache_format=standardize_cache_format
+        )
+        if getattr(outputs, "state", None) is not None:
+            model_kwargs["state"] = outputs.state
+
+        # update token_type_ids with last value
+        if "token_type_ids" in model_kwargs:
+            token_type_ids = model_kwargs["token_type_ids"]
+            model_kwargs["token_type_ids"] = torch.cat(
+                [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
+
+        if not is_encoder_decoder:
+            # update attention mask
+            if "attention_mask" in model_kwargs:
+                attention_mask = model_kwargs["attention_mask"]
+                model_kwargs["attention_mask"] = torch.cat(
+                    [attention_mask, 
attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1
+                )
+        else:
+            # update decoder attention mask
+            if "decoder_attention_mask" in model_kwargs:
+                decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+                model_kwargs["decoder_attention_mask"] = torch.cat(
+                    [decoder_attention_mask, decoder_attention_mask.new_ones(
+                        (decoder_attention_mask.shape[0], horizon_length))],
+                    dim=-1,
+                )
+
+        if "cache_position" in model_kwargs and model_kwargs["cache_position"] 
is not None:
+            model_kwargs["cache_position"] = 
model_kwargs["cache_position"][-1:] + horizon_length
+
+        return model_kwargs
\ No newline at end of file
diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml
index e8cacb42a9f..5b0d02f466e 100644
--- a/iotdb-core/ainode/pyproject.toml
+++ b/iotdb-core/ainode/pyproject.toml
@@ -21,7 +21,7 @@ build-backend = "poetry.core.masonry.api"
 
 [tool.poetry]
 name = "apache-iotdb-ainode"
-version = "2.0.4.dev"
+version = "2.0.5.dev"
 description = "Apache IoTDB AINode"
 readme = "README.md"
 authors = ["Apache Software Foundation <[email protected]>"]

Reply via email to