CVPR 2023 Hybrid Tutorial: All Things ViTs之DINO attention map

All Things ViTs系列讲座从ViT视觉模型注意力机制出发,本文给出DINO attention map可视化部分阅读学习体会.

课程视频与课件: https://all-things-vits.github.io/atv/
代码:https://colab.research.google.com/github/all-things-vits/code-samples/blob/main/probing/dino_attention_maps.ipynb
文献: Emerging Properties in Self-Supervised Vision Transformers

1. 总述

这部分代码实际与CVPR 2023 Hybrid Tutorial: All Things ViTs之mean attention distance (MAD)比较类似.Emerging Properties in Self-Supervised Vision Transformers这篇论文主要想论述采用DINO(self-distillation with no labels)的训练方式可以达到类似监督学习的训练效果,因此通过可视化注意力机制图的方式来说明DINO学习方法是学到了需要的知识.这部分代码的核心是可视化最后一个block中CLS token与其他图像patch token之间的注意力图.即展示每个patch与CLS之间的关联性,如Fig 1:
CVPR 2023 Hybrid Tutorial: All Things ViTs之DINO attention map_第1张图片
Fig 1 DINO attention map

2.关键代码

def get_attention_scores(image: Image, model: torch.nn.Module, processor):
    """Extracts attention scores given an image, a model,
    and its processor."""
    inputs = processor(image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    return outputs.attentions # [[1, 12, 197, 197]*12]

这部分代码将图像输入ViT网络,并得到输出的每个head的注意力分数(outputs.attentions).outputs.attentions是一个tuple,其中包括12个维度为[1, 12, 197, 197]的tensor.这个tensor可理解如下,其中12为head的数量,197是token的数量.197*197表示每个token之间的注意力分数.

def process_attention_map(
    image: torch.Tensor, attention_scores, block_id=11, patch_size=16
):
    """
    Processes the attention scores such that they can be overlaid on the input image.

    Args:
        image (torch.Tensor): The input image tensor.
        attention_scores (Tuple[torch.Tensor]): Tuple of attention scores.
        block_id (int, optional): The block ID. Default is 11. 11 is the last
            transformer block for a DINO base model.
        patch_size (int, optional): The size of the patches. Default is 16.

    Returns:
        numpy.ndarray: The processed attention map as a NumPy array.
    """
    height, width = image.shape[2:]
    w_featmap = width // patch_size
    h_featmap = height // patch_size

    num_heads = attention_scores[block_id].shape[1]  # Number of attention heads. 12

    # Taking the representations from CLS token.
    attentions = attention_scores[block_id][0, :, 0, 1:].reshape(num_heads, -1)# only visual cls tokens of each head
    print(attentions.shape) # [12, 196], 196 is the number of patches(tokens), 12 is the number of heads, attention of cls token to each patch in one head

    # Reshape the attention scores to resemble mini patches.
    attentions = attentions.reshape(num_heads, w_featmap, h_featmap) # [12, 14, 14]
    print(attentions.shape)

    # Resize the attention patches to 224x224 (224: 14x16)
    attentions = F.resize(
        attentions, size=(h_featmap * patch_size, w_featmap * patch_size)
    )
    return attentions.numpy()

这部分代码则是抽取CLS token与其他patch token之间的attention矩阵,并上采样到与原始图像一样大并绘制.

3. One more thing

这部分代码可视化比较好理解,接下来可以更全面看看DINO可视化效果:
CVPR 2023 Hybrid Tutorial: All Things ViTs之DINO attention map_第2张图片
可以看到,经过DINO训练,ViT能够有效关注图像中有意义的部分.并且不同的head在关注不同的区域.后续经过简单微调,即可泛化至下游任务.

你可能感兴趣的:(深度学习,人工智能,python,计算机视觉)