pytorch中model模型显示及维度信息(直接打印、summary、make_dot)

在学习Pytorch中,为了更好理解网络结构,需要结合mdoel的图片结构和维度信息才能更好理解。keras中model.summary和plot_model工具就十分好用。在pytorch中,经过多方搜索,下列三种方式有助于自己理解,在此mark一下。其中summary要能知道模型的输入shape,可根据源代码和报错中提示进行尝试。

import torch
from torchviz import make_dot
from torch.autograd import Variable
from torchsummary import summary

model.netG  # 直接打印
summary(model.netG, (3,256,256))  # 每层输出shape
xtmp = Variable(torch.randn(1,3,256,256))
ytmp = model.netG(xtmp)
make_dot(ytmp, params=dict(model.netG.named_parameters())).render('tmp', view=True)  # render用于保存为图片

model.netD
summary(model.netD, (6, 256, 256))
xtmp2 = Variable(torch.randn(1,6,256,256))
ytmp2 = model.netD(xtmp2)
make_dot(ytmp2, params=dict(model.netD.named_parameters())).render('tmp2', view=True)

类似的make_dot,似乎更简洁些, From言有三

import torch
from torch.autograd import Variable
from visualize import  make_dot
x = Variable(torch.randn(1,3,48,48))
model = simpleconv3()
y = model(x)
g = make_dot(y)
g.view()

20200421更新:
HiddenLayer,torchwatch

你可能感兴趣的:(PyTorch)