基于逻辑回归的图像二分类算法实现(Pytorch版)

基于逻辑回归的图像二分类算法实现(Pytorch版)

  • 数据集
  • 模型代码

数据集

链接:Fast Food Classification Dataset

我们只使用Burger和Pizza这两类。

模型代码

import torch
import torch.nn as nn
from torchvision.models.utils import load_state_dict_from_url

model_urls = {
    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
}


class LogisticRegression(nn.Module):
    def __init__(self, input_shape, num_classes=1000):
        super(LogisticRegression, self).__init__()

        self.linear = nn.Linear(input_shape[0] * input_shape[1] * 3, num_classes)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.linear(x)

        return x


def logistic_regression(pretrained=False, progress=True, input_shape=[224, 224], num_classes=1000):
    model = LogisticRegression(input_shape=input_shape, num_classes=num_classes)
    if pretrained:
        state_dict = load_state_dict_from_url(model_urls['resnet50'], model_dir='./model_data',
                                              progress=progress)
        model.load_state_dict(state_dict)

    return model

你可能感兴趣的:(图像分类,pytorch,逻辑回归,分类)