1.下载SSD-Pytorch代码
SSD-pytorch代码链接: https://github.com/amdegroot/ssd.pytorch
git clone https://github.com/amdegroot/ssd.pytorch
2.准备数据集
3.1 config.py
找到config.py文件,
打开修改VOC中的num_classes,根据自己的情况修改:classes+1(背景算一类),
我这里就只有一类,所有是2
第一次调试最好修改一下max_iter,不然迭代次数太大,要好长时间,其他都是一些超参数,可以占时不修改
博主用的VOC格式的数据集,下面修改都是以VOC格式为例
3.3 train.py
下载预训练模型。VGG16_reducedfc.pth
链接: https://pan.baidu.com/s/1EW9qT0nJkE2dK7thn_kPVw 密码: nw6t
–来自百度网盘超级会员V1的分享
3.4 eval.py
添加训练好的模型到eval.py,对模型进行验证,我这里训练好的是ssd300_VOC_500.pth
将下面的
args = parser.parse_args()
修改为
args,unknow= parser.parse_known_args()
运行eval.py只能看到AP值,想要测试自己的图片,在jupyter notebook中运行demo.ipynb
将对应部分的代码,修改为以下这样即可,注意正确添加图片的路径
image = cv2.imread(’…/data/example3.jpg’, cv2.IMREAD_COLOR) # uncomment if dataset not downloaded
from matplotlib import pyplot as plt
from data import VOCDetection, VOC_ROOT, VOCAnnotationTransform
#testset = VOCDetection('./data/example1.jpg', [('2020', 'val')], None, VOCAnnotationTransform())
#img_id = 13
#image = testset.pull_image(img_id)
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(10,10))
plt.imshow(rgb_image)
plt.show()
bug1:出现维度不匹配的情况
loc_loss += loss_l.data[0] 报错
解决方法:
bug2:自动停止训练
解决方法:
3. load train data部分修改为如上图所示
bug3:可能会出现pytorch版本带来的影响问题
解决方法:根据提示语句,百度修改即可
bug4:运行eval.py可能会出现pytest这种情况
解决方法:将eval.py中的test_net函数名字修改即可,不能出现test关键字,博主修改为set_net成功运行
bug5:训练出现-nan
解决方法:降低学习率
bug6:出现显存不足的问题Runtimeout
解决方法:降低batch_size
bug7:出现数组索引过多的情况
IndexError: too many indices for array
解决方法:因为有些标注的标签没有数据,所有会出现数组索引出错
如果数据比较多,可以用如下脚本排查是哪个标签出现问题(注意修改自己的标签路径)
import argparse
import sys
import cv2
import os
import os.path as osp
import numpy as np
if sys.version_info[0] == 2:
import xml.etree.cElementTree as ET
else:
import xml.etree.ElementTree as ET
parser = argparse.ArgumentParser(
description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
parser.add_argument('--root', default='data/VOCdevkit/VOC2020' , help='Dataset root directory path')
args = parser.parse_args()
CLASSES = [( # always index 0
'dargon fruit')]
annopath = osp.join('%s', 'Annotations', '%s.{}'.format("xml"))
imgpath = osp.join('%s', 'JPEGImages', '%s.{}'.format("jpg"))
def vocChecker(image_id, width, height, keep_difficult = False):
target = ET.parse(annopath % image_id).getroot()
res = []
for obj in target.iter('object'):
difficult = int(obj.find('difficult').text) == 1
if not keep_difficult and difficult:
continue
name = obj.find('name').text.lower().strip()
bbox = obj.find('bndbox')
pts = ['xmin', 'ymin', 'xmax', 'ymax']
bndbox = []
for i, pt in enumerate(pts):
cur_pt = int(bbox.find(pt).text) - 1
# scale height or width
cur_pt = float(cur_pt) / width if i % 2 == 0 else float(cur_pt) / height
bndbox.append(cur_pt)
print(name)
label_idx = dict(zip(CLASSES, range(len(CLASSES))))[name]
bndbox.append(label_idx)
res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind]
# img_id = target.find('filename').text[:-4]
print(res)
try :
print(np.array(res)[:,4])
print(np.array(res)[:,:4])
except IndexError:
print("\nINDEX ERROR HERE !\n")
exit(0)
return res # [[xmin, ymin, xmax, ymax, label_ind], ... ]
if __name__ == '__main__' :
i = 0
for name in sorted(os.listdir(osp.join(args.root,'Annotations'))):
# as we have only one annotations file per image
i += 1
img = cv2.imread(imgpath % (args.root,name.split('.')[0]))
height, width, channels = img.shape
print("path : {}".format(annopath % (args.root,name.split('.')[0])))
res = vocChecker((args.root, name.split('.')[0]), height, width)
print("Total of annotations : {}".format(i))
之作为学习使用不商用
ref:https://blog.csdn.net/weixin_42447868/article/details/105675158