torch.full()
是 PyTorch 中用于创建指定形状、所有元素值都相同的新张量的核心函数。它在深度学习中有广泛应用,尤其是在初始化张量和创建特殊数据结构时。
torch.full(size, fill_value, *,
dtype=None,
layout=torch.strided,
device=None,
requires_grad=False,
pin_memory=False,
memory_format=torch.contiguous_format)
参数 | 类型 | 描述 |
---|---|---|
size | tuple of ints | 定义张量形状的整数元组(如 (3, 4) 表示 3行4列) |
fill_value | scalar | 填充张量的值(整型或浮点型) |
dtype | torch.dtype (可选) | 张量的数据类型(默认根据 fill_value 类型推断) |
device | torch.device (可选) | 张量所在设备(CPU/GPU)(默认使用当前设备) |
requires_grad | bool (可选) | 是否需要计算梯度(默认 False) |
layout | torch.layout (可选) | 张量布局(默认 strided) |
pin_memory | bool (可选) | 是否使用锁页内存(默认 False) |
memory_format | torch.memory_format (可选) | 内存格式(默认 contiguous_format) |
创建满足以下条件的张量:
使用示例
#基础用法
import torch
#创建 2x3 的张量,所有元素值为 5
= torch.full((2, 3), 5)
print(a)
tensor([[5, 5, 5],
[5, 5, 5]])
创建 3x3 的浮点数张量,所有元素值为 3.14
= torch.full((3, 3), 3.14)
print(b)
tensor([[3.1400, 3.1400, 3.1400],
[3.1400, 3.1400, 3.1400],
[3.1400, 3.1400, 3.1400]])
高级用法
指定数据类型
= torch.full((2, 2), 1.5, dtype=torch.float16)
print(c)
tensor([[1.5000, 1.5000],
[1.5000, 1.5000]], dtype=torch.float16)
创建在GPU上的张量
= torch.full((3,), 10, device='cuda')
print(d)
tensor([10, 10, 10], device='cuda:0')
创建需要梯度的张量
= torch.full((2, 3), 0.1, requires_grad=True)
print(e.requires_grad) # True
创建4维张量(如批量大小×通道×高度×宽度)
= torch.full((2, 3, 4, 4), 0) # 创建全零掩码
print(f.shape) # torch.Size([2, 3, 4, 4])
与相似函数的对比
函数 | 描述 | 主要区别 |
---|---|---|
torch.full() | 直接指定形状填充 | 基本版本,灵活 |
torch.full_like() | 参考其他张量形状填充 | 复制其他张量的形状和属性 |
torch.ones() | 创建全1张量 | 固定值=1 |
torch.zeros() | 创建全0张量 | 固定值=0 |
torch.empty() | 创建未初始化张量 | 元素值随机 |
torch.tensor() | 从数据创建张量 | 可包含不同值 |
#初始化偏置项为0.1
bias = torch.full((128,), 0.1)
#初始化掩码为1
mask = torch.full((64, 64), 1)
#创建注意力掩码(1表示有效位置)
attn_mask = torch.full((batch_size, seq_len), 1.0)
#创建特殊值矩阵(如填充无效值)
invalid_mask = torch.full(data.shape, float('-inf'))
#直接在GPU上创建用于计算的张量
gpu_tensor = torch.full((100, 100), 0.5, device='cuda')
#创建需要高精度的常数张量
high_precision = torch.full((10,), 0.123456789, dtype=torch.float64)
当未指定 dtype 时:
# 整数4会被转换为浮点数4.0
tensor = torch.full((3,), 4, dtype=torch.float32)
print(tensor) # tensor([4., 4., 4.])
当需要复制同样值时,使用 torch.full() 比使用 Python 列表更高效:
# 不推荐:低效
bad_tensor = torch.tensor([[5]100]100)
# 推荐:高效
good_tensor = torch.full((100, 100), 5)
在模型中作为参数使用时需要设置 requires_grad=True:
trainable_scalar = torch.full((1,), 0.5, requires_grad=True)
PyTorch 的 torch.full() 与 NumPy 的 np.full() 功能相似,但针对深度学习进行了优化:
import numpy as np
import torch
#NumPy 版本
np_array = np.full((3, 3), 5)
print(type(np_array)) #
#PyTorch 版本
torch_tensor = torch.full((3, 3), 5)
print(type(torch_tensor)) #
主要区别:
torch.full()
是一个功能强大且高效的工具,用于创建全相同值的张量:
它在神经网络开发中常用于初始化张量、创建掩码、设置特殊值和建立模型参数,是PyTorch张量操作工具箱中不可缺少的一部分。