This is an automated email from the ASF dual-hosted git repository.
yongzao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new 711da9799d3 [AINode] Refactor the built-in TimerXL (#15655)
711da9799d3 is described below
commit 711da9799d31abe1b08429e452f01ea8e8bd4d90
Author: Yongzao <[email protected]>
AuthorDate: Fri Jun 6 20:42:57 2025 +0800
[AINode] Refactor the built-in TimerXL (#15655)
---
.../ainode/ainode/TimerXL/layers/Attn_Bias.py | 108 ----
.../ainode/TimerXL/layers/Attn_Projection.py | 127 ----
iotdb-core/ainode/ainode/TimerXL/layers/Embed.py | 290 ---------
.../ainode/TimerXL/layers/SelfAttention_Family.py | 207 -------
.../ainode/TimerXL/layers/Transformer_EncDec.py | 329 ----------
.../ainode/ainode/TimerXL/layers/__init__.py | 17 -
.../ainode/ainode/TimerXL/models/__init__.py | 17 -
.../ainode/ainode/TimerXL/models/timer_xl.py | 446 --------------
.../ainode/core/manager/inference_manager.py | 4 +-
.../ainode/core/model/built_in_model_factory.py | 6 +-
.../{TimerXL => core/model/timerxl}/__init__.py | 0
.../model/timerxl}/configuration_timer.py | 30 +-
.../ainode/core/model/timerxl/modeling_timer.py | 680 +++++++++++++++++++++
.../core/model/timerxl/ts_generation_mixin.py | 366 +++++++++++
14 files changed, 1066 insertions(+), 1561 deletions(-)
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py
b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py
deleted file mode 100644
index 3c5ad760032..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py
+++ /dev/null
@@ -1,108 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import abc
-import math
-
-import torch
-from einops import rearrange
-from torch import nn
-
-
-class AttentionBias(nn.Module, abc.ABC):
- def __init__(self, dim: int, num_heads: int):
- super().__init__()
- assert num_heads > 0 and dim % num_heads == 0
-
- self.num_heads = num_heads
- self.head_dim = dim // num_heads
-
- @abc.abstractmethod
- def forward(self, query_id, kv_id): ...
-
-
-class BinaryAttentionBias(AttentionBias):
- def __init__(self, dim: int, num_heads: int):
- super().__init__(dim, num_heads)
- self.emb = nn.Embedding(num_embeddings=2, embedding_dim=self.num_heads)
-
- def forward(self, query_id, kv_id):
- ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2))
- weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1
1")
- bias = ~ind * weight[:1] + ind * weight[1:]
- return bias
-
-
-def _relative_position_bucket(
- relative_position, bidirectional=True, num_buckets=32, max_distance=128
-):
- relative_buckets = 0
- if bidirectional:
- num_buckets //= 2
- relative_buckets += (relative_position > 0).to(torch.long) *
num_buckets
- relative_position = torch.abs(relative_position)
- else:
- relative_position = -torch.min(
- relative_position, torch.zeros_like(relative_position)
- )
-
- max_exact = num_buckets // 2
- is_small = relative_position < max_exact
- relative_position_if_large = max_exact + (
- torch.log(relative_position.float() / max_exact)
- / math.log(max_distance / max_exact)
- * (num_buckets - max_exact)
- ).to(torch.long)
- relative_position_if_large = torch.min(
- relative_position_if_large,
- torch.full_like(relative_position_if_large, num_buckets - 1),
- )
-
- relative_buckets += torch.where(
- is_small, relative_position, relative_position_if_large
- )
- return relative_buckets
-
-
-class T5AttentionBias(AttentionBias):
- def __init__(self, dim: int, num_heads: int):
- super().__init__(dim, num_heads)
- self.num_buckets = 32
- self.max_distance = 32
- self.relative_attention_bias = nn.Embedding(self.num_buckets, 1)
-
- def forward(self, n_vars, n_tokens):
- context_position = torch.arange(
- n_tokens,
- dtype=torch.long,
- )[:, None]
- memory_position = torch.arange(
- n_tokens,
- dtype=torch.long,
- )[None, :]
- relative_position = memory_position - context_position
- bucket = _relative_position_bucket(
- relative_position=relative_position,
- bidirectional=False,
- num_buckets=self.num_buckets,
- max_distance=self.max_distance,
- ).to(self.relative_attention_bias.weight.device)
- bias = self.relative_attention_bias(bucket).squeeze(-1)
- bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1])
- mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device)
- final_bias = torch.kron(mask1, bias)
- return final_bias
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py
b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py
deleted file mode 100644
index 18e2b29c3d6..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py
+++ /dev/null
@@ -1,127 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import abc
-from functools import cached_property
-
-import torch
-from einops import einsum, rearrange, repeat
-from torch import nn
-
-
-class Projection(nn.Module, abc.ABC):
- def __init__(self, proj_width: int, num_heads: int, **kwargs):
- super().__init__()
- self.proj_width = proj_width
- self.num_heads = num_heads
-
- @abc.abstractmethod
- def forward(self, x, seq_id): ...
-
-
-class RotaryProjection(Projection):
- def __init__(
- self, *, proj_width: int, num_heads: int, max_len: int = 512, base:
int = 10000
- ):
- super().__init__(proj_width, num_heads)
- assert (
- self.proj_width % 2 == 0
- ), f"proj_width must be even, got {self.proj_width}"
- self.register_buffer(
- "theta",
- 1.0
- / torch.pow(
- base,
- torch.arange(0, self.proj_width, 2, dtype=torch.float)
- / self.proj_width,
- ),
- persistent=False,
- )
- self.register_buffer("cos", None, persistent=False)
- self.register_buffer("sin", None, persistent=False)
- self._init_freq(max_len=max_len)
-
- def _init_freq(self, max_len: int):
- if self.cos is None or self.cos.size(-2) < max_len:
- position = torch.arange(
- max_len, device=self.theta.device, dtype=self.theta.dtype
- )
- m_theta = einsum(position, self.theta, "length, width -> length
width")
- m_theta = repeat(m_theta, "length width -> length (width 2)")
- self.register_buffer("cos", torch.cos(m_theta), persistent=False)
- self.register_buffer("sin", torch.sin(m_theta), persistent=False)
-
- @staticmethod
- def _rotate(x):
- x1, x2 = rearrange(x, "... (dim r) -> r ... dim", r=2)
- return rearrange([-x2, x1], "r ... dim -> ... (dim r)", r=2) # noqa
-
- def forward(self, x, seq_id):
- self._init_freq(max_len=seq_id.max() + 1)
- rot_cos = self.cos[seq_id]
- rot_sin = self.sin[seq_id]
- return rot_cos * x + rot_sin * self._rotate(x)
-
-
-class QueryKeyProjection(nn.Module):
- def __init__(
- self, dim: int, num_heads: int, proj_layer, kwargs=None,
partial_factor=None
- ):
- super().__init__()
- if partial_factor is not None:
- assert (
- 0.0 <= partial_factor[0] < partial_factor[1] <= 1.0
- ), f"got {partial_factor[0]}, {partial_factor[1]}"
- assert num_heads > 0 and dim % num_heads == 0
-
- self.head_dim = dim // num_heads
- self.partial_factor = partial_factor
- self.query_proj = proj_layer(
- proj_width=self.proj_width,
- num_heads=num_heads,
- **(kwargs or {}),
- )
- self.key_proj = self.query_proj
-
- @cached_property
- def proj_width(self) -> int:
- if self.partial_factor is None:
- return self.head_dim
- return int(self.head_dim * (self.partial_factor[1] -
self.partial_factor[0]))
-
- @cached_property
- def split_sizes(self):
- if self.partial_factor is None:
- return 0, self.head_dim, 0
- return (
- int(self.partial_factor[0] * self.head_dim),
- self.proj_width,
- int((1.0 - self.partial_factor[1]) * self.head_dim),
- )
-
- def forward(self, query, key, query_id, kv_id):
- if self.partial_factor is not None:
- queries = list(query.split(self.split_sizes, dim=-1))
- keys = list(key.split(self.split_sizes, dim=-1))
- queries[1] = self.query_proj(queries[1], seq_id=query_id)
- keys[1] = self.key_proj(keys[1], seq_id=kv_id)
- query = torch.cat(queries, dim=-1)
- key = torch.cat(keys, dim=-1)
- else:
- query = self.query_proj(query, seq_id=query_id)
- key = self.key_proj(key, seq_id=kv_id)
- return query, key
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py
b/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py
deleted file mode 100644
index 8c3cf570baf..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py
+++ /dev/null
@@ -1,290 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-import math
-
-import torch
-import torch.nn as nn
-from torch.jit import is_scripting
-
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-
-class PositionalEmbedding(nn.Module):
- def __init__(self, d_model, max_len=6500):
- super(PositionalEmbedding, self).__init__()
- # Compute the positional encodings once in log space.
- pe = torch.zeros(max_len, d_model).float()
- pe.require_grad = False
-
- position = torch.arange(0, max_len).float().unsqueeze(1)
- div_term = (
- torch.arange(0, d_model, 2).float() * -(math.log(10000.0) /
d_model)
- ).exp()
-
- pe[:, 0::2] = torch.sin(position * div_term)
- pe[:, 1::2] = torch.cos(position * div_term)
-
- pe = pe.unsqueeze(0)
- self.register_buffer("pe", pe)
-
- def forward(self, x):
- return self.pe[:, : x.size(1)]
-
-
-class TokenEmbedding(nn.Module):
- def __init__(self, c_in, d_model):
- super(TokenEmbedding, self).__init__()
- padding = 1 if torch.__version__ >= "1.5.0" else 2
- self.tokenConv = nn.Conv1d(
- in_channels=c_in,
- out_channels=d_model,
- kernel_size=3,
- padding=padding,
- padding_mode="circular",
- bias=False,
- )
- for m in self.modules():
- if isinstance(m, nn.Conv1d):
- nn.init.kaiming_normal_(
- m.weight, mode="fan_in", nonlinearity="leaky_relu"
- )
-
- def forward(self, x):
- x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
- return x
-
-
-class FixedEmbedding(nn.Module):
- def __init__(self, c_in, d_model):
- super(FixedEmbedding, self).__init__()
-
- w = torch.zeros(c_in, d_model).float()
- w.require_grad = False
-
- position = torch.arange(0, c_in).float().unsqueeze(1)
- div_term = (
- torch.arange(0, d_model, 2).float() * -(math.log(10000.0) /
d_model)
- ).exp()
-
- w[:, 0::2] = torch.sin(position * div_term)
- w[:, 1::2] = torch.cos(position * div_term)
-
- self.emb = nn.Embedding(c_in, d_model)
- self.emb.weight = nn.Parameter(w, requires_grad=False)
-
- def forward(self, x):
- return self.emb(x).detach()
-
-
-class TemporalEmbedding(nn.Module):
- def __init__(self, d_model, embed_type="fixed", freq="h"):
- super(TemporalEmbedding, self).__init__()
-
- minute_size = 4
- hour_size = 24
- weekday_size = 7
- day_size = 32
- month_size = 13
-
- Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
- if freq == "t":
- self.minute_embed = Embed(minute_size, d_model)
- self.hour_embed = Embed(hour_size, d_model)
- self.weekday_embed = Embed(weekday_size, d_model)
- self.day_embed = Embed(day_size, d_model)
- self.month_embed = Embed(month_size, d_model)
-
- def forward(self, x):
- x = x.long()
- minute_x = (
- self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed")
else 0.0
- )
- hour_x = self.hour_embed(x[:, :, 3])
- weekday_x = self.weekday_embed(x[:, :, 2])
- day_x = self.day_embed(x[:, :, 1])
- month_x = self.month_embed(x[:, :, 0])
-
- return hour_x + weekday_x + day_x + month_x + minute_x
-
-
-class TimeFeatureEmbedding(nn.Module):
- def __init__(self, d_model, embed_type="timeF", freq="h"):
- super(TimeFeatureEmbedding, self).__init__()
-
- freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3,
"b": 3}
- d_inp = freq_map[freq]
- self.embed = nn.Linear(d_inp, d_model, bias=False)
-
- def forward(self, x):
- return self.embed(x)
-
-
-class DataEmbedding(nn.Module):
- def __init__(self, c_in, d_model, embed_type="fixed", freq="h",
dropout=0.1):
- super(DataEmbedding, self).__init__()
-
- self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
- self.position_embedding = PositionalEmbedding(d_model=d_model)
- self.temporal_embedding = (
- TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq)
- if embed_type != "timeF"
- else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq)
- )
- self.dropout = nn.Dropout(p=dropout)
-
- def forward(self, x, x_mark):
- if x_mark is None:
- x = self.value_embedding(x) + self.position_embedding(x)
- else:
- x = (
- self.value_embedding(x)
- + self.temporal_embedding(x_mark)
- + self.position_embedding(x)
- )
- return self.dropout(x)
-
-
-class DataEmbedding_inverted(nn.Module):
- def __init__(self, c_in, d_model, embed_type="fixed", freq="h",
dropout=0.1):
- super(DataEmbedding_inverted, self).__init__()
- self.value_embedding = nn.Linear(c_in, d_model)
- self.dropout = nn.Dropout(p=dropout)
-
- def forward(self, x, x_mark):
- x = x.permute(0, 2, 1)
- # x: [Batch Variate Time]
- if x_mark is None:
- x = self.value_embedding(x)
- else:
- x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)],
1))
- # x: [Batch Variate d_model]
- return self.dropout(x)
-
-
-class DataEmbedding_wo_pos(nn.Module):
- def __init__(self, c_in, d_model, embed_type="fixed", freq="h",
dropout=0.1):
- super(DataEmbedding_wo_pos, self).__init__()
-
- self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
- self.position_embedding = PositionalEmbedding(d_model=d_model)
- self.temporal_embedding = (
- TemporalEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq)
- if embed_type != "timeF"
- else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type,
freq=freq)
- )
- self.dropout = nn.Dropout(p=dropout)
-
- def forward(self, x, x_mark):
- if x_mark is None:
- x = self.value_embedding(x)
- else:
- x = self.value_embedding(x) + self.temporal_embedding(x_mark)
- return self.dropout(x)
-
-
-class PatchEmbedding(nn.Module):
- def __init__(self, d_model, patch_len, stride, padding, dropout):
- super(PatchEmbedding, self).__init__()
- # Patching
- self.patch_len = patch_len
- self.stride = stride
- self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
-
- # Backbone, Input encoding: projection of feature vectors onto a d-dim
vector space
- self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
-
- # Positional embedding
- self.position_embedding = PositionalEmbedding(d_model)
-
- # Residual dropout
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, x):
- # do patching
- n_vars = x.shape[1]
- x = self.padding_patch_layer(x)
- x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
- x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
- # Input encoding
- x = self.value_embedding(x) + self.position_embedding(x)
- return self.dropout(x), n_vars
-
-
-class TimerPatchEmbedding(nn.Module):
- def __init__(self, config: TimerxlConfig):
- super().__init__()
- self.input_token_len = config.input_token_len
- self.emb = nn.Linear(config.input_token_len, config.hidden_size,
bias=False)
-
- def forward(self, hidden_state: torch.Tensor):
- hidden_state = hidden_state.unfold(
- dimension=-1, size=self.input_token_len, step=self.input_token_len
- )
- return self.emb(hidden_state)
-
-
-class TimeMoeRotaryEmbedding(torch.nn.Module):
- def __init__(self, dim, max_position_embeddings=10000, base=10000,
device=None):
- super().__init__()
- self.dim = dim
- self.max_position_embeddings = max_position_embeddings
- self.base = base
- self.max_seq_len_cached: int = 0
- inv_freq = 1.0 / (
- self.base
- ** (
- torch.arange(0, self.dim, 2,
dtype=torch.int64).float().to(device)
- / self.dim
- )
- )
- self.register_buffer("inv_freq", inv_freq, persistent=False)
-
- # Build here to make `torch.jit.trace` work.
- self._set_cos_sin_cache(
- seq_len=max_position_embeddings,
- device=self.inv_freq.device,
- dtype=torch.get_default_dtype(),
- )
-
- def _set_cos_sin_cache(
- self, seq_len: int, device: torch.device, dtype: torch.dtype
- ):
- self.max_seq_len_cached = int(seq_len)
- t = torch.arange(
- self.max_seq_len_cached, device=device, dtype=torch.int64
- ).type_as(self.inv_freq)
-
- freqs = torch.outer(t, self.inv_freq)
- # Different from paper, but it uses a different permutation in order
to obtain the same calculation
- emb = torch.cat((freqs, freqs), dim=-1)
- if not is_scripting():
- self.register_buffer("cos_cached", emb.cos().to(dtype),
persistent=False)
- self.register_buffer("sin_cached", emb.sin().to(dtype),
persistent=False)
- else:
- self.cos_cached = emb.cos().to(dtype)
- self.sin_cached = emb.sin().to(dtype)
-
- def forward(self, x, seq_len: int = 0):
- # x: [bs, num_attention_heads, seq_len, head_size]
- if seq_len > self.max_seq_len_cached:
- self._set_cos_sin_cache(seq_len=seq_len, device=x.device,
dtype=x.dtype)
-
- return (
- self.cos_cached[:seq_len].to(dtype=x.dtype),
- self.sin_cached[:seq_len].to(dtype=x.dtype),
- )
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py
b/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py
deleted file mode 100644
index 4a2fb0d27e0..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py
+++ /dev/null
@@ -1,207 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from math import sqrt
-from typing import Any, Optional, Tuple
-
-import numpy as np
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from einops import repeat
-
-from ainode.core.util.huggingface_cache import Cache, DynamicCache
-from ainode.core.util.masking import (
- TimerCovariateMask,
- TimerMultivariateMask,
- TriangularCausalMask,
-)
-from ainode.TimerXL.layers.Attn_Bias import BinaryAttentionBias
-from ainode.TimerXL.layers.Attn_Projection import QueryKeyProjection,
RotaryProjection
-from ainode.TimerXL.layers.Embed import TimeMoeRotaryEmbedding
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-
-def rotate_half(x):
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
- cos = cos[position_ids].unsqueeze(unsqueeze_dim)
- sin = sin[position_ids].unsqueeze(unsqueeze_dim)
- q_embed = (q * cos) + (rotate_half(q) * sin)
- k_embed = (k * cos) + (rotate_half(k) * sin)
- return q_embed, k_embed
-
-
-class FullAttention(nn.Module):
- def __init__(
- self, mask_flag=True, scale=None, attention_dropout=0.1,
output_attention=False
- ):
- super(FullAttention, self).__init__()
- self.scale = scale
- self.mask_flag = mask_flag
- self.output_attention = output_attention
- self.dropout = nn.Dropout(attention_dropout)
-
- def forward(
- self,
- queries,
- keys,
- values,
- attn_mask,
- n_vars=None,
- n_tokens=None,
- tau=None,
- delta=None,
- ):
- B, L, H, E = queries.shape
- _, S, _, D = values.shape
- scale = self.scale or 1.0 / sqrt(E)
-
- scores = torch.einsum("blhe,bshe->bhls", queries, keys)
-
- if self.mask_flag:
- if attn_mask is None:
- attn_mask = TriangularCausalMask(B, L, device=queries.device)
-
- scores.masked_fill_(attn_mask.mask, -np.inf)
-
- A = self.dropout(torch.softmax(scale * scores, dim=-1))
- V = torch.einsum("bhls,bshd->blhd", A, values)
-
- if self.output_attention:
- return V.contiguous(), A
- else:
- return V.contiguous(), None
-
-
-class TimerAttention(nn.Module):
- def __init__(self, config: TimerxlConfig, layer_idx: Optional[int] = None):
- super().__init__()
- self.layer_idx = layer_idx
- self.hidden_size = config.hidden_size
- self.num_heads = config.num_attention_heads
- self.head_dim = self.hidden_size // self.num_heads
- self.attention_dropout = config.attention_dropout
- self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
- self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
- self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
- self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
- self.rotary_emb = TimeMoeRotaryEmbedding(
- self.head_dim,
max_position_embeddings=config.max_position_embeddings
- )
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional["Cache"] = None,
- ) -> Tuple[torch.Tensor, Optional["Cache"]]:
- bsz, q_len, _ = hidden_states.size()
-
- query_states = self.q_proj(hidden_states)
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
-
- query_states = query_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ).transpose(1, 2)
- key_states = key_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ).transpose(1, 2)
- value_states = value_states.view(
- bsz, q_len, self.num_heads, self.head_dim
- ).transpose(1, 2)
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
- query_states, key_states = apply_rotary_pos_emb(
- query_states, key_states, cos, sin, position_ids
- )
-
- if past_key_value is not None:
- key_states, value_states = past_key_value.update(
- key_states, value_states, self.layer_idx
- )
-
- attn_output = F.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attention_mask,
- dropout_p=self.attention_dropout,
- )
-
- attn_output = attn_output.transpose(1, 2).contiguous()
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
- attn_output = self.o_proj(attn_output)
-
- return attn_output, past_key_value
-
-
-class AttentionLayer(nn.Module):
- def __init__(self, attention, d_model, n_heads, d_keys=None,
d_values=None):
- super(AttentionLayer, self).__init__()
-
- d_keys = d_keys or (d_model // n_heads)
- d_values = d_values or (d_model // n_heads)
-
- self.inner_attention = attention
- self.query_projection = nn.Linear(d_model, d_keys * n_heads)
- self.key_projection = nn.Linear(d_model, d_keys * n_heads)
- self.value_projection = nn.Linear(d_model, d_values * n_heads)
- self.out_projection = nn.Linear(d_values * n_heads, d_model)
- self.n_heads = n_heads
-
- def forward(
- self,
- queries,
- keys,
- values,
- attn_mask,
- n_vars=None,
- n_tokens=None,
- tau=None,
- delta=None,
- ):
- B, L, _ = queries.shape
- _, S, _ = keys.shape
- H = self.n_heads
-
- queries = self.query_projection(queries).view(B, L, H, -1)
- keys = self.key_projection(keys).view(B, S, H, -1)
- values = self.value_projection(values).view(B, S, H, -1)
-
- out, attn = self.inner_attention(
- queries,
- keys,
- values,
- attn_mask,
- n_vars=n_vars,
- n_tokens=n_tokens,
- tau=tau,
- delta=delta,
- )
- out = out.view(B, L, -1)
-
- return self.out_projection(out), attn
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py
b/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py
deleted file mode 100644
index d5bad30ea05..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py
+++ /dev/null
@@ -1,329 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-from typing import Optional, Tuple
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from ainode.core.util.activation import ACT2FN
-from ainode.core.util.huggingface_cache import Cache, DynamicCache
-from ainode.TimerXL.layers.SelfAttention_Family import TimerAttention
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-
-class EncoderLayer(nn.Module):
- def __init__(self, attention, d_model, d_ff=None, dropout=0.1,
activation="relu"):
- super(EncoderLayer, self).__init__()
- d_ff = d_ff or 4 * d_model
- self.attention = attention
- self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff,
kernel_size=1)
- self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model,
kernel_size=1)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = F.relu if activation == "relu" else F.gelu
-
- def forward(self, x, attn_mask=None, tau=None, delta=None):
- new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau,
delta=delta)
- x = x + self.dropout(new_x)
-
- y = x = self.norm1(x)
- y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
- y = self.dropout(self.conv2(y).transpose(-1, 1))
-
- return self.norm2(x + y), attn
-
-
-class DecoderLayer(nn.Module):
- def __init__(
- self,
- self_attention,
- cross_attention,
- d_model,
- d_ff=None,
- dropout=0.1,
- activation="relu",
- ):
- super(DecoderLayer, self).__init__()
- d_ff = d_ff or 4 * d_model
- self.self_attention = self_attention
- self.cross_attention = cross_attention
- self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff,
kernel_size=1)
- self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model,
kernel_size=1)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.norm3 = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = F.relu if activation == "relu" else F.gelu
-
- def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None,
delta=None):
- x = x + self.dropout(
- self.self_attention(x, x, x, attn_mask=x_mask, tau=tau,
delta=None)[0]
- )
- x = self.norm1(x)
-
- x = x + self.dropout(
- self.cross_attention(
- x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta
- )[0]
- )
-
- y = x = self.norm2(x)
- y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
- y = self.dropout(self.conv2(y).transpose(-1, 1))
-
- return self.norm3(x + y)
-
-
-class DecoderOnlyLayer(nn.Module):
- def __init__(self, attention, d_model, d_ff=None, dropout=0.1,
activation="relu"):
- super(DecoderOnlyLayer, self).__init__()
- d_ff = d_ff or 4 * d_model
- self.attention = attention
- self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff,
kernel_size=1)
- self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model,
kernel_size=1)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = F.relu if activation == "relu" else F.gelu
-
- def forward(self, x, attn_mask=None, tau=None, delta=None):
- new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau,
delta=delta)
- x = x + self.dropout(new_x)
-
- y = x = self.norm1(x)
- y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
- y = self.dropout(self.conv2(y).transpose(-1, 1))
-
- return self.norm2(x + y), attn
-
-
-class TimerLayer(nn.Module):
- def __init__(self, attention, d_model, d_ff=None, dropout=0.1,
activation="relu"):
- super(TimerLayer, self).__init__()
- d_ff = d_ff or 4 * d_model
- self.attention = attention
- self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff,
kernel_size=1)
- self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model,
kernel_size=1)
- self.norm1 = nn.LayerNorm(d_model)
- self.norm2 = nn.LayerNorm(d_model)
- self.dropout = nn.Dropout(dropout)
- self.activation = F.relu if activation == "relu" else F.gelu
-
- def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None,
delta=None):
- new_x, attn = self.attention(
- x,
- x,
- x,
- n_vars=n_vars,
- n_tokens=n_tokens,
- attn_mask=attn_mask,
- tau=tau,
- delta=delta,
- )
- x = x + self.dropout(new_x)
-
- y = x = self.norm1(x)
- y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
- y = self.dropout(self.conv2(y).transpose(-1, 1))
-
- return self.norm2(x + y), attn
-
-
-class Encoder(nn.Module):
- def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
- super(Encoder, self).__init__()
- self.attn_layers = nn.ModuleList(attn_layers)
- self.conv_layers = (
- nn.ModuleList(conv_layers) if conv_layers is not None else None
- )
- self.norm = norm_layer
-
- def forward(self, x, attn_mask=None, tau=None, delta=None):
- # x [B, L, D]
- attns = []
- if self.conv_layers is not None:
- for i, (attn_layer, conv_layer) in enumerate(
- zip(self.attn_layers, self.conv_layers)
- ):
- delta = delta if i == 0 else None
- x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau,
delta=delta)
- x = conv_layer(x)
- attns.append(attn)
- x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
- attns.append(attn)
- else:
- for attn_layer in self.attn_layers:
- x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau,
delta=delta)
- attns.append(attn)
-
- if self.norm is not None:
- x = self.norm(x)
-
- return x, attns
-
-
-class Decoder(nn.Module):
- def __init__(self, layers, norm_layer=None, projection=None):
- super(Decoder, self).__init__()
- self.layers = nn.ModuleList(layers)
- self.norm = norm_layer
- self.projection = projection
-
- def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None,
delta=None):
- for layer in self.layers:
- x = layer(
- x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau,
delta=delta
- )
-
- if self.norm is not None:
- x = self.norm(x)
-
- if self.projection is not None:
- x = self.projection(x)
- return x
-
-
-class DecoderOnly(nn.Module):
- def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
- super(DecoderOnly, self).__init__()
- self.attn_layers = nn.ModuleList(attn_layers)
- self.conv_layers = (
- nn.ModuleList(conv_layers) if conv_layers is not None else None
- )
- self.norm = norm_layer
-
- def forward(self, x, attn_mask=None, tau=None, delta=None):
- # x [B, L, D]
- attns = []
- if self.conv_layers is not None:
- for i, (attn_layer, conv_layer) in enumerate(
- zip(self.attn_layers, self.conv_layers)
- ):
- delta = delta if i == 0 else None
- x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau,
delta=delta)
- x = conv_layer(x)
- attns.append(attn)
- x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
- attns.append(attn)
- else:
- for attn_layer in self.attn_layers:
- x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau,
delta=delta)
- attns.append(attn)
-
- if self.norm is not None:
- x = self.norm(x)
-
- return x, attns
-
-
-class TimerBlock(nn.Module):
- def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
- super(TimerBlock, self).__init__()
- self.attn_layers = nn.ModuleList(attn_layers)
- self.conv_layers = (
- nn.ModuleList(conv_layers) if conv_layers is not None else None
- )
- self.norm = norm_layer
-
- def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None,
delta=None):
- # x [B, L, D]
- attns = []
- if self.conv_layers is not None:
- for i, (attn_layer, conv_layer) in enumerate(
- zip(self.attn_layers, self.conv_layers)
- ):
- delta = delta if i == 0 else None
- x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau,
delta=delta)
- x = conv_layer(x)
- attns.append(attn)
- x, attn = self.attn_layers[-1](x, n_vars, n_tokens, tau=tau,
delta=None)
- attns.append(attn)
- else:
- for attn_layer in self.attn_layers:
- x, attn = attn_layer(
- x, n_vars, n_tokens, attn_mask=attn_mask, tau=tau,
delta=delta
- )
- attns.append(attn)
-
- if self.norm is not None:
- x = self.norm(x)
-
- return x, attns
-
-
-class TimerMLP(nn.Module):
- def __init__(self, hidden_size: int, intermediate_size: int, hidden_act:
str):
- super().__init__()
- self.hidden_size = hidden_size
- self.intermediate_size = intermediate_size
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size,
bias=False)
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size,
bias=False)
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size,
bias=False)
- self.act_fn = ACT2FN[hidden_act]
-
- def forward(self, hidden_state):
- return self.down_proj(
- self.act_fn(self.gate_proj(hidden_state)) *
self.up_proj(hidden_state)
- )
-
-
-class TimerDecoderLayer(nn.Module):
- def __init__(self, config: TimerxlConfig, layer_idx: int):
- super().__init__()
- self.self_attn = TimerAttention(config, layer_idx)
-
- self.ffn_layer = TimerMLP(
- hidden_size=config.hidden_size,
- intermediate_size=config.intermediate_size,
- hidden_act=config.hidden_act,
- )
- self.norm1 = torch.nn.LayerNorm(config.hidden_size)
- self.norm2 = torch.nn.LayerNorm(config.hidden_size)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_value: Optional[Cache] = None,
- use_cache: bool = False,
- ) -> Tuple[torch.FloatTensor, Optional[Cache]]:
- residual = hidden_states
-
- # Self Attention
- hidden_states, present_key_value = self.self_attn(
- hidden_states=hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_value,
- )
-
- hidden_states = residual + hidden_states
- hidden_states = self.norm1(hidden_states)
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.ffn_layer(hidden_states)
- hidden_states = residual + hidden_states
- hidden_states = self.norm2(hidden_states)
-
- if not use_cache:
- present_key_value = None
- return hidden_states, present_key_value
diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py
b/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py
deleted file mode 100644
index 2a1e720805f..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py
b/iotdb-core/ainode/ainode/TimerXL/models/__init__.py
deleted file mode 100644
index 2a1e720805f..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py
+++ /dev/null
@@ -1,17 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
deleted file mode 100644
index b3962a052a5..00000000000
--- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py
+++ /dev/null
@@ -1,446 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-#
-
-import os
-from dataclasses import dataclass
-from typing import Any, Dict, List, Optional, Tuple
-
-import torch
-from huggingface_hub import hf_hub_download
-from safetensors.torch import load_file as load_safetensors
-from torch import nn
-
-from ainode.core.log import Logger
-from ainode.core.util.huggingface_cache import Cache, DynamicCache
-from ainode.core.util.masking import prepare_4d_causal_attention_mask
-from ainode.TimerXL.layers.Embed import TimerPatchEmbedding
-from ainode.TimerXL.layers.Transformer_EncDec import TimerDecoderLayer
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
-
-logger = Logger()
-
-
-@dataclass
-class Output:
- outputs: torch.Tensor
- past_key_values: Optional[Any] = None
-
-
-class TimerModel(nn.Module):
- def __init__(self, config: TimerxlConfig):
- super().__init__()
- self.config = config
- self.embed_layer = TimerPatchEmbedding(config)
- self.layers = nn.ModuleList(
- [
- TimerDecoderLayer(config, layer_idx)
- for layer_idx in range(config.num_hidden_layers)
- ]
- )
- self.norm = torch.nn.LayerNorm(config.hidden_size)
- self.gradient_checkpointing = False
-
- def forward(
- self,
- input_ids: torch.FloatTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- use_cache: bool = None,
- ):
- # input_ids is the input of time series, its shape is [batch_size,
seq_len]
-
- if input_ids is not None:
- batch_size, seq_length = input_ids.shape
- else:
- raise ValueError(
- "You have to specify either decoder_input_ids or
decoder_inputs_embeds"
- )
-
- inputs_embeds = self.embed_layer(input_ids)
-
- seq_length = inputs_embeds.shape[1]
-
- past_key_values_length = 0
-
- if use_cache:
- use_legacy_cache = not isinstance(past_key_values, Cache)
- if use_legacy_cache:
- past_key_values =
DynamicCache.from_legacy_cache(past_key_values)
- past_key_values_length =
past_key_values.get_usable_length(seq_length)
-
- if position_ids is None:
- device = input_ids.device if input_ids is not None else
inputs_embeds.device
- position_ids = torch.arange(
- past_key_values_length,
- seq_length + past_key_values_length,
- dtype=torch.long,
- device=device,
- )
- position_ids = position_ids.view(-1, seq_length)
- else:
- position_ids = position_ids.view(-1, seq_length).long()
-
- # 4d mask is passed through the layers
- attention_mask = prepare_4d_causal_attention_mask(
- attention_mask,
- (batch_size, seq_length),
- inputs_embeds,
- past_key_values_length,
- )
-
- hidden_states = inputs_embeds
-
- # decoder layers
- next_decoder_cache = None
-
- for decoder_layer in self.layers:
- layer_outputs = decoder_layer(
- hidden_states,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_value=past_key_values,
- use_cache=use_cache,
- )
-
- hidden_states = layer_outputs[0]
-
- if use_cache:
- next_decoder_cache = layer_outputs[1]
-
- hidden_states = self.norm(hidden_states)
-
- next_cache = None
- if use_cache:
- next_cache = (
- next_decoder_cache.to_legacy_cache()
- if use_legacy_cache
- else next_decoder_cache
- )
-
- return Output(outputs=hidden_states, past_key_values=next_cache)
-
-
-class TimerForPrediction(nn.Module):
- def __init__(self, config):
- super().__init__()
- self.config = config
- self.model = TimerModel(self.config)
- lm_head_list = []
- self.output_token_len_map = {}
- for i, output_token_len in enumerate(self.config.output_token_lens):
- lm_head_list.append(
- nn.Linear(self.config.hidden_size, output_token_len,
bias=False)
- )
- self.output_token_len_map[output_token_len] = i
- self.lm_heads = nn.ModuleList(lm_head_list)
- self.loss_function = torch.nn.MSELoss(reduction="none")
-
- def forward(
- self,
- input_ids: torch.FloatTensor = None,
- attention_mask: Optional[torch.Tensor] = None,
- position_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- use_cache: Optional[bool] = None,
- max_output_length: Optional[int] = None,
- revin: Optional[bool] = True,
- ):
- if revin:
- means, stdev = input_ids.mean(dim=-1, keepdim=True), input_ids.std(
- dim=-1, keepdim=True
- )
- input_ids = (input_ids - means) / stdev
-
- outputs = self.model(
- input_ids=input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- )
- hidden_states = outputs.outputs
-
- if max_output_length is None:
- output_token_len = self.config.output_token_lens[0]
- max_output_length = output_token_len
- else:
- output_token_len = self.config.output_token_lens[0]
- for h in self.config.output_token_lens[1:]:
- if h > max_output_length:
- break
- else:
- output_token_len = h
-
- lm_head = self.lm_heads[self.output_token_len_map[output_token_len]]
- predictions = lm_head(hidden_states)[:, -1, :]
-
- if output_token_len > max_output_length:
- predictions = predictions[:, :max_output_length]
- if revin:
- predictions = predictions * stdev + means
-
- return Output(predictions, outputs.past_key_values)
-
-
-class Model(nn.Module):
- """
- Timer-XL: Long-Context Transformers for Unified Time Series Forecasting
-
- Paper: https://arxiv.org/abs/2410.04803
-
- GitHub: https://github.com/thuml/Timer-XL
-
- Citation: @article{liu2024timer,
- title={Timer-XL: Long-Context Transformers for Unified Time Series
Forecasting},
- author={Liu, Yong and Qin, Guo and Huang, Xiangdong and Wang, Jianmin
and Long, Mingsheng},
- journal={arXiv preprint arXiv:2410.04803},
- year={2024}
- }
- """
-
- def __init__(self, config: TimerxlConfig):
- super().__init__()
- self.config = config # can't be scripted by torch
-
- self.device = torch.device("cuda" if torch.cuda.is_available() else
"cpu")
- self.model = TimerForPrediction(config).to(self.device)
-
- if config.ckpt_path is not None and config.ckpt_path != "":
- if config.ckpt_path.endswith(".pt") or
config.ckpt_path.endswith(".pth"):
- state_dict = torch.load(config.ckpt_path)
- elif config.ckpt_path.endswith(".safetensors"):
- if not os.path.exists(config.ckpt_path):
- logger.info(
- f"Checkpoint not found at {config.ckpt_path},
downloading from HuggingFace..."
- )
- repo_id = "thuml/timer-base-84m"
- try:
- config.ckpt_path = hf_hub_download(
- repo_id=repo_id,
- filename=os.path.basename(config.ckpt_path),
- local_dir=os.path.dirname(config.ckpt_path),
- )
- logger.info(f"Got checkpoint to {config.ckpt_path}")
- except Exception as e:
- logger.error(
- f"Failed to download checkpoint to
{config.ckpt_path} due to {e}"
- )
- raise e
- state_dict = load_safetensors(config.ckpt_path)
- else:
- raise ValueError("unsupported model weight type")
- # If there is no key beginning with 'model.model' in state_dict,
add a 'model.' before all keys. (The model code here has an additional layer of
encapsulation compared to the code on huggingface.)
- if not any(k.startswith("model.model") for k in state_dict.keys()):
- state_dict = {"model." + k: v for k, v in state_dict.items()}
- self.load_state_dict(state_dict, strict=True)
-
- def set_device(self, device):
- self.model.to(device)
- self.device = next(self.model.parameters()).device
-
- def inference(self, x, max_new_tokens: int = 96):
- # x.shape: [L, C], type: DataFrame
- # here we only except C=1 temporarily
- # change [L, C=1] to [batchsize=1, L]
- self.device = next(self.model.parameters()).device
-
- x = torch.tensor(
- x, dtype=next(self.model.parameters()).dtype, device=self.device
- )
- x = x.view(1, -1)
-
- preds = self.forward(x, max_new_tokens)
- preds = preds.detach().cpu().numpy()
-
- return preds
-
- def forward(self, x, max_new_tokens: int = 96):
- # self.config.is_encoder_decoder = False
- self.eval()
- self.device = next(self.model.parameters()).device
-
- if len(x.shape) == 2:
- batch_size, cur_len = x.shape
- if cur_len < self.config.input_token_len:
- raise ValueError(
- f"Input length must be at least
{self.config.input_token_len}"
- )
- elif cur_len % self.config.input_token_len != 0:
- new_len = (
- cur_len // self.config.input_token_len
- ) * self.config.input_token_len
- x = x[:, -new_len:]
- else:
- raise ValueError("Input shape must be: [batch_size, seq_len]")
-
- use_cache = self.config.use_cache
- all_input_ids = x
-
- attention_mask =
self.prepare_attention_mask_for_generation(all_input_ids)
- all_input_ids_length = all_input_ids.shape[-1]
- max_length = max_new_tokens + all_input_ids_length
-
- all_input_ids = all_input_ids.to(self.device)
- batch_size, cur_len = all_input_ids.shape
-
- unfinished_sequences = torch.ones(
- batch_size, dtype=torch.long, device=all_input_ids.device
- )
- cache_position = torch.arange(cur_len, device=all_input_ids.device)
- true_seq_len = cur_len // self.config.input_token_len
- attention_mask = attention_mask[:, -true_seq_len:]
-
- this_peer_finished = False
- past_key_values = None
- position_ids = None
- while not this_peer_finished:
- (input_ids, position_ids, past_key_values, attention_mask, revin)
= (
- self.prepare_inputs_for_generation(
- all_input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- # position_ids=position_ids # Wrong?!
- position_ids=None, # True?! based on huggingface code
- )
- )
-
- input_length = all_input_ids.shape[1]
-
- # forward pass to get next token
- outputs = self.model(
- input_ids,
- attention_mask=attention_mask,
- position_ids=position_ids,
- past_key_values=past_key_values,
- use_cache=use_cache,
- max_output_length=max_length - input_length,
- revin=revin,
- )
-
- next_tokens = outputs.outputs
-
- # update generated ids, model inputs, and length for next step
- horizon_length = next_tokens.shape[1] //
self.config.input_token_len
-
- all_input_ids = torch.cat([all_input_ids, next_tokens], dim=-1)
- (past_key_values, attention_mask, cache_position) = (
- self._update_model_kwargs_for_generation(
- outputs,
- attention_mask=attention_mask,
- horizon_length=horizon_length,
- cache_position=cache_position,
- )
- )
-
- unfinished_sequences = unfinished_sequences & (
- all_input_ids.shape[1] < max_length
- )
- this_peer_finished = unfinished_sequences.max() == 0
-
- if all_input_ids.shape[1] > max_length:
- all_input_ids = all_input_ids[:, :max_length]
-
- return all_input_ids[:, -(max_length - cur_len) :]
-
- def prepare_attention_mask_for_generation(
- self,
- inputs: torch.Tensor,
- ) -> torch.LongTensor:
- return torch.ones(inputs.shape[:2], dtype=torch.long,
device=inputs.device)
-
- def prepare_inputs_for_generation(
- self,
- input_ids,
- past_key_values=None,
- attention_mask=None,
- revin=True,
- position_ids=None,
- ):
- # Omit tokens covered by past_key_values
- if past_key_values is not None:
- if isinstance(past_key_values, Cache):
- cache_length = past_key_values.get_seq_length()
- if isinstance(past_key_values, DynamicCache):
- past_length = past_key_values.seen_tokens
- else:
- past_length = cache_length
-
- max_cache_length = past_key_values.get_max_length()
- else:
- cache_length = past_length = past_key_values[0][0].shape[2]
- max_cache_length = None
-
- # Keep only the unprocessed tokens:
- # 1 - If the length of the attention_mask exceeds the length of
input_ids, then we are in a setting where
- # some of the inputs are exclusively passed as part of the cache
(e.g. when passing input_embeds as
- # input)
- if attention_mask is not None and attention_mask.shape[1] > (
- input_ids.shape[1] // self.config.input_token_len
- ):
- input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
- # 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
- # input_ids based on the past_length.
- elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
- input_ids = input_ids[:, past_length *
self.config.input_token_len :]
- # 3 - Otherwise (past_length >= (input_ids.shape[1] //
self.config.input_token_len)), let's assume input_ids only has unprocessed
tokens.
-
- # If we are about to go beyond the maximum cache length, we need
to crop the input attention mask.
- if (
- max_cache_length is not None
- and attention_mask is not None
- and cache_length + (input_ids.shape[1] //
self.config.input_token_len)
- > max_cache_length
- ):
- attention_mask = attention_mask[:, -max_cache_length:]
-
- if attention_mask is not None and position_ids is None:
- # create position_ids on the fly for batch generation
- position_ids = attention_mask.long().cumsum(-1) - 1
- position_ids.masked_fill_(attention_mask == 0, 1)
- if past_key_values:
- position_ids = position_ids[
- :, -(input_ids.shape[1] // self.config.input_token_len) :
- ]
-
- return (input_ids, position_ids, past_key_values, attention_mask,
revin)
-
- def _update_model_kwargs_for_generation(
- self,
- outputs,
- attention_mask=None,
- cache_position=None,
- horizon_length: int = 1,
- ) -> Dict[str, Any]:
- # update past_key_values
- past_key_values = outputs.past_key_values
-
- # update attention mask
- if attention_mask is not None:
- attention_mask = torch.cat(
- [
- attention_mask,
- attention_mask.new_ones((attention_mask.shape[0],
horizon_length)),
- ],
- dim=-1,
- )
-
- if cache_position is not None:
- cache_position = cache_position[-1:] + horizon_length
-
- return (past_key_values, attention_mask, cache_position)
diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
index 175961e7313..eb8becd0f17 100644
--- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py
+++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py
@@ -57,7 +57,9 @@ class TimerXLStrategy(InferenceStrategy):
data = full_data[1][0]
if data.dtype.byteorder not in ("=", "|"):
data = data.byteswap().newbyteorder()
- output = self.model.inference(data, int(predict_length))
+ seqs = torch.tensor(data).unsqueeze(0).float()
+ # TODO: unify model inference input
+ output = self.model.generate(seqs, max_new_tokens=predict_length,
revin=True)
df = pd.DataFrame(output[0])
return convert_to_binary(df)
diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
index 6298fb6a1db..8bd3bfc4800 100644
--- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
+++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py
@@ -41,8 +41,8 @@ from ainode.core.exception import (
from ainode.core.log import Logger
from ainode.core.model.sundial import modeling_sundial
from ainode.core.model.sundial.configuration_sundial import SundialConfig
-from ainode.TimerXL.models import timer_xl
-from ainode.TimerXL.models.configuration_timer import TimerxlConfig
+from ainode.core.model.timerxl import modeling_timer
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
logger = Logger()
@@ -113,7 +113,7 @@ def fetch_built_in_model(model_id, inference_attributes):
elif model_id == BuiltInModelType.STRAY.value:
model = STRAYModel(attributes)
elif model_id == BuiltInModelType.TIMER_XL.value:
- model = timer_xl.Model(TimerxlConfig.from_dict(attributes))
+ model =
modeling_timer.TimerForPrediction(TimerConfig.from_dict(attributes))
elif model_id == BuiltInModelType.SUNDIAL.value:
model = modeling_sundial.SundialForPrediction(
SundialConfig.from_dict(attributes)
diff --git a/iotdb-core/ainode/ainode/TimerXL/__init__.py
b/iotdb-core/ainode/ainode/core/model/timerxl/__init__.py
similarity index 100%
rename from iotdb-core/ainode/ainode/TimerXL/__init__.py
rename to iotdb-core/ainode/ainode/core/model/timerxl/__init__.py
diff --git a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py
b/iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py
similarity index 68%
rename from iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py
rename to iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py
index ac5034aa85e..34f9de91b63 100644
--- a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/configuration_timer.py
@@ -15,27 +15,30 @@
# specific language governing permissions and limitations
# under the License.
#
+
from typing import List
+from transformers import PretrainedConfig
+
-class TimerxlConfig:
- model_type = "timerxl"
+class TimerConfig(PretrainedConfig):
+ model_type = "timer"
+ keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
- input_token_len: int = 96, # how many points as a token, don't change
- hidden_size: int = 1024, # model hidden size
- intermediate_size: int = 2048, # ffn middle size
- output_token_lens: List[int] = [96], # how many points as a token,
don't change
+ input_token_len: int = 1,
+ hidden_size: int = 1024,
+ intermediate_size: int = 2048,
+ output_token_lens: List[int] = [1, 8, 32, 64],
num_hidden_layers: int = 8,
num_attention_heads: int = 8,
- hidden_act: str = "silu", # activation function
- use_cache: bool = True, # kv cache
- rope_theta: int = 10000, # ROBE parameter
+ hidden_act: str = "silu",
+ use_cache: bool = True,
+ rope_theta: int = 10000,
attention_dropout: float = 0.0,
- initializer_range: float = 0.02, # be of no use, because we already
have weights
+ initializer_range: float = 0.02,
max_position_embeddings: int = 10000,
- ckpt_path: str = None, # weight path
**kwargs,
):
self.input_token_len = input_token_len
@@ -50,12 +53,7 @@ class TimerxlConfig:
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.max_position_embeddings = max_position_embeddings
- self.ckpt_path = ckpt_path
super().__init__(
**kwargs,
)
-
- @classmethod
- def from_dict(cls, config_dict: dict) -> "TimerxlConfig":
- return cls(**config_dict)
diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
new file mode 100644
index 00000000000..42b3a82b972
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/modeling_timer.py
@@ -0,0 +1,680 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import os
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.nn.functional as F
+from huggingface_hub import hf_hub_download
+from safetensors.torch import load_file as load_safetensors
+from torch import nn
+from transformers import Cache, DynamicCache, PreTrainedModel
+from transformers.activations import ACT2FN
+from transformers.modeling_attn_mask_utils import
_prepare_4d_causal_attention_mask
+from transformers.modeling_outputs import (
+ MoeCausalLMOutputWithPast,
+ MoeModelOutputWithPast,
+)
+
+from ainode.core.log import Logger
+from ainode.core.model.timerxl.configuration_timer import TimerConfig
+from ainode.core.model.timerxl.ts_generation_mixin import TSGenerationMixin
+
+logger = Logger()
+
+
+def rotate_half(x):
+ x1 = x[..., : x.shape[-1] // 2]
+ x2 = x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
+ cos = cos[position_ids].unsqueeze(unsqueeze_dim)
+ sin = sin[position_ids].unsqueeze(unsqueeze_dim)
+ q_embed = (q * cos) + (rotate_half(q) * sin)
+ k_embed = (k * cos) + (rotate_half(k) * sin)
+ return q_embed, k_embed
+
+
+class TimerPatchEmbedding(nn.Module):
+ def __init__(self, config: TimerConfig):
+ super().__init__()
+ self.input_token_len = config.input_token_len
+ self.emb = nn.Linear(config.input_token_len, config.hidden_size,
bias=False)
+
+ def forward(self, hidden_state: torch.Tensor):
+ hidden_state = hidden_state.unfold(
+ dimension=-1, size=self.input_token_len, step=self.input_token_len
+ )
+ return self.emb(hidden_state)
+
+
+class TimerPointEmbedding(nn.Module):
+ def __init__(self, config: TimerConfig):
+ super().__init__()
+ self.emb_layer = nn.Linear(
+ config.input_token_len, config.hidden_size, bias=False
+ )
+ self.gate_layer = nn.Linear(
+ config.input_token_len, config.hidden_size, bias=False
+ )
+ self.act_fn = ACT2FN[config.hidden_act]
+
+ def forward(self, x):
+ emb = self.act_fn(self.gate_layer(x)) * self.emb_layer(x)
+ return emb
+
+
+class TimeMoeRotaryEmbedding(torch.nn.Module):
+ def __init__(self, dim, max_position_embeddings=10000, base=10000,
device=None):
+ super().__init__()
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (
+ self.base
+ ** (
+ torch.arange(0, self.dim, 2,
dtype=torch.int64).float().to(device)
+ / self.dim
+ )
+ )
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings,
+ device=self.inv_freq.device,
+ dtype=torch.get_default_dtype(),
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(
+ self.max_seq_len_cached, device=device, dtype=torch.int64
+ ).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order
to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype),
persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype),
persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device,
dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+class TimerAttention(nn.Module):
+ def __init__(self, config: TimerConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.layer_idx = layer_idx
+ self.hidden_size = config.hidden_size
+ self.num_heads = config.num_attention_heads
+ self.head_dim = self.hidden_size // self.num_heads
+ self.attention_dropout = config.attention_dropout
+ self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.k_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
+ self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
+ self.rotary_emb = TimeMoeRotaryEmbedding(
+ self.head_dim,
max_position_embeddings=config.max_position_embeddings
+ )
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ key_states = key_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+ value_states = value_states.view(
+ bsz, q_len, self.num_heads, self.head_dim
+ ).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len,
self.layer_idx)
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(
+ query_states, key_states, cos, sin, position_ids
+ )
+
+ if past_key_value is not None:
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx
+ )
+
+ attn_output = F.scaled_dot_product_attention(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout_p=self.attention_dropout,
+ )
+
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
+class TimerMLP(nn.Module):
+ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act:
str):
+ super().__init__()
+ self.hidden_size = hidden_size
+ self.intermediate_size = intermediate_size
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size,
bias=False)
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size,
bias=False)
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size,
bias=False)
+ self.act_fn = ACT2FN[hidden_act]
+
+ def forward(self, hidden_state):
+ return self.down_proj(
+ self.act_fn(self.gate_proj(hidden_state)) *
self.up_proj(hidden_state)
+ )
+
+
+class TimerDecoderLayer(nn.Module):
+ def __init__(self, config: TimerConfig, layer_idx: int):
+ super().__init__()
+ self.self_attn = TimerAttention(config, layer_idx)
+
+ self.ffn_layer = TimerMLP(
+ hidden_size=config.hidden_size,
+ intermediate_size=config.intermediate_size,
+ hidden_act=config.hidden_act,
+ )
+ self.norm1 = torch.nn.LayerNorm(config.hidden_size)
+ self.norm2 = torch.nn.LayerNorm(config.hidden_size)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = False,
+ **kwargs,
+ ) -> Tuple[
+ torch.FloatTensor,
+ torch.FloatTensor,
+ Optional[torch.FloatTensor],
+ Optional[torch.FloatTensor],
+ ]:
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+ hidden_states = residual + hidden_states
+ hidden_states = self.norm1(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.ffn_layer(hidden_states)
+ hidden_states = residual + hidden_states
+ hidden_states = self.norm2(hidden_states)
+
+ if not output_attentions:
+ self_attn_weights = None
+
+ if not use_cache:
+ present_key_value = None
+ return hidden_states, self_attn_weights, present_key_value
+
+
+class TimerPreTrainedModel(PreTrainedModel):
+ config_class = TimerConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["TimeMoeDecoderLayer"]
+ _skip_keys_device_placement = "past_key_values"
+ _supports_flash_attn_2 = True
+ _supports_sdpa = False
+ _supports_cache_class = True
+
+ def _init_weights(self, module):
+ std = self.config.initializer_range
+ if isinstance(module, torch.nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, torch.nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+
+class TimerModel(TimerPreTrainedModel):
+ def __init__(self, config: TimerConfig):
+ super().__init__(config)
+ self.embed_layer = TimerPatchEmbedding(config)
+ self.layers = nn.ModuleList(
+ [
+ TimerDecoderLayer(config, layer_idx)
+ for layer_idx in range(config.num_hidden_layers)
+ ]
+ )
+ self.norm = torch.nn.LayerNorm(config.hidden_size)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ input_ids: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, MoeModelOutputWithPast]:
+ # input_ids is the input of time series, its shape is [batch_size,
seq_len]
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else
self.config.use_cache
+
+ return_dict = (
+ return_dict if return_dict is not None else
self.config.use_return_dict
+ )
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError(
+ "You cannot specify both decoder_input_ids and
decoder_inputs_embeds at the same time"
+ )
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError(
+ "You have to specify either decoder_input_ids or
decoder_inputs_embeds"
+ )
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_layer(input_ids)
+ seq_length = inputs_embeds.shape[1]
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ use_cache = False
+
+ past_key_values_length = 0
+
+ if use_cache:
+ use_legacy_cache = not isinstance(past_key_values, Cache)
+ if use_legacy_cache:
+ past_key_values =
DynamicCache.from_legacy_cache(past_key_values)
+ past_key_values_length =
past_key_values.get_usable_length(seq_length)
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else
inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length,
+ seq_length + past_key_values_length,
+ dtype=torch.long,
+ device=device,
+ )
+ # position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ position_ids = position_ids.view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=None,
+ )
+
+ hidden_states = inputs_embeds
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2]
+
+ hidden_states = self.norm(hidden_states)
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = None
+ if use_cache:
+ next_cache = (
+ next_decoder_cache.to_legacy_cache()
+ if use_legacy_cache
+ else next_decoder_cache
+ )
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states,
all_self_attns]
+ if v is not None
+ )
+ return MoeModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+
+class TimerForPrediction(TimerPreTrainedModel, TSGenerationMixin):
+ def __init__(self, config: TimerConfig):
+ super().__init__(config)
+ self.config = config
+ self.model = TimerModel(self.config)
+ lm_head_list = []
+ self.output_token_len_map = {}
+ for i, output_token_len in enumerate(self.config.output_token_lens):
+ lm_head_list.append(
+ nn.Linear(self.config.hidden_size, output_token_len,
bias=False)
+ )
+ self.output_token_len_map[output_token_len] = i
+ self.lm_heads = nn.ModuleList(lm_head_list)
+ self.loss_function = torch.nn.MSELoss(reduction="none")
+ # TODO: Unify data loader
+ if not os.path.exists(config.ckpt_path):
+ os.mkdir(config.ckpt_path)
+ weights_path = os.path.join(config.ckpt_path, "model.safetensors")
+ if not os.path.exists(weights_path):
+ logger.info(
+ f"Weight not found at {weights_path}, downloading from
HuggingFace..."
+ )
+ repo_id = "thuml/sundial-base-128m"
+ try:
+ hf_hub_download(
+ repo_id=repo_id,
+ filename="model.safetensors",
+ local_dir=config.ckpt_path,
+ )
+ logger.info(f"Got weight to {weights_path}")
+ except Exception as e:
+ logger.error(f"Failed to download weight to {weights_path} due
to {e}")
+ raise e
+ state_dict = load_safetensors(weights_path)
+ self.load_state_dict(state_dict, strict=True)
+
+ def set_decoder(self, decoder):
+ self.model = decoder
+
+ def get_decoder(self):
+ return self.model
+
+ def forward(
+ self,
+ input_ids: torch.FloatTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.FloatTensor] = None,
+ loss_masks: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ max_output_length: Optional[int] = None,
+ revin: Optional[bool] = False,
+ ) -> Union[Tuple, MoeCausalLMOutputWithPast]:
+
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+ return_dict = (
+ return_dict if return_dict is not None else
self.config.use_return_dict
+ )
+
+ if revin:
+ mean, std = input_ids.mean(dim=-1, keepdim=True), input_ids.std(
+ dim=-1, keepdim=True
+ )
+ input_ids = (input_ids - mean) / std
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = outputs[0] if not return_dict else
outputs.last_hidden_state
+ predictions = None
+
+ loss = None
+ if labels is not None:
+ ar_loss = 0.0
+ for lm_head, output_token_len in zip(
+ self.lm_heads, self.config.output_token_lens
+ ):
+ one_predictions = lm_head(hidden_states)
+ one_loss = self.calc_ar_loss(
+ one_predictions, labels, loss_masks, output_token_len
+ )
+ ar_loss += one_loss
+ if predictions is None:
+ predictions = one_predictions
+ loss = ar_loss / len(self.config.output_token_lens)
+ else:
+ if max_output_length is None:
+ output_token_len = self.config.output_token_lens[0]
+ max_output_length = output_token_len
+ else:
+ output_token_len = self.config.output_token_lens[0]
+ for h in self.config.output_token_lens[1:]:
+ if h > max_output_length:
+ break
+ else:
+ output_token_len = h
+ lm_head =
self.lm_heads[self.output_token_len_map[output_token_len]]
+ predictions = lm_head(hidden_states)[:, -1, :]
+ if output_token_len > max_output_length:
+ predictions = predictions[:, :max_output_length]
+ if revin:
+ predictions = predictions * std + mean
+ if not return_dict:
+ output = (predictions,) + outputs[1:]
+ return (loss) + output if loss is not None else output
+
+ return MoeCausalLMOutputWithPast(
+ loss=loss,
+ logits=predictions,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ def calc_ar_loss(self, predictions, labels, loss_masks, output_token_len):
+ seq_len = predictions.shape[1] * self.config.input_token_len
+ labels = labels[:, : seq_len - self.config.input_token_len +
output_token_len]
+ shift_labels = labels.unfold(
+ dimension=-1, size=output_token_len,
step=self.config.input_token_len
+ )
+
+ # Calculate loss with mask
+ losses = self.loss_function(predictions, shift_labels).mean(dim=-1)
+ if loss_masks is not None:
+ losses = losses * loss_masks
+ loss = losses.sum() / loss_masks.sum()
+ else:
+ loss = torch.mean(losses)
+
+ return loss
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ revin=True,
+ **kwargs,
+ ):
+ # Omit tokens covered by past_key_values
+ if past_key_values is not None:
+ if isinstance(past_key_values, Cache):
+ cache_length = past_key_values.get_seq_length()
+ if isinstance(past_key_values, DynamicCache):
+ past_length = past_key_values.seen_tokens
+ else:
+ past_length = cache_length
+
+ max_cache_length = past_key_values.get_max_length()
+ else:
+ cache_length = past_length = past_key_values[0][0].shape[2]
+ max_cache_length = None
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of
input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache
(e.g. when passing input_embeds as
+ # input)
+ if attention_mask is not None and attention_mask.shape[1] > (
+ input_ids.shape[1] // self.config.input_token_len
+ ):
+ input_ids = input_ids[:, -(attention_mask.shape[1] -
past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then
input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < (input_ids.shape[1] //
self.config.input_token_len):
+ input_ids = input_ids[:, past_length *
self.config.input_token_len :]
+ # 3 - Otherwise (past_length >= (input_ids.shape[1] //
self.config.input_token_len)), let's assume input_ids only has unprocessed
tokens.
+
+ # If we are about to go beyond the maximum cache length, we need
to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + (input_ids.shape[1] //
self.config.input_token_len)
+ > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[
+ :, -(input_ids.shape[1] // self.config.input_token_len) :
+ ]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st
generation step
+ if inputs_embeds is not None and past_key_values is None:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ model_inputs = {"input_ids": input_ids}
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "attention_mask": attention_mask,
+ "revin": revin,
+ }
+ )
+ return model_inputs
diff --git a/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py
b/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py
new file mode 100644
index 00000000000..165d3c55e44
--- /dev/null
+++ b/iotdb-core/ainode/ainode/core/model/timerxl/ts_generation_mixin.py
@@ -0,0 +1,366 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import torch
+from transformers import GenerationMixin, LogitsProcessorList,
StoppingCriteriaList
+from transformers.generation import EosTokenCriteria,
validate_stopping_criteria
+from transformers.generation.utils import (
+ GenerateDecoderOnlyOutput,
+ GenerateEncoderDecoderOutput,
+ GenerateNonBeamOutput,
+ GenerateOutput,
+ GenerationConfig,
+)
+from transformers.utils import ModelOutput
+
+
+class TSGenerationMixin(GenerationMixin):
+
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[
+ Callable[[int, torch.Tensor], List[int]]
+ ] = None,
+ synced_gpus: Optional[bool] = None,
+ assistant_model: Optional["PreTrainedModel"] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ negative_prompt_ids: Optional[torch.Tensor] = None,
+ negative_prompt_attention_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ if len(inputs.shape) == 2:
+ batch_size, cur_len = inputs.shape
+ if cur_len < self.config.input_token_len:
+ raise ValueError(
+ f"Input length must be at least
{self.config.input_token_len}"
+ )
+ elif cur_len % self.config.input_token_len != 0:
+ new_len = (
+ cur_len // self.config.input_token_len
+ ) * self.config.input_token_len
+ inputs = inputs[:, -new_len:]
+ else:
+ raise ValueError("Input shape must be: [batch_size, seq_len]")
+ return super().generate(
+ inputs=inputs,
+ generation_config=generation_config,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ synced_gpus=synced_gpus,
+ assistant_model=assistant_model,
+ streamer=streamer,
+ negative_prompt_ids=negative_prompt_ids,
+ negative_prompt_attention_mask=negative_prompt_attention_mask,
+ **kwargs,
+ )
+
+ def _greedy_search(
+ self,
+ input_ids: torch.Tensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ output_logits: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: bool = False,
+ streamer: Optional["BaseStreamer"] = None,
+ **model_kwargs,
+ ) -> Union[GenerateNonBeamOutput, torch.Tensor]:
+ input_ids = input_ids.to(self.device)
+ batch_size, cur_len = input_ids.shape
+ # init values
+ logits_processor = (
+ logits_processor if logits_processor is not None else
LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ "
`stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])`
instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(
+ stopping_criteria, max_length
+ )
+ pad_token_id = (
+ pad_token_id
+ if pad_token_id is not None
+ else self.generation_config.pad_token_id
+ )
+ if eos_token_id is not None:
+
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
+ else:
+ # remove when the method is totally private
+ # need to get `eos_token_id` and add stopping criteria, so that
generation does not go forever
+ eos_token_id = [
+ criteria.eos_token_id.tolist()
+ for criteria in stopping_criteria
+ if hasattr(criteria, "eos_token_id")
+ ]
+ eos_token_id = eos_token_id[0] if eos_token_id else None
+ if eos_token_id is None and self.generation_config.eos_token_id is
not None:
+ eos_token_id = self.generation_config.eos_token_id
+
stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id))
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ output_scores = (
+ output_scores
+ if output_scores is not None
+ else self.generation_config.output_scores
+ )
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ raw_logits = () if (return_dict_in_generate and output_logits) else
None
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ cross_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ decoder_hidden_states = (
+ () if (return_dict_in_generate and output_hidden_states) else None
+ )
+
+ # if model is an encoder-decoder, retrieve encoder attention weights
and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = (
+ model_kwargs["encoder_outputs"].get("attentions")
+ if output_attentions
+ else None
+ )
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states")
+ if output_hidden_states
+ else None
+ )
+
+ # keep track of which sequences are already finished
+ if "inputs_embeds" in model_kwargs:
+ cur_len = model_kwargs["inputs_embeds"].shape[1]
+ this_peer_finished = False
+ unfinished_sequences = torch.ones(
+ batch_size, dtype=torch.long, device=input_ids.device
+ )
+ model_kwargs["cache_position"] = torch.arange(cur_len,
device=input_ids.device)
+ true_seq_len = cur_len // self.config.input_token_len
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][
+ :, -true_seq_len:
+ ]
+ max_length = stopping_criteria.max_length
+ while self._has_unfinished_sequences(
+ this_peer_finished, synced_gpus, device=input_ids.device
+ ):
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids,
**model_kwargs)
+
+ input_length = input_ids.shape[1]
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ max_output_length=max_length - input_length,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't
need
+
+ next_token_logits = outputs.logits
+
+ # pre-process distribution
+ next_tokens_scores = logits_processor(input_ids, next_token_logits)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_tokens_scores,)
+ if output_logits:
+ raw_logits += (next_token_logits,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,)
+ if self.config.is_encoder_decoder
+ else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # argmax
+ # next_tokens = torch.argmax(next_tokens_scores, dim=-1)
+ next_tokens = next_tokens_scores
+
+ # finished sentences should have their next token be a padding
token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError(
+ "If `eos_token_id` is defined, make sure that
`pad_token_id` is defined."
+ )
+ next_tokens = next_tokens * unfinished_sequences +
pad_token_id * (
+ 1 - unfinished_sequences
+ )
+
+ # update generated ids, model inputs, and length for next step
+ horizon_length = next_tokens.shape[1] //
self.config.input_token_len
+
+ input_ids = torch.cat([input_ids, next_tokens], dim=-1)
+ if streamer is not None:
+ streamer.put(next_tokens.cpu())
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs,
+ model_kwargs,
+ horizon_length=horizon_length,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ )
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(
+ input_ids, scores
+ )
+ this_peer_finished = unfinished_sequences.max() == 0
+
+ if input_ids.shape[1] > max_length:
+ input_ids = input_ids[:, :max_length]
+
+ if streamer is not None:
+ streamer.end()
+
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return GenerateEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return GenerateDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ logits=raw_logits,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ past_key_values=model_kwargs.get("past_key_values"),
+ )
+ else:
+ return input_ids[:, -(max_length - cur_len) :]
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs: ModelOutput,
+ model_kwargs: Dict[str, Any],
+ horizon_length: int = 1,
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ # update past_key_values
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+ if getattr(outputs, "state", None) is not None:
+ model_kwargs["state"] = outputs.state
+
+ # update token_type_ids with last value
+ if "token_type_ids" in model_kwargs:
+ token_type_ids = model_kwargs["token_type_ids"]
+ model_kwargs["token_type_ids"] = torch.cat(
+ [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1
+ )
+
+ if not is_encoder_decoder:
+ # update attention mask
+ if "attention_mask" in model_kwargs:
+ attention_mask = model_kwargs["attention_mask"]
+ model_kwargs["attention_mask"] = torch.cat(
+ [
+ attention_mask,
+ attention_mask.new_ones(
+ (attention_mask.shape[0], horizon_length)
+ ),
+ ],
+ dim=-1,
+ )
+ else:
+ # update decoder attention mask
+ if "decoder_attention_mask" in model_kwargs:
+ decoder_attention_mask = model_kwargs["decoder_attention_mask"]
+ model_kwargs["decoder_attention_mask"] = torch.cat(
+ [
+ decoder_attention_mask,
+ decoder_attention_mask.new_ones(
+ (decoder_attention_mask.shape[0], horizon_length)
+ ),
+ ],
+ dim=-1,
+ )
+
+ if (
+ "cache_position" in model_kwargs
+ and model_kwargs["cache_position"] is not None
+ ):
+ model_kwargs["cache_position"] = (
+ model_kwargs["cache_position"][-1:] + horizon_length
+ )
+
+ return model_kwargs