在自然语言生成任务中,不同的采样技术用于从语言模型的输出中选择下一个生成的单词或词语。这些技术包括贪婪采样、温度采样、Top-k采样和核(Nucleus)采样。它们在选择生成单词的过程中有不同的策略,本文将介绍这四种采样方式的区别。
贪婪采样是一种直接选择最可能的下一个词的策略。具体步骤为:
class GreedySampler(Sampler):
def __call__(self, logits: torch.Tensor):
return logits.argmax(dim=-1)
温度采样通过引入一个温度参数来调整输出概率的分布,以控制生成文本的多样性。温度 T
的作用是平滑或锐化概率分布:
T = 1
时,采样为标准随机采样。T < 1
时,概率分布变得更尖锐,模型更倾向于选择最可能的词。T > 1
时,概率分布变得更加平滑,模型会更多地探索低概率的词。class TemperatureSampler(Sampler):
def __init__(self, temperature: float = 1.0):
self.temperature = temperature
def __call__(self, logits: torch.Tensor):
dist = Categorical(logits=logits / self.temperature)
return dist.sample()
Top-k采样限制了每次生成时候的候选词数量,模型只会从概率前k个最高的词中进行采样,而忽略其他可能性较小的词。
class TopKSampler(Sampler):
def __init__(self, k: int, sampler: Sampler):
self.k = k
self.sampler = sampler
def __call__(self, logits: torch.Tensor):
zeros = logits.new_ones(logits.shape) * float('-inf')
values, indices = torch.topk(logits, self.k, dim=-1)
zeros.scatter_(-1, indices, values)
return self.sampler(zeros)
k
值,如果 k
值太小,生成的文本可能会缺乏多样性;如果 k
值太大,则效果与标准采样相似。核采样是一种自适应的采样方法,它选择的候选词集合 V(p)
是满足累计概率和大于或等于给定阈值 p
的最小词汇子集。与Top-k采样不同,核采样的候选词数量不是固定的,而是基于累计概率动态确定的。
假设同样的语境:“今天的天气很”,但这次我们将会有不同的词汇及其概率分布,我们也会使用不同的阈值 ( p ) 来展示如何动态确定选词数量。
按概率从高到低排序并计算累积概率:
这次,我们将选择不同的阈值 ( p ) 来观察核集合如何变化:
在每种情况下,我们从对应的核集合中随机选取一个词作为下一个词。选择的范围和多样性取决于 ( p ) 值的大小,而词的数量是根据这个阈值动态确定的,不是固定的。
class NucleusSampler(Sampler):
"""
## Nucleus 采样器
Nucleus 采样器根据给定的概率 p 选择词汇的一个子集,并从中进行采样。
"""
def __init__(self, p: float, sampler: Sampler):
"""
### 初始化
:param p: 要选择的令牌概率之和,即 p 值。
:param sampler: 用于从选定令牌中进行采样的采样器。
"""
# 保存 p 值
self.p = p
# 保存采样器
self.sampler = sampler
# 初始化 softmax 层,用于将 logits 转换为概率
self.softmax = nn.Softmax(dim=-1)
def __call__(self, logits: torch.Tensor):
"""
### 从 logits 中进行 Nucleus 采样
:param logits: 输入的 logits 张量,形状为 (batch_size, num_tokens)。
:return: 采样得到的令牌索引,形状为 (batch_size,)。
"""
# 获取概率 P(x_i | x_1:i-1)
probs = self.softmax(logits)
# 按降序对概率进行排序,并获取排序后的索引
sorted_probs, indices = torch.sort(probs, dim=-1, descending=True)
# 按排序顺序获取概率的累积总和
cum_sum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找出累积总和小于 p 的令牌
nucleus = cum_sum_probs < self.p
# 在前面加一个 True,这样我们可以在累积概率小于 p 的最小令牌数量之后添加一个令牌
nucleus = torch.cat([nucleus.new_ones(nucleus.shape[:-1] + (1,)), nucleus[..., :-1]], dim=-1)
# 获取对数概率并掩盖非核部分
sorted_log_probs = torch.log(sorted_probs)
sorted_log_probs[~nucleus] = float('-inf')
# 使用采样器从排序后的对数概率中进行采样
sampled_sorted_indexes = self.sampler(sorted_log_probs)
# 获取实际的索引
res = indices.gather(-1, sampled_sorted_indexes.unsqueeze(-1))
# 返回采样得到的令牌索引
return res.squeeze(-1)
p
的选择需要调节,不同任务可能需要不同的 p
值。采样方法 | 优点 | 缺点 |
---|---|---|
贪婪采样 | 简单、高效,始终选择最有可能的词 | 文本生成可能单一,缺乏多样性 |
温度采样 | 通过调整温度控制多样性,适应性强 | 温度的调节需要谨慎,过高或过低的温度可能产生不理想的结果 |
Top-k采样 | 控制候选词数量,避免选择低概率词 | k 值选择需要调节,k 太小可能导致文本单一 |
核采样 | 动态选择候选词集合,更灵活,生成文本质量较高 | 参数 p 需要调节,计算复杂度较高 |
每种采样方式都有其适用的场景,根据具体的应用和对生成文本的要求,可以选择不同的采样策略。