详解PyG中的ToSLIC变换

详解PyG中的ToSLIC变换

PyG是一个基于PyTorch的图神经网络库,提供了丰富的数据处理、图转换和图模型的功能。本文将介绍PyG中的一个图转换函数ToSLIC,它可以将一张图片转换为一个超像素图,并生成相应的数据对象。

前言

PyG是一个开源的Python库,用于深度学习任务中的图神经网络(GNN)建模和训练。该库包括多个GNN模型和与图相关的数据结构和算法。在本篇文章中,我将介绍PyG中的ToSLIC模块,它是一个用于图像分割的超像素分割算法。

什么是超像素图?

图像分割是计算机视觉领域的一个重要问题,它的目标是将图像分成若干个相似的区域,每个区域具有一定的语义信息。图像分割在许多领域都有应用,如医学影像、自动驾驶、图像检索等。

超像素分割是一种常见的图像分割方法,它将图像中的像素划分为若干个相似的区域,这些区域被称为超像素。与像素相比,超像素更具有代表性和可解释性,并且可以降低图像分割的复杂度。

超像素图是一种对图片进行分割的方法,将图片中相似的像素聚合成一个个小区域,称为超像素。每个超像素可以看作是图片中的一个节点,它具有一定的特征(如颜色、位置等),并与其他超像素有一定的关系(如邻接、距离等)。这样,我们就可以把一张图片看作是一个图结构,从而利用图神经网络来进行分析和处理。

[外链图片转存中…(img-FcTPnHjV-1700015661802)]

如何使用PyG中的ToSLIC函数?

PyG中的ToSLIC模块实现了一种基于SLIC(Simple Linear Iterative Clustering)算法的超像素分割方法。该算法将图像划分为若干个相邻的块,每个块具有相同的颜色或者灰度级别。ToSLIC模块使用PyTorch实现,并可以直接集成到PyG的GNN模型中,用于图像分割任务。

ToSLIC模块主要由以下几个步骤组成:

  1. 图像预处理:将原始图像转换为LAB色彩空间,并对其进行归一化处理。

  2. 超像素初始化:在图像上随机选取若干个像素作为超像素中心,根据这些中心像素计算每个像素与哪个超像素最近,将其归入该超像素。

  3. 超像素迭代:重复进行以下两个步骤,直到收敛:

    a. 计算每个超像素的中心点,并更新其位置;

    b. 对每个像素,计算其与每个超像素中心点之间的距离,并将其归入最近的超像素中。

  4. 超像素合并:根据超像素之间的距离,将相邻的超像素合并成一个更大的超像素。

进一步说,ToSLIC函数是PyG中提供的一个图转换函数,它使用了skimage库中的slic算法来实现图片到超像素图的转换。ToSLIC函数接受一个torch.Tensor类型的图片作为输入,并返回一个torch_geometric.data.Data类型的数据对象作为输出。输出对象包含以下属性:

  • x: 一个二维张量,表示每个超像素节点的特征向量。默认情况下,特征向量是每个超像素节点在RGB空间下的平均颜色值。
  • pos: 一个二维张量,表示每个超像素节点在原始图片上的位置坐标。
  • seg: 一个二维张量(可选),表示原始图片上每个像素所属于哪个超像素节点。
  • img: 一个四维张量(可选),表示原始图片。

ToSLIC函数还可以接受一些额外参数来调整slic算法和输出对象:

  • add_seg: 一个布尔值(默认为False),表示是否在输出对象中添加seg属性。
  • add_img: 一个布尔值(默认为False),表示是否在输出对象中添加img属性。
  • **kwargs: 其他参数,用于调整slic算法。具体参见skimage.segmentation.slic文档。

下面给出一个简单的例子:

from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch_geometric.transforms import ToSLIC

transform = T.Compose([
T.ToTensor(),
ToSLIC(n_segments=75, add_seg=True)
])
dataset = MNIST('/tmp/MNIST', download=True, transform=transform)
data = dataset[0] # data is a Data object with x, pos and seg attributes

这段代码首先从torchvision库中加载了MNIST数据集,并定义了一个组合变换transform。transform包含两个步骤:第一步是将PIL.Image类型的图片转换为torch.Tensor类型;第二步是将torch.Tensor类型的图片转换为Data类型,并指定要生成75个超像素节点,并在输出对象中添加seg属性。然后我们从数据集中取出第一张图片,并应用transform变换得到data对象。

应用

在训练GNN模型时,可以使用这个数据集作为输入数据,并将其转换为图形数据格式。例如,可以将图像中的超像素视为节点,超像素之间的邻接关系视为边。在PyG中,可以使用Data类来表示图形数据,其中包括节点特征、边索引和边特征等信息。

下面是一个使用ToSLIC模块进行图像分割的例子:

import torch
import torch.nn.functional as F
from torch_geometric.datasets import MNISTSuperpixels
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
from torch_geometric.data import DataLoader

class GCN(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__(aggr='add')
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        x = self.lin(x)
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCN(3, 16)
        self.conv2 = GCN(16, 32)
        self.conv3 = GCN(32, 64)
        self.lin = torch.nn.Linear(64, 10)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return F.log_softmax(x, dim=1)

dataset = MNISTSuperpixels(root='~/datasets/MNIST', train=True, transform=ToSLIC())
loader = DataLoader(dataset, batch_size=64, shuffle=True)

model = Net()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

for epoch in range(100):
    model.train()
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x.float(), data.edge_index)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    for data in loader:
        out = model(data.x.float(), data.edge_index)
        pred = out.argmax(dim=1)
        correct += pred.eq(data.y).sum().item()
    acc = correct / len(dataset)
    print(f'Epoch: {epoch}, Accuracy: {acc:.4f}')

在这个例子中,我们使用了一个包含三个GCN层的模型来对超像素图进行分类,其中输入节点特征的维度为3,代表超像素的LAB颜色空间的三个通道。我们使用全局池化层对每个超像素的特征进行聚合,最终输出一个10维向量,代表对应MNIST数字图像的类别概率分布。训练过程中,我们使用ToSLIC将原始图像转换为超像素图,并使用DataLoader将其加载到模型中进行训练。

GCN类中,我们重写了message()forward()方法。在message()中,我们将每个节点的特征按照邻接矩阵中的边权重进行加权平均,以获取该节点的邻居节点特征的信息。在forward()中,我们首先对邻接矩阵添加自环,然后对每个节点的特征进行线性变换。接下来,我们计算每个边的权重,以便在message()中进行加权平均。最后,我们调用propagate()方法来执行信息传递操作。

Net类中,我们定义了一个包含三个GCN层和一个全连接层的模型,并在每个GCN层后应用了一个ReLU激活函数。在模型的最后一层,我们使用global_mean_pool()对所有超像素的特征进行全局平均池化,以获取整个图像的特征表示。

训练过程中,我们使用Adam优化器来最小化交叉熵损失,并在每个epoch结束时计算模型的准确率。由于我们使用了ToSLIC模块对原始图像进行了超像素分割,因此我们可以将每个超像素视为一个节点,并使用图卷积神经网络来对其进行分类。

超像素的特征进行全局平均池化,以获取整个图像的特征表示。

训练过程中,我们使用Adam优化器来最小化交叉熵损失,并在每个epoch结束时计算模型的准确率。由于我们使用了ToSLIC模块对原始图像进行了超像素分割,因此我们可以将每个超像素视为一个节点,并使用图卷积神经网络来对其进行分类。

在实际应用中,ToSLIC模块可以与其他PyG中的模块结合使用,例如SAGEConv、GATConv等,以实现更复杂的图卷积神经网络。此外,我们还可以使用ToSLIC模块将图像分割应用于其他任务,例如目标检测、图像生成等。

你可能感兴趣的:(深度学习)