【学习】PyTorch中的nn.Embedding的用法

基本理解

nn.Embedding(num_embeddings, embedding_dim)

其中 num_embeddings 是词表的大小,即 len(vocab);embedding_dim 是词向量的维度。

nn.Embedding()产生一个权重矩阵weight,其shape为(num_embeddings, embedding_dim),表示生成num_embeddings个具有embedding_dim大小的嵌入向量;

输入input的形状shape为(batch_size, Seq_len),batch_size表示样本数(NLP句子数),Seq_len表示序列的长度(每个句子单词个数);

nn.Embedding(input)的输出output具有(batch_size,Seq_len,embedding_dim)的形状大小;

示例:

import torch
import torch.nn a

你可能感兴趣的:(深度学习学习记录,学习,pytorch,embedding)