MultiPatchFormer改进|爆改模型|涨点|使用TiDE中的MLP结构替换半自动回归的线性层(附代码+修改教程)

一、本文介绍

MultiPatchFormer是一种的基于Transformer的模型,用于多元时间序列预测使用多尺度时间嵌入来捕捉不同尺度的时间相关性,包括一个通道编码器来建模输入序列之间的关系并捕捉交互作用。

MultiPatchFormer 采用了一种简单但有效的方法来避免过拟合的问题,即通过线性层在连续的步骤中解码提取的信息,而不是使用单一的线性层将模型维度映射到预测窗口。这种方法被称为"半自回归解码"(multi-step linear decoder),可以减少过拟合和噪声的影响。

TiDE是一种时间序列密集编码器,用于长期时间序列预测。TiDE使用密集的多层感知机(MLP)对时间序列的过去及协变量进行编码,然后使用密集的MLP对时间序列及未来协变量进行解码。

本文采用TiDE的多层感知机结构替换MultiPatchFormer半自回归解码中的线性层,以提升模型预测准确率。

MulitPathcFormer论文地址:https://www.nature.com/articles/s41598-024-82417-4

TiDE论文地址:https://arxiv.org/pdf/2304.08424.pdf

代码地址:https://github.com/thuml/Time-Series-Library

二、模型图

MLP-半自动回归模型架构

MultiPatchFormer改进|爆改模型|涨点|使用TiDE中的MLP结构替换半自动回归的线性层(附代码+修改教程)_第1张图片

TiDE中的MLP结构
MultiPatchFormer改进|爆改模型|涨点|使用TiDE中的MLP结构替换半自动回归的线性层(附代码+修改教程)_第2张图片

三、核心代码

定义MLP

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class ResBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
        self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.ln = LayerNorm(output_dim, bias=bias)
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, x):

        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out + self.fc3(x)
        out = self.ln(out)
        return out

修改MulitPatchFormer中半自动回归的定义

        self.out_linear_1 = ResBlock(self.d_model, self.d_hidden, self.pred_len // 8)
        self.out_linear_2 = ResBlock(
            self.d_model + self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_3 = ResBlock(
            self.d_model + 2 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_4 = ResBlock(
            self.d_model + 3 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_5 = ResBlock(
            self.d_model + self.pred_len // 2, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_6 = ResBlock(
            self.d_model + 5 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_7 = ResBlock(
            self.d_model + 6 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_8 = ResBlock(
            self.d_model + 7 * self.pred_len // 8
            , self.d_hidden,
            self.pred_len - 7 * (self.pred_len // 8)
        )

完整代码

import torch
import torch.nn as nn
import math
from einops import rearrange

from layers.SelfAttention_Family import AttentionLayer, FullAttention
import torch.nn.functional as F

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """

    def __init__(self, ndim, bias):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class ResBlock(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout=0.1, bias=True):
        super().__init__()

        self.fc1 = nn.Linear(input_dim, hidden_dim, bias=bias)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias=bias)
        self.fc3 = nn.Linear(input_dim, output_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.ln = LayerNorm(output_dim, bias=bias)
        self.input_dim = input_dim
        self.output_dim = output_dim

    def forward(self, x):

        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out + self.fc3(x)
        out = self.ln(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, d_model: int, d_hidden: int = 512):
        super(FeedForward, self).__init__()

        self.linear_1 = torch.nn.Linear(d_model, d_hidden)
        self.linear_2 = torch.nn.Linear(d_hidden, d_model)
        self.activation = torch.nn.GELU()

    def forward(self, x):
        x = self.linear_1(x)
        x = self.activation(x)
        x = self.linear_2(x)

        return x


class Encoder(nn.Module):
    def __init__(
        self,
        d_model: int,
        mha: AttentionLayer,
        d_hidden: int,
        dropout: float = 0,
        channel_wise=False,
    ):
        super(Encoder, self).__init__()

        self.channel_wise = channel_wise
        if self.channel_wise:
            self.conv = torch.nn.Conv1d(
                in_channels=d_model,
                out_channels=d_model,
                kernel_size=1,
                stride=1,
                padding=0,
                padding_mode="reflect",
            )
        self.MHA = mha
        self.feedforward = FeedForward(d_model=d_model, d_hidden=d_hidden)
        self.dropout = torch.nn.Dropout(p=dropout)
        self.layerNormal_1 = torch.nn.LayerNorm(d_model)
        self.layerNormal_2 = torch.nn.LayerNorm(d_model)

    def forward(self, x):
        residual = x
        q = residual
        if self.channel_wise:
            x_r = self.conv(x.permute(0, 2, 1)).transpose(1, 2)
            k = x_r
            v = x_r
        else:
            k = residual
            v = residual
        x, score = self.MHA(q, k, v, attn_mask=None)
        x = self.dropout(x)
        x = self.layerNormal_1(x + residual)

        residual = x
        x = self.feedforward(residual)
        x = self.dropout(x)
        x = self.layerNormal_2(x + residual)

        return x, score


class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.d_channel = configs.enc_in
        self.N = configs.e_layers
        # Embedding
        self.d_model = configs.d_model
        self.d_hidden = configs.d_ff
        self.n_heads = configs.n_heads
        self.mask = True
        self.dropout = configs.dropout

        self.stride1 = 8
        self.patch_len1 = 8
        self.stride2 = 8
        self.patch_len2 = 16
        self.stride3 = 7
        self.patch_len3 = 24
        self.stride4 = 6
        self.patch_len4 = 32
        self.patch_num1 = int((self.seq_len - self.patch_len2) // self.stride2) + 2
        self.padding_patch_layer1 = nn.ReplicationPad1d((0, self.stride1))
        self.padding_patch_layer2 = nn.ReplicationPad1d((0, self.stride2))
        self.padding_patch_layer3 = nn.ReplicationPad1d((0, self.stride3))
        self.padding_patch_layer4 = nn.ReplicationPad1d((0, self.stride4))

        self.shared_MHA = nn.ModuleList(
            [
                AttentionLayer(
                    FullAttention(mask_flag=self.mask),
                    d_model=self.d_model,
                    n_heads=self.n_heads,
                )
                for _ in range(self.N)
            ]
        )

        self.shared_MHA_ch = nn.ModuleList(
            [
                AttentionLayer(
                    FullAttention(mask_flag=self.mask),
                    d_model=self.d_model,
                    n_heads=self.n_heads,
                )
                for _ in range(self.N)
            ]
        )

        self.encoder_list = nn.ModuleList(
            [
                Encoder(
                    d_model=self.d_model,
                    mha=self.shared_MHA[ll],
                    d_hidden=self.d_hidden,
                    dropout=self.dropout,
                    channel_wise=False,
                )
                for ll in range(self.N)
            ]
        )

        self.encoder_list_ch = nn.ModuleList(
            [
                Encoder(
                    d_model=self.d_model,
                    mha=self.shared_MHA_ch[0],
                    d_hidden=self.d_hidden,
                    dropout=self.dropout,
                    channel_wise=True,
                )
                for ll in range(self.N)
            ]
        )

        pe = torch.zeros(self.patch_num1, self.d_model)
        for pos in range(self.patch_num1):
            for i in range(0, self.d_model, 2):
                wavelength = 10000 ** ((2 * i) / self.d_model)
                pe[pos, i] = math.sin(pos / wavelength)
                pe[pos, i + 1] = math.cos(pos / wavelength)
        pe = pe.unsqueeze(0)  # add a batch dimention to your pe matrix
        self.register_buffer("pe", pe)

        self.embedding_channel = nn.Conv1d(
            in_channels=self.d_model * self.patch_num1,
            out_channels=self.d_model,
            kernel_size=1,
        )

        self.embedding_patch_1 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.d_model // 4,
            kernel_size=self.patch_len1,
            stride=self.stride1,
        )
        self.embedding_patch_2 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.d_model // 4,
            kernel_size=self.patch_len2,
            stride=self.stride2,
        )
        self.embedding_patch_3 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.d_model // 4,
            kernel_size=self.patch_len3,
            stride=self.stride3,
        )
        self.embedding_patch_4 = torch.nn.Conv1d(
            in_channels=1,
            out_channels=self.d_model // 4,
            kernel_size=self.patch_len4,
            stride=self.stride4,
        )

        self.out_linear_1 = ResBlock(self.d_model, self.d_hidden, self.pred_len // 8)
        self.out_linear_2 = ResBlock(
            self.d_model + self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_3 = ResBlock(
            self.d_model + 2 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_4 = ResBlock(
            self.d_model + 3 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_5 = ResBlock(
            self.d_model + self.pred_len // 2, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_6 = ResBlock(
            self.d_model + 5 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_7 = ResBlock(
            self.d_model + 6 * self.pred_len // 8, self.d_hidden, self.pred_len // 8
        )
        self.out_linear_8 = ResBlock(
            self.d_model + 7 * self.pred_len // 8
            , self.d_hidden,
            self.pred_len - 7 * (self.pred_len // 8)
        )

        self.remap = torch.nn.Linear(self.d_model, self.seq_len)

    def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
        # Normalization
        means = x_enc.mean(1, keepdim=True).detach()
        x_enc = x_enc - means
        stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
        x_enc /= stdev

        # Multi-scale embedding
        x_i = x_enc.permute(0, 2, 1)

        x_i_p1 = x_i
        x_i_p2 = self.padding_patch_layer2(x_i)
        x_i_p3 = self.padding_patch_layer3(x_i)
        x_i_p4 = self.padding_patch_layer4(x_i)
        encoding_patch1 = self.embedding_patch_1(
            rearrange(x_i_p1, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        encoding_patch2 = self.embedding_patch_2(
            rearrange(x_i_p2, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        encoding_patch3 = self.embedding_patch_3(
            rearrange(x_i_p3, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)
        encoding_patch4 = self.embedding_patch_4(
            rearrange(x_i_p4, "b c l -> (b c) l").unsqueeze(-1).permute(0, 2, 1)
        ).permute(0, 2, 1)

        encoding_patch = (
            torch.cat(
                (encoding_patch1, encoding_patch2, encoding_patch3, encoding_patch4),
                dim=-1,
            )
            + self.pe
        )
        # Temporal encoding
        for i in range(self.N):
            encoding_patch = self.encoder_list[i](encoding_patch)[0]

        # Channel-wise encoding
        x_patch_c = rearrange(
            encoding_patch, "(b c) p d -> b c (p d)", b=x_enc.shape[0], c=self.d_channel
        )
        x_ch = self.embedding_channel(x_patch_c.permute(0, 2, 1)).transpose(
            1, 2
        )  # [b c d]

        encoding_1_ch = self.encoder_list_ch[0](x_ch)[0]

        # Semi Auto-regressive
        forecast_ch1 = self.out_linear_1(encoding_1_ch)
        forecast_ch2 = self.out_linear_2(
            torch.cat((encoding_1_ch, forecast_ch1), dim=-1)
        )
        forecast_ch3 = self.out_linear_3(
            torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2), dim=-1)
        )
        forecast_ch4 = self.out_linear_4(
            torch.cat((encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3), dim=-1)
        )
        forecast_ch5 = self.out_linear_5(
            torch.cat(
                (encoding_1_ch, forecast_ch1, forecast_ch2, forecast_ch3, forecast_ch4),
                dim=-1,
            )
        )
        forecast_ch6 = self.out_linear_6(
            torch.cat(
                (
                    encoding_1_ch,
                    forecast_ch1,
                    forecast_ch2,
                    forecast_ch3,
                    forecast_ch4,
                    forecast_ch5,
                ),
                dim=-1,
            )
        )
        forecast_ch7 = self.out_linear_7(
            torch.cat(
                (
                    encoding_1_ch,
                    forecast_ch1,
                    forecast_ch2,
                    forecast_ch3,
                    forecast_ch4,
                    forecast_ch5,
                    forecast_ch6,
                ),
                dim=-1,
            )
        )
        forecast_ch8 = self.out_linear_8(
            torch.cat(
                (
                    encoding_1_ch,
                    forecast_ch1,
                    forecast_ch2,
                    forecast_ch3,
                    forecast_ch4,
                    forecast_ch5,
                    forecast_ch6,
                    forecast_ch7,
                ),
                dim=-1,
            )
        )

        final_forecast = torch.cat(
            (
                forecast_ch1,
                forecast_ch2,
                forecast_ch3,
                forecast_ch4,
                forecast_ch5,
                forecast_ch6,
                forecast_ch7,
                forecast_ch8,
            ),
            dim=-1,
        ).permute(0, 2, 1)

        # De-Normalization
        dec_out = final_forecast * (
            stdev[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1)
        )
        dec_out = dec_out + (means[:, 0].unsqueeze(1).repeat(1, self.pred_len, 1))
        return dec_out

    def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if (
            self.task_name == "long_term_forecast"
            or self.task_name == "short_term_forecast"
        ):
            dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
            return dec_out[:, -self.pred_len :, :]  # [B, L, D]
        if self.task_name == "imputation":
            raise NotImplementedError(
                "Task imputation for WPMixer is temporarily not supported"
            )
        if self.task_name == "anomaly_detection":
            raise NotImplementedError(
                "Task anomaly_detection for WPMixer is temporarily not supported"
            )
        if self.task_name == "classification":
            raise NotImplementedError(
                "Task classification for WPMixer is temporarily not supported"
            )
        return None

你可能感兴趣的:(回归,数据挖掘,人工智能,python,神经网络)