lstm 输入数据维度_[mcj]pytorch中LSTM的输入输出解释||LSTM输入输出详解

最近想了解一些关于LSTM的相关知识,在进行代码测试的时候,有个地方一直比较疑惑,关于LSTM的输入和输出问题。一直不清楚在pytorch里面该如何定义LSTM的输入和输出。首先看个pytorch官方的例子:

# 首先导入LSTM需要的相关模块

import torch

import torch.nn as nn # 神经网络模块

# 数据向量维数10, 隐藏元维度20, 2个LSTM层串联(如果是1,可以省略,默认为1)

rnn = nn.LSTM(10, 20, 2)

# 序列长度seq_len=5, batch_size=3, 数据向量维数=10

input = torch.randn(5, 3, 10)

# 初始化的隐藏元和记忆元,通常它们的维度是一样的

# 2个LSTM层,batch_size=3,隐藏元维度20

h0 = torch.randn(2, 3, 20)

c0 = torch.randn(2, 3, 20)

# 这里有2层lstm,output是最后一层lstm的每个词向量对应隐藏层的输出,其与层数无关,只与序列长度相关

# hn,cn是所有层最后一个隐藏元和记忆元的输出

output, (hn, cn) = rnn(input, (h0, c0))

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

# 首先导入LSTM需要的相关模块

importtorch

importtorch.nnasnn# 神经网络模块<

你可能感兴趣的:(lstm,输入数据维度)