pytorch特征可视化——特征图

一、核心函数

第一行的参数是必须的,第二行的参数是save_image这个函数调用make_grid时用的,make_grid的作用是将多通道的特征拼接成一个网格状排列的大图,make_grid返回的是一个tensor,save_image直接将图片保存在filename路径。

# tensor 的shape要变成[C, 1, H, W]
torchvision.utils.save_image(tensor, filename,
		nrow=8, padding=2, normalize=False, range=None, scale_each=False, pad_value=0)

简单的例子

a = torch.randn((1, 32, 256, 512))
file_path = 'result/a.png'
a = a.permute(1, 0, 2, 3)
torchvision.utils.save_image(a, file_path)

二、优雅的使用方式

上述的简单例子可以适用于任意情况,只要写在模型的forward()函数内即可,如果不想写在forward()内,则可以用另外的方式实现,结合torch.nn.Module.register_forward_hook函数。

1. torch.nn.Module.register_forward_hook

函数声明:register_forward_pre_hook(hook: Callable[..., None])
括号中的参数是一个函数名,姑且称之为hook_func,hook_func需要自己实现。

hook_func可从前向过程中接收到三个参数:hook_func(module, input, output)。其中module指的是模块的名称,调用register_forward_hook后会自己绑定,比如对于卷积层,module是Conv2d,注意module带有具体的参数。input和output就是我们所需的特征图,这二者分别是module的输入和输出,输入可能有多个(比如concate层就有多个输入),输出只有一个,所以input是一个tuple,其中每一个元素都是一个tensor,而输出就是一个tensor。一般而言output可能更经常拿来做分析。我们可以在hook_func中将特征图画出来并保存为图片,所以hook_func是实现可视化的关键。

2. 具体使用方法

# 定义全局字典存储features
features = {}

def get_activation(name):
    """
    :param name: 特征名字
    :return: 
    """
    def hook_func(module, input, output):
        """
        :param module: 调用register_forward_hook后自动绑定
        :param input: 输入
        :param output: 输出
        :return: 
        """
        # 这里可以写自己所需要的处理方式
        features[name] = output.detach()
    return hook_func

# 可视化features
def show_features(features):
    """
    :param features: 特征字典 
    :return: 
    """
    for keys, values in features.items():
        # 将特征的shape变成[C, 1, H, W]
        values = values.permute(1, 0, 2, 3)
        # 保存路径
        image_name = 'result/' + keys + '.png'
        torchvision.utils.save_image(values, image_name)

if __name__ == '__main__':
	# 创建模型
	model = nn.DataParallel(.......)

	# 为模型的某一层或者某一模块注册hook
	model.module.my_block_or_layer.register_forward_hook(get_activation('my_block_or_layer'))
	
	# 调用模型前向传播
	.................

	# 保存特征图
	show_features(features)

特征图可视化结果

pytorch特征可视化——特征图_第1张图片

你可能感兴趣的:(pytorch,深度学习)