higher
是一个用于 元学习(Meta-Learning) 和 高阶导数(Higher-order gradients) 的 Python 库,专为 PyTorch 设计。它扩展了 PyTorch 的自动微分机制,使得在训练过程中可以动态地计算参数的梯度更新,并把这些更新过程纳入到更高阶的梯度计算中。
higher
主要用于以下场景:
import higher
higher需要自己手动安装,否则报错。
No module named ‘higher’。
你可以通过 pip 安装:
pip install higher
或者从 GitHub 安装最新版本:
git clone https://github.com/facebookresearch/higher
cd higher
pip install -e .
通常在 PyTorch 中,当你进行 optimizer.step()
时,参数的更新是“in-place”的,不会被自动微分系统记录。而 higher
提供了一种方法,可以在不修改原始模型的情况下,对参数更新进行跟踪,从而实现高阶导数的计算。
higher
进行一次可微分的参数更新import torch
import torch.nn as nn
import torch.optim as optim
import higher
# 定义一个简单的模型
model = nn.Linear(2, 1)
opt = optim.SGD(model.parameters(), lr=0.1)
# 使用 higher 包装 optimizer,允许参数更新被记录
with higher.gradient_tracking():
tr_opt = higher.get_trainable_optim(opt, model.parameters())
# 假设我们有一个简单的损失函数
x = torch.randn(4, 2)
y = torch.randn(4, 1)
# 前向传播
pred = model(x)
loss = ((pred - y) ** 2).mean()
# 计算梯度并更新参数(这些更新将被记录用于高阶导数)
tr_opt.zero_grad()
grads = torch.autograd.grad(loss, list(model.parameters()))
tr_opt.step(grads)
# 现在你可以继续在这个更新后的模型上做新的前向/反向传播
new_pred = tr_opt.current_params() @ x.t()
final_loss = ((new_pred - y.t()) ** 2).mean()
# 反向传播最终损失会回传到初始参数和学习率上
final_loss.backward()
higher
特别适合用于实现像 MAML 这样的算法,因为它可以动态生成“更新后的模型”,并且这些更新是可微的。
for task in tasks:
fmodel = copy.deepcopy(model)
with higher.gradient_tracking():
tr_opt = higher.get_trainable_optim(optimizer, fmodel.parameters())
# Inner loop: 在当前任务上做几步梯度更新
for _ in range(inner_steps):
pred = fmodel(x_support)
loss = loss_fn(pred, y_support)
grads = torch.autograd.grad(loss, list(fmodel.parameters()))
tr_opt.step(grads)
# Outer loop: 测试在 query 数据上的表现,并反向传播
pred_query = fmodel(x_query)
outer_loss = loss_fn(pred_query, y_query)
outer_loss.backward() # 这个 loss 对原始 model.parameters() 的梯度会被正确计算
函数 | 描述 |
---|---|
higher.get_trainable_optim(optimizer, params) |
将普通 optimizer 转换为 trainable optimizer,支持梯度更新追踪 |
tr_opt.step(grads) |
手动执行一步带梯度的参数更新 |
tr_opt.current_params() |
获取当前更新后的参数列表 |
higher.gradient_tracking() |
上下文管理器,启用梯度追踪模式 |