Pytorch中torch.nn.functional模块介绍

1、torch.nn.functional模块介绍

Pytorch中torch.nn.functional模块介绍_第1张图片

1.1 模块功能概述

torch.nn.functional 是 PyTorch 的核心模块之一,提供函数式接口实现神经网络操作,涵盖以下功能:

  • 激活函数:如 ReLU、Sigmoid、Tanh
  • 卷积运算:包括 1D/2D/3D 卷积、转置卷积
  • 池化操作:最大池化、平均池化、自适应池化
  • 正则化方法:Dropout、BatchNorm
  • 损失函数:交叉熵、均方误差
  • 张量操作:填充(pad)、裁剪、插值

1.2 核心特性与使用场景

  1. 函数式编程范式:直接调用函数,无需实例化对象
import torch.nn.functional as F
output = F.relu(input_tensor)  # 直接调用激活函数
  1. 与 torch.nn 模块对比(以 ReLU 为例):
    torch.nn.ReLU 是类,需实例化为网络层后使用
    F.relu 可直接在计算图中调用
# 类方式(需继承 Module)
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
    
    def forward(self, x):
        return self.relu(x)

# 函数式调用
def forward(x):
    return F.relu(x)
  1. 参数传递方式:
# 卷积操作示例
conv_output = F.conv2d(
    input, 
    weight, 
    bias=None, 
    stride=1, 
    padding=0  # 显式传递所有参数
)

1.3 典型使用场景

  1. 自定义网络层:灵活组合基础操作
def custom_layer(x, weight):
    x = F.pad(x, (1,1,1,1))  # 使用填充函数[^3]
    x = F.conv2d(x, weight)
    return F.leaky_relu(x, 0.1)
  1. 动态参数计算:实现自适应逻辑
def adaptive_dropout(x, p):
    current_p = calculate_based_on_input(x)
    return F.dropout(x, p=current_p, training=True)
  1. 与 nn.Module 协同使用:
class CustomNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_weight = nn.Parameter(torch.randn(3,3,5,5))
    
    def forward(self, x):
        x = F.conv2d(x, self.conv_weight)
        return F.sigmoid(x)  # 使用函数式激活[^2]

1.4 梯度处理机制

所有函数式操作均支持自动微分:
∂ L ∂ x = { 0 i f   x ≤ 0 ∂ L ∂ y i f   x > 0 \frac{\partial\mathcal{L}}{\partial\mathbf{x}}=\{\begin{array}{cc} 0 & \mathrm{if~x\le 0} \\\frac{\partial\mathcal{L}}{\partial\mathbf{y}} & \mathrm{if~x>0}\end{array} xL={0yLif x0if x>0
(以 ReLU 反向传播为例)

1.5 最佳实践建议

  1. 训练/推理模式切换:对 Dropout/BatchNorm 等需手动处理
if self.training:
    x = F.dropout(x, p=0.5)
  1. 内存优化:适当使用 torch.no_grad() 包裹无需梯度的计算

  2. 设备兼容性:确保所有张量位于相同设备(CPU/GPU)

2、torch.nn.functional使用实例(以卷积为例)

卷积的概念介绍可以参考:卷积(Convolution)介绍——从数学基础到深度学习应用
Pytorch中torch.nn.functional模块介绍_第2张图片
torch.nn.functional文档
以torch.nn.functional.conv2d为例,torch.nn.functional.conv2d文档

2.1 torch.nn.functional.conv2d语法说明

torch.nn.functional.conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) → Tensor

参数说明:

  • input – input tensor of shape ( m i n i b a t c h , i n _ c h a n n e l s , i H , i W ) (minibatch,in\_channels,iH,iW) (minibatch,in_channels,iH,iW)
    输入形状

  • weight – filters of shape ( o u t _ c h a n n e l s , i n _ c h a n n e l s g r o u p s , k H , k W ) (out\_channels,\frac{in\_channels}{groups},kH,kW) (out_channels,groupsin_channels,kH,kW)
    权重(卷积核)

  • bias – optional bias tensor of shape ( o u t _ c h a n n e l s ) (out\_channels) (out_channels). Default: None
    可选 ,形状为 ( o u t _ c h a n n e l s ) (out\_channels) (out_channels).的张量偏置

  • stride – the stride of the convolving kernel. Can be a single number or a tuple ( s H , s W ) (sH, sW) (sH,sW). Default: 1
    卷积内核的步幅。可以是单个数字或元组 ( s H , s W ) (sH, sW) (sH,sW)。默认值:1

  • padding – implicit paddings on both sides of the input. Can be a string {‘valid’, ‘same’}, single number or a tuple ( p a d H , p a d W ) (padH, padW) (padH,padW).Default: 0 padding=‘valid’ is the same as no padding. padding=‘same’ pads the input so the output has the same shape as the input.However, this mode doesn’t support any stride values other than 1.
    input 两侧的隐式填充。可以是字符串 {‘valid’, ‘same’}、单个数字或元组 ( p a d H , p a d W ) (padH, padW) (padH,padW)。默认值:0 padding=‘valid’ 等同于无填充。 padding=‘same’ 填充输入,使输出具有与输入相同的形状。但是,此模式不支持除 1 以外的任何步幅值。

  • dilation – the spacing between kernel elements. Can be a single number or a tuple ( d H , d W ) (dH, dW) (dH,dW). Default: 1
    内核元素之间的间距。可以是单个数字或元组 ( d H , d W ) (dH, dW) (dH,dW)。默认值:1

  • groups – split input into groups, both i n _ c h a n n e l s in\_channels in_channels and o u t _ c h a n n e l s out\_channels out_channels should be divisible by the number of groups. Default: 1
    将输入拆分为多个组, i n _ c h a n n e l s in\_channels in_channels o u t _ c h a n n e l s out\_channels out_channels两者都 应能被组数整除。默认值:1

Examples:

# With square kernels and equal stride
filters = torch.randn(8, 4, 3, 3)
inputs = torch.randn(1, 4, 5, 5)
F.conv2d(inputs, filters, padding=1)

2.2 torch.nn.functional.conv2d使用实例

  1. 设置输入和卷积核
import torch

input = torch.tensor([[[[2, 5, 3, 1, 1],
                      [7, 0, 6, 4, 1],
                      [3, 2, 6, 0, 0],
                      [5, 1, 8, 3, 1],
                      [2, 1, 0, 5, 1]]]])

kernel = torch.tensor([[[[2, 2, 1],
                       [1, 0, 2],
                       [3, 0, 1]]]])

此处一定要注意输入和卷积核的维度,是否符合torch.nn.functional.conv2d的要求。可以使用”[ ]“来增加维度,也可以使用torch.reshape()来更改维度。如:

input = torch.tensor([[2, 5, 3, 1, 1],
                      [7, 0, 6, 4, 1],
                      [3, 2, 6, 0, 0],
                      [5, 1, 8, 3, 1],
                      [2, 1, 0, 5, 1]])

此时input的维度为(5X5)。

input = torch.tensor([[[[2, 5, 3, 1, 1],
                      [7, 0, 6, 4, 1],
                      [3, 2, 6, 0, 0],
                      [5, 1, 8, 3, 1],
                      [2, 1, 0, 5, 1]]]])
                                   
input = torch.reshape(input, (1, 1, 5, 5))   

以上两种方法都可以将input的维度变为(1X1X5X5)。

  1. 使用torch.nn.functional.conv2d函数
import torch
import torch.nn.functional as F

input = torch.tensor([[[[2, 5, 3, 1, 1],
                      [7, 0, 6, 4, 1],
                      [3, 2, 6, 0, 0],
                      [5, 1, 8, 3, 1],
                      [2, 1, 0, 5, 1]]]])

kernel = torch.tensor([[[[2, 2, 1],
                       [1, 0, 2],
                       [3, 0, 1]]]])

output1 = F.conv2d(input, kernel, stride=1)
print(output1)

output2 = F.conv2d(input, kernel, stride=2)
print(output2)

output3 = F.conv2d(input, kernel, stride=1, padding=1)
print(output3)                       

运行结果,输出张量:

tensor([[[[51, 31, 35],
          [58, 24, 52],
          [43, 31, 23]]]])
tensor([[[[51, 35],
          [43, 23]]]])
tensor([[[[10, 35, 11, 24, 13],
          [11, 51, 31, 35,  8],
          [19, 58, 24, 52, 19],
          [11, 43, 31, 23, 18],
          [13, 22, 32, 25, 13]]]])

可以手动计算验证一下。

你可能感兴趣的:(Pytorch实战,深度学习(DL),pytorch,深度学习,人工智能,机器学习,python,cnn,卷积神经网络)