lstm 数据输入问题

lstm 我有 20*6 条数据,20个样本,每个样本6条历史数据,每条数据有5个值,我送给网络输入时应该是20*6*5 还是 6*20*5

你的数据是:

  • 20 个样本(batch size = 20)

  • 每个样本有 6 条历史数据(sequence length = 6)

  • 每条数据有 5 个值(input size = 5)


✅ 正确的输入形状是:

(20, 6, 5) # 即 batch_size=20, seq_len=6, input_size=5

前提是你在定义 LSTM 时设置了:

nn.LSTM(input_size=5, hidden_size=..., batch_first=True)

你可能感兴趣的:(python基础,lstm,人工智能)