使用resnet18预训练模型实时检测摄像头画面中的物体(画面显示英文类名)

imagenet_class_index.cs文件下载

https://download.csdn.net/download/qq_42864343/88492936

代码

import os

import numpy as np
import pandas as pd

import cv2 # opencv-python
from tqdm import tqdm # 进度条
from PIL import Image # pillow
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn.functional as F
from torchvision import models
import time
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)
# 载入预训练图像分类模型
model = models.resnet18(pretrained=True)
model = model.eval()
model = model.to(device)
# 将idx与类名相对应
df = pd.read_csv('data/imagenet_class_index.csv')
idx_to_labels = {}
for idx, row in df.iterrows():
    # 英文类名
    idx_to_labels[row['ID']] = row['class']
# 图像预处理
from torchvision import transforms

# 测试集图像预处理-RCTN:缩放裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),
                                     transforms.CenterCrop(224),
                                     transforms.ToTensor(),
                                     transforms.Normalize(
                                         mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
                                     ])
# 处理一帧的函数,供后面调用
def process_frame(img):

    '''
    输入摄像头拍摄画面bgr-array,输出图像分类预测结果bgr-array
    '''

    # 记录该帧开始处理的时间
    start_time = time.time()

    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR转RGB
    img_pil = Image.fromarray(img_rgb) # array 转 PIL
    input_img = test_transform(img_pil).unsqueeze(0).to(device) # 预处理
    pred_logits = model(input_img) # 执行前向预测,得到所有类别的 logit 预测分数
    pred_softmax = F.softmax(pred_logits, dim=1) # 对 logit 分数做 softmax 运算

    top_n = torch.topk(pred_softmax, 5) # 取置信度最大的 n 个结果
    pred_ids = top_n[1].cpu().detach().numpy().squeeze() # 解析预测类别
    confs = top_n[0].cpu().detach().numpy().squeeze() # 解析置信度

    # 在图像上写字
    for i in range(len(confs)):
        pred_class = idx_to_labels[pred_ids[i]]
        text = '{:<15} {:>.3f}'.format(pred_class, confs[i])

        # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
        img = cv2.putText(img, text, (50, 160 + 80 * i), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 255), 4, cv2.LINE_AA)

    # 记录该帧处理完毕的时间
    end_time = time.time()
    # 计算每秒处理图像帧数FPS
    FPS = 1/(end_time - start_time)
    # 图片,添加的文字,左上角坐标,字体,字体大小,颜色,线宽,线型
    img = cv2.putText(img, 'FPS  '+str(int(FPS)), (50, 80), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 0, 255), 4, cv2.LINE_AA)

    return img
# 调用摄像头处理摄像头中的画面
def view_video(video_path):
    # 设置显示窗口的大小
    width,height = 800,600

    video = cv2.VideoCapture(video_path)
    '''把摄像头设置为1980 x 1080'''
    video.set(cv2.CAP_PROP_FRAME_WIDTH,1920)
    video.set(cv2.CAP_PROP_FRAME_HEIGHT,1080)
    video.set(cv2.CAP_PROP_FOURCC,cv2.VideoWriter.fourcc('M','J','P','G'))

    if video.isOpened():
        '''
            video.read() 一帧一帧地读取
            open 得到的是一个布尔值,就是 True 或者 False
            frame 得到当前这一帧的图像
        '''
        open, frame = video.read()
    else:
        open = False

    while open:
        ret, frame = video.read()
        # 如果读到的帧数不为空,那么就继续读取,如果为空,就退出
        if frame is None:
            break
        if ret == True:
            # !!!处理帧函数
            frame = process_frame(frame)
            cv2.namedWindow('video',cv2.WINDOW_NORMAL)
            cv2.imshow("video", frame)
            # 50毫秒内判断是否受到esc按键的信息
            if cv2.waitKey(50) & 0xFF == 27:
                break
    video.release()
    cv2.destroyAllWindows()
# linux usb摄像头的Id一般为1
view_video(1)

你可能感兴趣的:(opencv,opencv,人工智能,计算机视觉)