MultiPatchFormer是一种的基于Transformer的模型,用于多元时间序列预测使用多尺度时间嵌入来捕捉不同尺度的时间相关性,包括一个通道编码器来建模输入序列之间的关系并捕捉交互作用。
MultiPatchFormer 采用了一种简单但有效的方法来避免过拟合的问题,即通过线性层在连续的步骤中解码提取的信息,而不是使用单一的线性层将模型维度映射到预测窗口。这种方法被称为"半自回归解码"(multi-step linear decoder),可以减少过拟合和噪声的影响。
TiDE是一种时间序列密集编码器,用于长期时间序列预测。TiDE使用密集的多层感知机(MLP)对时间序列的过去及协变量进行编码,然后使用密集的MLP对时间序列及未来协变量进行解码。
本文采用TiDE的多层感知机结构替换MultiPatchFormer半自回归解码中的线性层,以提升模型预测准确率。
MulitPathcFormer论文地址:https://www.nature.com/articles/s41598-024-82417-4
TiDE论文地址:https://arxiv.org/pdf/2304.08424.pdf
代码地址:https://github.com/thuml/Time-Series-Library
MLP-半自动回归模型架构
定义MLP
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class ResBlock(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.ln = LayerNorm(output_dim, bias=bias)
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.dropout(out)
out = out + self.fc3(x)
out = self.ln(out)
return out
修改MulitPatchFormer中半自动回归的定义
self.out_linear_1 = ResBlock(self.d_model, self.d_hidden, self.pred_len // 8)
self.out_linear_2 = ResBlock(
self.d_model + self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_3 = ResBlock(
self.d_model + 2 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_4 = ResBlock(
self.d_model + 3 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_5 = ResBlock(
self.d_model + self.pred_len // 2, self.d_hidden, self.pred_len // 8
)
self.out_linear_6 = ResBlock(
self.d_model + 5 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_7 = ResBlock(
self.d_model + 6 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_8 = ResBlock(
self.d_model + 7 * self.pred_len // 8
, self.d_hidden,
self.pred_len - 7 * (self.pred_len // 8)
)
完整代码
import torch
import torch.nn as nn
import math
from einops import rearrange
from layers.SelfAttention_Family import AttentionLayer, FullAttention
import torch.nn.functional as F
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class ResBlock(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.ln = LayerNorm(output_dim, bias=bias)
self.input_dim = input_dim
self.output_dim = output_dim
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
out = self.dropout(out)
out = out + self.fc3(x)
out = self.ln(out)
return out
class FeedForward(nn.Module):
def __init__(self, d_model: int, d_hidden: int = 512):
super(FeedForward, self).__init__()
self.linear_1 = torch.nn.Linear(d_model, d_hidden)
self.linear_2 = torch.nn.Linear(d_hidden, d_model)
self.activation = torch.nn.GELU()
def forward(self, x):
x = self.linear_1(x)
x = self.activation(x)
x = self.linear_2(x)
return x
class Encoder(nn.Module):
def __init__(
self,
d_model: int,
mha: AttentionLayer,
d_hidden: int,
dropout: float = 0,
channel_wise=False,
):
super(Encoder, self).__init__()
self.channel_wise = channel_wise
if self.channel_wise:
self.conv = torch.nn.Conv1d(
in_channels=d_model,
out_channels=d_model,
kernel_size=1,
stride=1,
padding=0,
padding_mode="reflect",
)
self.MHA = mha
self.feedforward = FeedForward(d_model=d_model, d_hidden=d_hidden)
self.dropout = torch.nn.Dropout(p=dropout)
self.layerNormal_1 = torch.nn.LayerNorm(d_model)
self.layerNormal_2 = torch.nn.LayerNorm(d_model)
def forward(self, x):
residual = x
q = residual
if self.channel_wise:
x_r = self.conv(x.permute(0, 2, 1)).transpose(1, 2)
k = x_r
v = x_r
else:
k = residual
v = residual
x, score = self.MHA(q, k, v, attn_mask=None)
x = self.dropout(x)
x = self.layerNormal_1(x + residual)
residual = x
x = self.feedforward(residual)
x = self.dropout(x)
x = self.layerNormal_2(x + residual)
return x, score
class Model(nn.Module):
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
self.pred_len = configs.pred_len
self.d_channel = configs.enc_in
self.N = configs.e_layers
# Embedding
self.d_model = configs.d_model
self.d_hidden = configs.d_ff
self.n_heads = configs.n_heads
self.mask = True
self.dropout = configs.dropout
self.stride1 = 8
self.patch_len1 = 8
self.stride2 = 8
self.patch_len2 = 16
self.stride3 = 7
self.patch_len3 = 24
self.stride4 = 6
self.patch_len4 = 32
self.patch_num1 = int((self.seq_len - self.patch_len2) // self.stride2) + 2
self.padding_patch_layer1 = nn.ReplicationPad1d((0, self.stride1))
self.padding_patch_layer2 = nn.ReplicationPad1d((0, self.stride2))
self.padding_patch_layer3 = nn.ReplicationPad1d((0, self.stride3))
self.padding_patch_layer4 = nn.ReplicationPad1d((0, self.stride4))
self.shared_MHA = nn.ModuleList(
[
AttentionLayer(
FullAttention(mask_flag=self.mask),
d_model=self.d_model,
n_heads=self.n_heads,
)
for _ in range(self.N)
]
)
self.shared_MHA_ch = nn.ModuleList(
[
AttentionLayer(
FullAttention(mask_flag=self.mask),
d_model=self.d_model,
n_heads=self.n_heads,
)
for _ in range(self.N)
]
)
self.encoder_list = nn.ModuleList(
[
Encoder(
d_model=self.d_model,
mha=self.shared_MHA[ll],
d_hidden=self.d_hidden,
dropout=self.dropout,
channel_wise=False,
)
for ll in range(self.N)
]
)
self.encoder_list_ch = nn.ModuleList(
[
Encoder(
d_model=self.d_model,
mha=self.shared_MHA_ch[0],
d_hidden=self.d_hidden,
dropout=self.dropout,
channel_wise=True,
)
for ll in range(self.N)
]
)
pe = torch.zeros(self.patch_num1, self.d_model)
for pos in range(self.patch_num1):
for i in range(0, self.d_model, 2):
wavelength = 10000 ** ((2 * i) / self.d_model)
pe[pos, i] = math.sin(pos / wavelength)
pe[pos, i + 1] = math.cos(pos / wavelength)
pe = pe.unsqueeze(0) # add a batch dimention to your pe matrix
self.register_buffer("pe", pe)
self.embedding_channel = nn.Conv1d(
in_channels=self.d_model * self.patch_num1,
out_channels=self.d_model,
kernel_size=1,
)
self.embedding_patch_1 = torch.nn.Conv1d(
in_channels=1,
out_channels=self.d_model // 4,
kernel_size=self.patch_len1,
stride=self.stride1,
)
self.embedding_patch_2 = torch.nn.Conv1d(
in_channels=1,
out_channels=self.d_model // 4,
kernel_size=self.patch_len2,
stride=self.stride2,
)
self.embedding_patch_3 = torch.nn.Conv1d(
in_channels=1,
out_channels=self.d_model // 4,
kernel_size=self.patch_len3,
stride=self.stride3,
)
self.embedding_patch_4 = torch.nn.Conv1d(
in_channels=1,
out_channels=self.d_model // 4,
kernel_size=self.patch_len4,
stride=self.stride4,
)
self.out_linear_1 = ResBlock(self.d_model, self.d_hidden, self.pred_len // 8)
self.out_linear_2 = ResBlock(
self.d_model + self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_3 = ResBlock(
self.d_model + 2 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_4 = ResBlock(
self.d_model + 3 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_5 = ResBlock(
self.d_model + self.pred_len // 2, self.d_hidden, self.pred_len // 8
)
self.out_linear_6 = ResBlock(
self.d_model + 5 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_7 = ResBlock(
self.d_model + 6 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
)
self.out_linear_8 = ResBlock(
self.d_model + 7 * self.pred_len // 8
, self.d_hidden,
self.pred_len - 7 * (self.pred_len // 8)
)
self.remap = torch.nn.Linear(self.d_model, self.seq_len)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
# Normalization
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
x_enc /= stdev
# Multi-scale embedding
x_i = x_enc.permute(0, 2, 1)
x_i_p1 = x_i
x_i_p2 = self.padding_patch_layer2(x_i)
x_i_p3 = self.padding_patch_layer3(x_i)
x_i_p4 = self.padding_patch_layer4(x_i)
encoding_patch1 = self.embedding_patch_1(
rearrange(x_i_p1, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
).permute(0, 2, 1)
encoding_patch2 = self.embedding_patch_2(
rearrange(x_i_p2, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
).permute(0, 2, 1)
encoding_patch3 = self.embedding_patch_3(
rearrange(x_i_p3, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
).permute(0, 2, 1)
encoding_patch4 = self.embedding_patch_4(
rearrange(x_i_p4, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
).permute(0, 2, 1)
encoding_patch = (
torch.cat(
(encoding_patch1, encoding_patch2, encoding_patch3, encoding_patch4),
dim=-1,
)
+ self.pe
)
# Temporal encoding
for i in range(self.N):
encoding_patch = self.encoder_list[i](encoding_patch)[0]
# Channel-wise encoding
x_patch_c = rearrange(
encoding_patch, "(b c) p d -> b c (p d)", b=x_enc.shape[0], c=self.d_channel
)
x_ch = self.embedding_channel(x_patch_c.permute(0, 2, 1)).transpose(
1, 2
) # [b c d]
encoding_1_ch = self.encoder_list_ch[0](x_ch)[0]
# Semi Auto-regressive
forecast_ch1 = self.out_linear_1(encoding_1_ch)
forecast_ch2 = self.out_linear_2(
torch.cat((encoding_1_ch, forecast_ch1), dim=-1)
)
forecast_ch3 = self.out_linear_3(
torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2), dim=-1)
)
forecast_ch4 = self.out_linear_4(
torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3), dim=-1)
)
forecast_ch5 = self.out_linear_5(
torch.cat(
(encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3, forecast_ch4),
dim=-1,
)
)
forecast_ch6 = self.out_linear_6(
torch.cat(
(
encoding_1_ch,
forecast_ch1,
forecast_ch2,
forecast_ch3,
forecast_ch4,
forecast_ch5,
),
dim=-1,
)
)
forecast_ch7 = self.out_linear_7(
torch.cat(
(
encoding_1_ch,
forecast_ch1,
forecast_ch2,
forecast_ch3,
forecast_ch4,
forecast_ch5,
forecast_ch6,
),
dim=-1,
)
)
forecast_ch8 = self.out_linear_8(
torch.cat(
(
encoding_1_ch,
forecast_ch1,
forecast_ch2,
forecast_ch3,
forecast_ch4,
forecast_ch5,
forecast_ch6,
forecast_ch7,
),
dim=-1,
)
)
final_forecast = torch.cat(
(
forecast_ch1,
forecast_ch2,
forecast_ch3,
forecast_ch4,
forecast_ch5,
forecast_ch6,
forecast_ch7,
forecast_ch8,
),
dim=-1,
).permute(0, 2, 1)
# De-Normalization
dec_out = final_forecast * (
stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1)
)
dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1))
return dec_out
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if (
self.task_name == "long_term_forecast"
or self.task_name == "short_term_forecast"
):
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len :, :] # [B, L, D]
if self.task_name == "imputation":
raise NotImplementedError(
"Task imputation for WPMixer is temporarily not supported"
)
if self.task_name == "anomaly_detection":
raise NotImplementedError(
"Task anomaly_detection for WPMixer is temporarily not supported"
)
if self.task_name == "classification":
raise NotImplementedError(
"Task classification for WPMixer is temporarily not supported"
)
return None