CBOW(Continuous Bag-of-Words)
"The cat sits on the mat"
["the", "cat", "on", "the"]
(上下文) → 输出预测:"sits"
Skip-gram
"sits"
→ 输出预测:["the", "cat", "on", "the"]
"bank"
可能指河岸或银行)。import collections
import math
import random
import sys
import time
import torch.utils.data as Data
import torch
from torch import nn
# Word2Vec提取相似文本
# 以只读方式打开文件 读取单词 存储在列表中
with open('HarryPotter.txt','r') as f:
lines = f.readlines()
raw_dataset=[st.split() for st in lines]
"""
tk for st in raw_dataset for tk in st 等价于
result = []
for st in raw_dataset: # 遍历每一条文本(每个子列表)
for tk in st: # 遍历文本中的每一个词(token)
result.append(tk) # 将词加入结果列表
"""
# tk是token的缩写
counter = collections.Counter([tk for st in raw_dataset for tk in st])
# 过滤低频词 只保留在数据集中至少出现5次的词
counter = dict(filter(lambda x: x[1] >= 5, counter.items()))
# 生成从索引到词的映射表 idx_to_token
"""
counter.items() 返回的是词频字典的所有键值对(词 + 出现次数)。
tk for tk, _ in counter.items() 遍历这些键值对,只取词(忽略频率),生成一个列表。
结果:idx_to_token 是一个列表,索引是词在列表中的位置,值是对应的词。
"""
idx_to_token = [tk for tk, _ in counter.items()]
# 生成从词到索引的映射表 token_to_idx
"""
使用 enumerate 遍历 idx_to_token 列表,得到每个词及其对应的索引。
构建一个字典,键是词(token),值是该词对应的索引(index)。
"""
token_to_idx = {tk: idx for idx, tk in enumerate(idx_to_token)}
# 将原始文本转换为索引表示 dataset raw_dataset中的单词在这一步被转换为对应的idx
"""
外层循环遍历每条原始文本 st(如句子或段落)。
内层循环遍历每个词 tk,如果这个词存在于 token_to_idx 中,则将其转换为对应的索引。
最终结果是一个二维列表,其中每个子列表是一条文本对应的词索引序列。
"""
dataset = [[token_to_idx[tk] for tk in st if tk in token_to_idx]
for st in raw_dataset]
# 统计所有保留词的总数量 num_tokens
num_tokens = sum([len(st) for st in dataset])
'''
二次采样操作
降频操作: 越高频率的词一般意义不大,根据公式高频词越容易被过滤。既不希望超高频被完全过滤,又希望减少高频词对训练的影响。
'''
def discard(idx):
return random.uniform(0, 1) < 1 - math.sqrt(
1e-4 / counter[idx_to_token[idx]] * num_tokens)
'''
提取中心词和背景词 将与中心词距离不超过背景窗口大小的词作为背景词
get_centers_and_contexts 函数提取所有的中心词和背景词
每次在整数1和max_window_size之间随机均匀采样一个整数作为背景窗口大小
'''
def get_centers_and_contexts(dataset, max_window_size):
"""
从给定的数据集中获取中心和上下文
:param dataset: 数据集
:param max_window_size: 最大背景窗口
:return: 一个包含中心词和上下文的元组(centers, contexts)
"""
# 中心词和背景词列表
centers, contexts = [], []
# 遍历数据集中的每个字符串st
for st in dataset:
# 跳过长度小于2的句子
if len(st) < 2:
continue
# 将符合要求的单词添加到中心词列表中
centers += st
# 对于每个中心词
for center_i in range(len(st)):
# 随机选择一个背景窗口大小
window_size = random.randint(1, max_window_size)
# 生成一个包含中心词索引和周围索引的列表
'''
center_i:当前中心词在句子 st 中的索引位置。
window_size:随机选择的窗口大小(范围为 1 到 max_window_size)。
len(st):当前句子中词的总数量。
1. center_i - window_size 计算窗口的起始索引(左边界),不能小于 0,所以用 max(0, ...) 来限制。
2. center_i + 1 + window_size 计算窗口的结束索引(右边界),注意:
加 1 是因为 Python 的 range(start, end) 是左闭右开区间;所以 center_i + 1 表示从中心词开始,至少包括它本身;
再加上 window_size 就能取到右边最多 window_size 个词;
使用 min(..., len(st)) 防止超出句子长度。
3. range(...) + list(...) 把这个索引范围转换成一个整数列表,即所有可能的上下文词的位置。
'''
indices = list(range(max(0, center_i - window_size), min(len(st), center_i + 1 + window_size)))
# 将中心词索引从列表中移除
indices.remove(center_i)
contexts.append([st[idx] for idx in indices])
# 返回中心和上下文列表
return centers, contexts
# 假设最大背景窗口大小为5 提取中心词和背景词
all_centers, all_contexts = get_centers_and_contexts(dataset, 5)
"""
负采样近似加快程序运行时间
对于一对中心词和背景词 随机采样5个噪声词
噪声词采样率P(w)设为w词频与总词频之比的0.75次方
"""
def get_negatives(all_contexts, sampling_weights, K):
"""
用于获取负样本
:param all_contexts: 所有上下文的列表
:param sampling_weights: 采样权重的列表
:param K: 所需的负样本数量
:return: all_negatives 包含负样本的列表
"""
all_negatives, neg_candidates, i = [], [], 0
# 生成一个候选噪声词的列表,其中每个噪声词的采样权重与原始词的采样权重相同
population = list(range(len(sampling_weights)))
for contexts in all_contexts:
# 初始化一个空列表用于存储负样本
negatives = []
# 当负样本数量小于上下文数量与k的乘积时,继续循环
while len(negatives) < len(contexts) * K:
# 如果i等于候选噪声词列表的长度,则重新生成候选噪声词列表
if i == len(neg_candidates):
i, neg_candidates = 0, random.choices(
population, sampling_weights, k=int(1e5))
# 获取下一个负样本和更新i
neg, i = neg_candidates[i], i + 1
# 如果负样本不在上下文中,则添加到负样本列表中
if neg not in set(contexts):
negatives.append(neg)
# 将当前上下文的负样本添加到all_negatives列表中
all_negatives.append(negatives)
return all_negatives
# 计算采样权重 使用计数器中每个元素的0.75次方
sampling_weights = [counter[w] ** 0.75 for w in idx_to_token]
# 获取负样本
all_negatives = get_negatives(all_contexts, sampling_weights, 5)
# 小批量读取函数batchify
"""
小批量输入data是一个列表 其中每个元素分别为中心词center、背景词context和噪声词negative
"""
def batchify(data):
max_len = max(len(c) + len(n) for _, c, n in data)
centers, contexts_negatives, masks, labels = [], [], [], []
for center, context, negative in data:
cur_len = len(context) + len(negative)
centers += [center]
contexts_negatives += [context + negative + [0] * (max_len - cur_len)]
masks += [[1] * cur_len + [0] * (max_len - cur_len)]
labels += [[1] * len(context) + [0] * (max_len - len(context))]
batch = (torch.tensor(centers).view(-1, 1), torch.tensor(contexts_negatives),
torch.tensor(masks), torch.tensor(labels))
return batch
class MyDataset(torch.utils.data.Dataset):
def __init__(self, centers, contexts, negatives):
assert len(centers) == len(contexts) == len(negatives)
self.centers = centers
self.contexts = contexts
self.negatives = negatives
def __getitem__(self, index):
return (self.centers[index], self.contexts[index], self.negatives[index])
def __len__(self):
return len(self.centers)
# 定义批次大小 并根据操作系统设置线程数
batch_size = 256
num_workers = 0 if sys.platform.startswith('win32') else -1
# 创建数据集
dataset = MyDataset(all_centers, all_contexts, all_negatives)
# 创建数据加载器
data_iter = Data.DataLoader(dataset, batch_size, shuffle=True,
collate_fn=batchify,
num_workers=num_workers)
# 遍历数据集
for batch in data_iter:
for name, data in zip(['centers', 'contexts_negatives', 'masks', 'labels'], batch):
print(name, 'shape:', data.shape)
break
# 搭建网络模型
#采用交叉熵损失函数
class SigmoidBinaryCrossEntropyLoss(nn.Module):
def __init__(self):
super(SigmoidBinaryCrossEntropyLoss, self).__init__()
def forward(self, inputs, targets, mask=None):
inputs, targets, mask = inputs.float(), targets.float(), mask.float()
res = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction="none", weight=mask)
res = res.sum(dim=1) / mask.float().sum(dim=1)
return res
loss = SigmoidBinaryCrossEntropyLoss()
# 定义sigmd函数
def sigmd(x):
return - math.log(1 / (1 + math.exp(-x)))
# 设置嵌入向量的大小
embed_size = 200
# 创建一个神经网络
"""
嵌入层 将输入的索引映射到嵌入向量 嵌入向量的维度为embed_size
nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)
nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size)
"""
net = nn.Sequential(nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size),
nn.Embedding(num_embeddings=len(idx_to_token), embedding_dim=embed_size))
"""
center 中心词
contexts_and_negatives 上下文词和负样本
embed_v 将输入映射到嵌入向量的函数
embed_u 将输入映射到嵌入向量的函数
"""
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):
v = embed_v(center)
u = embed_u(contexts_and_negatives)
pred = torch.bmm(v, u.permute(0, 2, 1))
return pred
"""
训练网络模型
"""
def train(net, lr, num_epochs):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("train on", device)
net = net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
for epoch in range(num_epochs):
start, l_sum, n = time.time(), 0.0, 0
for batch in data_iter:
center, context_negative, mask, label = [d.to(device) for d in batch]
pred = skip_gram(center, context_negative, net[0], net[1])
l = loss(pred.view(label.shape), label, mask).mean()
optimizer.zero_grad()
l.backward()
optimizer.step()
l_sum += l.cpu().item()
n += 1
print('epoch %d, loss %.2f, time %.2fs'
% (epoch + 1, l_sum / n, time.time() - start))
train(net, 0.01, 5)
# 定义函数用于获取给定查询令牌相似的令牌
def get_similar_tokens(query_token, k, embed):
W = embed.weight.data
x = W[token_to_idx[query_token]]
cos = torch.matmul(W, x) / (torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9).sqrt()
_, topk = torch.topk(cos, k=k + 1)
topk = topk.cpu().numpy()
for i in topk[1:]:
print('余弦相似度 = %.3f: %s' % (cos[i], (idx_to_token[i])))
# 调用函数获取与'Dursley'最相似的5个令牌
get_similar_tokens('Dursley', 5, net[0])
centers shape: torch.Size([256, 1])
contexts_negatives shape: torch.Size([256, 60])
masks shape: torch.Size([256, 60])
labels shape: torch.Size([256, 60])
train on cuda
epoch 1, loss 3.27, time 236.44s
epoch 2, loss 1.17, time 237.80s
epoch 3, loss 0.70, time 236.57s
epoch 4, loss 0.53, time 241.82s
epoch 5, loss 0.45, time 241.26s
余弦相似度 = 0.290: was
余弦相似度 = 0.285: one.
余弦相似度 = 0.282: tight
余弦相似度 = 0.278: moment,
余弦相似度 = 0.271: loudly,