torch.topk()
是 PyTorch 中用于高效提取张量极值的关键函数,可返回指定维度上最大或最小的 k
个元素及其索引。以下从参数、作用和应用场景三方面详解:
函数签名:
torch.topk(input, k, dim=None, largest=True, sorted=True, *, out=None)
参数 | 类型 | 默认值 | 说明 |
---|---|---|---|
input |
Tensor | 必填 | 输入张量 |
k |
int | 必填 | 需返回的元素数量 |
dim |
int/None | None |
操作的维度。若为 None ,函数将展平张量后操作;未指定时默认取最后一维(dim=-1 ) |
largest |
bool | True |
True 返回最大值,False 返回最小值 |
sorted |
bool | True |
True 时结果按值降序(largest=True )或升序(largest=False )排列;False 时不保证顺序 |
out |
Tuple | 可选 | 预分配的输出元组,用于减少内存占用 |
作用:
在张量 input
的维度 dim
上,选取 最大或最小的 k
个元素,返回这些值及其在原张量中的索引。适用场景包括模型预测结果分析、特征选择、数据采样等 。
返回值:
values
:包含极值的张量,形状与 input
相同,但 dim
维度大小变为 k
。indices
:索引张量,与 values
同形,标记极值在 input
中的位置 。示例(二维张量):
import torch
x = torch.tensor([[0.1, 0.8, 0.6],
[0.9, 0.2, 0.3]])
# 每行取最大2个值(dim=1)
values, indices = torch.topk(x, k=2, dim=1)
# values: tensor([[0.8, 0.6], [0.9, 0.3]])
# indices: tensor([[1, 2], [0, 2]])
分类任务中提取预测类别
k
个预测类别及概率: probs = torch.softmax(model_output, dim=1) # 概率张量 [batch, classes]
topk_probs, topk_indices = torch.topk(probs, k=3, dim=1) # 每样本取Top-3类别[3,7](@ref)
推荐系统生成Top-K推荐
k
个物品: user_item_scores = ... # 张量 [users, items]
top_scores, top_items = torch.topk(user_item_scores, k=5, dim=1) # 每用户推荐5个物品[5](@ref)
特征选择与数据压缩
k
个通道,抑制噪声: feature_map = ... # 张量 [batch, channels, H, W]
topk_vals, _ = torch.topk(feature_map, k=10, dim=1) # 每位置保留前10通道[4](@ref)
采样优化(如对抗训练)
k
个样本加强训练: losses = ... # 张量 [batch]
hard_samples = torch.topk(losses, k=batch//2, largest=False) # 取损失最小的样本(即易分样本)[2,6](@ref)
高效近似计算
k
个最相关的键-值对,降低计算复杂度(如Top-K Attention)。torch.topk
通过极值索引提取,兼顾了高效性与灵活性,成为PyTorch中处理排序、推荐、特征选择等任务的核心工具。其参数设计(如 dim
指定维度、largest
切换极值方向)覆盖了多维张量操作需求,而返回值直接关联原始数据位置,便于后续分析或反向传播。
k
的有效性:若 k
超过维度长度(如 dim=1
时 k > 列数
),则返回该维度所有元素。
内存效率:对于大型张量,指定 dim
可显著减少内存占用。
排序开销:当 sorted=False
时可提升速度(但结果无序)。
通过灵活使用 torch.topk
,可高效解决涉及极值选择的各类深度学习任务。