分类网络-类别不均衡问题之FocalLoss

分类网络-类别不均衡问题之FocalLoss_第1张图片
有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
train.py

import torch
import torchvision
import time
import argparse
import importlib
from loss import FocalLoss


def parse_args():
    parser = argparse.ArgumentParser('training')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
    parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
    parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    model = importlib.import_module('models.'+args.model) 
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = model.net.to(device)

    loss = torch.nn.CrossEntropyLoss()
    
    if args.model == 'mlp':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    else:
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
          
    train_path = r'./Datasets/mnist_png/training'
    test_path = r'./Datasets/mnist_png/testing'
    transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
    if args.model == 'alexnet' or args.model == 'vgg':
        transform_list.append(torchvision.transforms.Resize(size=224))
    if args.model == 'googlenet' or args.model == 'resnet':
        transform_list.append(torchvision.transforms.Resize(size=96))
    transform = torchvision.transforms.Compose(transform_list)

    train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
    test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)

    train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    for epoch in range(num_epochs):
        train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l += l.cpu().item()
            train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            m += y.shape[0]
            batch_count += 1
        with torch.no_grad():
            for X, y in test_iter:
                net.eval() # 评估模式
                test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
                n += y.shape[0]
        print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
        torch.save(net, args.model+".pth")

test.py

import cv2
import torch
import argparse
import importlib
from pathlib import Path
import torchvision.transforms.functional


def parse_args():
    parser = argparse.ArgumentParser('testing')
    parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    model = importlib.import_module('models.' + args.model) 
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = model.net.to(device)
    net = torch.load(args.model+'.pth')
    net.eval()

    with torch.no_grad():
        imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")
        acc = 0
        count = 0
        for img_path in imgs_path:
            img = cv2.imread(str(img_path), 0)
            if args.model == 'alexnet' or args.model == 'vgg':  
                img = cv2.resize(img, (224,224))
            if args.model == 'googlenet' or args.model == 'resnet':
                img = cv2.resize(img, (96,96))
            img_tensor = torchvision.transforms.functional.to_tensor(img)
            img_tensor = torch.unsqueeze(img_tensor, 0)
            #print(net(img_tensor.to(device)).argmax(dim=1).item())
            if(net(img_tensor.to(device)).argmax(dim=1).item()==6):
                acc += 1
            count+=1
    print(acc/count)

数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
训练结果:

epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s

测试结果:

0.9780793319415448

现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
训练结果:

epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s

测试结果:

0.04801670146137787

可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union


class FocalLoss(nn.Module):
    def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):
        """
        :param alpha: 每个类别的权重
        :param gamma:
        :param with_logits: 是否经过softmax或者sigmoid
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)
        self.smooth = 1e-8
        self.with_logits = with_logits

    def _binary_class(self, input, target):
        prob = torch.sigmoid(input) if self.with_logits else input
        prob += self.smooth
        alpha = self.alpha.to(target.device)
        loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)
        return loss

    def _multiple_class(self, input, target):
        prob = F.softmax(input, dim=1) if self.with_logits else input

        alpha = self.alpha.to(target.device)
        alpha = alpha.gather(0, target)

        target = target.view(-1, 1)

        prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
        logpt = torch.log(prob)

        loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
        return loss

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        :param input: 维度为[bs, num_classes]
        :param target: 维度为[bs]
        :return:
        """
        if len(input.shape) > 1 and input.shape[-1] != 1:
            loss = self._multiple_class(input, target)
        else:
            loss = self._binary_class(input, target)

        return loss.mean()

并将train.py的第26行修改成

    loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])

其中列表的数字代表10个类别的权重值。
训练结果:

epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s

测试结果:

0.9196242171189979

对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

你可能感兴趣的:(#,3D,deep,learning,#,classification,分类,网络,数据挖掘)