Pytorch实现目标检测

import os
import random
import pandas as pd
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
#配置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class RoadDataset(Dataset):
    def __init__(self, paths, bb, y):
        self.paths = paths.values
        self.bb = bb.values
        self.y = y.values
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        path = self.paths[idx]
        y_class = self.y[idx]
        # print("path==",path)
        # 检查路径是否存在
        if not os.path.exists(path):
            raise FileNotFoundError(f"图片不存在:{path}")
        # 尝试读取图片
        img = cv2.imread(str(path))
        if img is None:
            raise ValueError(f"读取图片失败(可能损坏或格式错误):{path}")
        # 后续处理
        img = img.astype(np.float32)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255
        y_bb = self.bb[idx]

        tmplist = y_bb.split(' ')
        y_bb = [int(e) for e in tmplist]
        y_bb = torch.tensor(y_bb)
        img = torch.Tensor(img).permute([2,0,1])
        return img, y_class, y_bb

df_train = pd.read_csv(r"D:\DataSet\object_detection\dataset.csv", encoding="gbk")

# print(df_train)


X = df_train[['new_path', 'new_bb']]
Y = df_train['class']


X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=42, shuffle=False)
train_ds = RoadDataset(X_train['new_path'], X_train['new_bb'], y_train)
test_ds = RoadDataset(X_test['new_path'], X_test['new_bb'], y_test)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16)


#定义目标检测网络
class Detect_model(nn.Module):
    def __init__(self):
        super().__init__()
        resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        layers = list(resnet.children())[:8]  # 取ResNet的前8层
        self.features = nn.Sequential(*layers)  # 用于图像特征提取

        # 修正classifier:将BatchNorm2d改为BatchNorm1d(适应2D输入)
        self.classifier = nn.Sequential(
            nn.BatchNorm1d(512),  # 改为1D批归一化(输入是2D张量)
            nn.Linear(512, 4)  # 假设分类任务有4个类别,根据实际情况调整
        )

        self.bb = nn.Sequential(
            nn.BatchNorm1d(512),
            nn.Linear(512, 4)  # 边界框预测(x1,y1,x2,y2)
        )

    def forward(self, x):
        o = self.features(x)  # 4D: [batch, 512, h, w]
        o = torch.relu(o)
        o = nn.AdaptiveAvgPool2d((1, 1))(o)  # 4D: [batch, 512, 1, 1]
        o = o.reshape(x.shape[0], -1)  # 2D: [batch, 512]
        return self.classifier(o), self.bb(o)  # 此时输入classifier的是2D,与BatchNorm1d匹配

#  ============================
detect_model = Detect_model().to(device)
paramters = filter(lambda p: p.requires_grad, detect_model.parameters())
optimizer = torch.optim.Adam(paramters, lr=0.001)
# --------------------------
detect_model.train()
for ep in range(200):
    for k, (x, y_class, y_bb) in enumerate(train_loader):
        x, y_class, y_bb = x.to(device), y_class.to(device).long(), y_bb.to(device).float()
        pre_y, pre_bb = detect_model(x)
        loss_class = F.cross_entropy(pre_y, y_class, reduction='sum')
        loss_bb = F.l1_loss(pre_bb, y_bb, reduction='none').sum(1)
        loss_bb = loss_bb.sum()
        loss = loss_class + 0.001 * loss_bb
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
# --------------------------------------
detect_model.eval()
correct = 0
for k, (x, y_class, y_bb) in enumerate(test_loader):
    x, y_class, y_bb = x.to(device), y_class.to(device).long(), y_bb.to(device).float()
    pre_y, pre_bb = detect_model(x)
    _, pre_index = torch.max(pre_y, 1)
    t = (pre_index == y_class).int().sum()
    correct += t
correct = 1. * correct / len(test_loader.dataset)
print("在测试集上的分类准确率为:{:.2f}%".format(100 * correct.item()))







在测试集上的分类准确率为:90.34%

 使用的数据集是:Kaggle上的Road Sign Detection数据集。该数据集共有877张图片。各图片大小不一

Pytorch实现目标检测_第1张图片因为使用了OpenCV所以各种路径中不能有中文。这个是一定要注意的。 

你可能感兴趣的:(Pytorch实现目标检测)