将Github作者endernewtown的基于pascal_voc数据集和coco数据集的代码,迁移至NWPU_VHR数据集。
原代码仓库链接:https://github.com/endernewton/tf-faster-rcnn
修改后的代码仓库:https://github.com/Rhapso/tf-faster-rcnn
初次接触一个较大的tensorflow工程,所以在读源码上下了一番功夫。本文适合初涉tensorflow的用户阅读。本工程的代码结构较大,所以要从宏观入手。我们的切入点是/experiments/scripts
目录下的两个个文件train_faster_rcnn.sh
、test_faster_rcnn.sh
。
NWPU_VHR数据集是一个较小的遥感照片数据集,其数据集目录下包含的文件夹有
/ground_truth
中的信息生成的数据体,当程序内部默认当存在现成的数据体时,直接调用这个数据包。因此,调试程序时,若修改了对/ground_truth
文件夹内数据的读取规则,请清空/annotations_catchs
文件夹。(index1,index2),(index3,index4),class_num
,中间不包括空格。我们在提取的时候需要注意。源代码中还添加了难度(diff)参数,取值范围是如0、1、2的整数,难度参数在NWPU_VHR数据集中没有标注,因此在提取时将难度设置为0即可。下文我们顺藤摸瓜,依次将所需调用的文件修改完成。
这个脚本文件是整个程序训练的入口文件,我们使用它的方法是运行命令./train_faster_rcnn.sh 0 nwpu_vhr res101
。这里0
表示GPU_ID,nwpu_vhr
是我们要使用的数据集,为了要使用这个数据集,我们要在文件case
语句中添加以下代码
nwpu_vhr)
TRAIN_IMDB="NWPU_VHR_train"
TEST_IMDB="NWPU_VHR_val"
STEPSIZE="[35000]"
ITERS=200000
ANCHORS="[1,2,4,8,16]"
RATIOS="[0.5,1,2]"
;;
可以实现使用该数据集的一下基本配置,特别说明,所设置的参数可以自由修改,可以达到不一样的训练效果。我们还要修改源代码的一个地方:
if [ ! -f ${NET_FINAL}.index ]; then
if [[ ! -z ${EXTRA_ARGS_SLUG} ]]; then
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight ../DOTA/tf-faster-rcnn/data/imagenet_weights/${NET}.ckpt \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--tag ${EXTRA_ARGS_SLUG} \
--net ${NET} \
--set ANCHOR_SCALES ${ANCHORS} ANCHOR_RATIOS ${RATIOS} \
TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
else
CUDA_VISIBLE_DEVICES=${GPU_ID} time python ./tools/trainval_net.py \
--weight data/imagenet_weights/${NET}.ckpt \
--imdb ${TRAIN_IMDB} \
--imdbval ${TEST_IMDB} \
--iters ${ITERS} \
--cfg experiments/cfgs/${NET}.yml \
--net ${NET} \
--set ANCHOR_SCALES ${ANCHORS} ANCHOR_RATIOS ${RATIOS} \
TRAIN.STEPSIZE ${STEPSIZE} ${EXTRA_ARGS}
fi
fi
在--weight
词条中需要为我们的网络指定一个初始权重,在这个例程中,我使用的是在imageNet数据集上训练出来的一个现成权重。当然我们可以选择完全随机生成权重等其他方法。初始权重的选择会影响训练的速度和效果。我们发现,train_faster_rcnn.sh
调用了./tools/trainval_net.py
文件,我们继续分析。
之前的修改对于任何数据集都是一样的,接下来的修改就要因地制宜了。trainval_net.py
文件定义了参数接口(应用parse_args()
函数),获取并格式化数据集的函数combined_roidb()
。我们需要修改的就是由combined_roidb()
调用的各个函数,使其返回的变量imdb, roidb
符合CNN网络的格式要求就行了。
这个文件是一个数据集类的接口,get_imdb()
返回一个具体数据集的类。为了适配新增的NWPU_VHR数据集,添加以下代码:
# Set up NWPU_VHR_
__sets['NWPU_VHR_train'] = lambda :NWPU_VHR('train')
__sets['NWPU_VHR_val'] = lambda :NWPU_VHR('test')
NWPU_VHR
类是我们要重新改写的针对NWPU_VHR数据集的类,主要的代码结构跟/lib/datasets/pascal_voc.py
是一致的。
参考/lib/datasets/pascal_voc.py
创建的新文件,这个文件创建了一个基于imdb
类的派生类NWPU_VHR
,在文件开头将voc_eval
来源修改为
from NWPU_VHR_eval import voc_eval
我们在原有代码的基础上,修改以下项目:
这部分集中于__init__
函数里,根据本地环境,修改好相应的的数据集路径、种类、图片扩展名(此例中为.png
)。另外,浏览整个文件,任何涉及到路径的配置都需要注意。debug阶段出错大都是因为路径设置的问题。
不同的数据集数据标注的方法不一,做一个对比。DOTA数据集的标注方式是:714 76 726 95 small-vehicle 0
,NWPU_VHR数据集的标注方式是:(563,478),(630,573),1
。因此我们要针对不同的标注做读取上的调整。我们需要修改的函数是_load_XXXXX_annotation(self,index)
。创建函数
classes = ['airplane', 'ship', 'storage tank', 'baseball diamond', 'tennis court', 'basketball court', 'ground track field', 'harbor', 'bridge', 'vehicle']
def getObjects(FilePath):
objects = []
with open(FilePath) as f:
for line in f.readlines():
lineArray = line.rstrip().split(',')
xmin = float(lineArray[0].strip('(').strip(')'))
ymin = float(lineArray[1].strip('(').strip(')'))
xmax = float(lineArray[2].strip('(').strip(')'))
ymax = float(lineArray[3].strip('(').strip(')'))
assert xmax >= xmin
assert ymax >= ymin
label = classes[int(lineArray[4])-1]
# diff = lineArray[5] ---->NWPU_VHR datasets has no diff label
diff = 0
objects.append({
"wnid": label, # class name, not number
"difficult": diff,
"box": {
"xmin": xmin,
"ymin": ymin,
"xmax": xmax,
"ymax": ymax,
}
})
return objects
接受一个指向标注的文件,返回包含'wnid', 'difficult', 'box'
字段的字典。
数据集标注时的起始坐标可能不同,如(1,1)
或(0,0)
。在函数_load_XXXXX_annotation(self,index)
中,我们要将起始坐标统一为(0,0)
。NWPU_VHR数据集的标注起始坐标是(0,0)
,因此不需要将读取的坐标统一减一。
在网络训练完毕后进行测试时需要函数do_python_eval(self,output_dir = 'output')
。该函数调用的函数voc_eval()
需要我们修改。
参考/lib/datasets/voc_eval.py
创建的新文件。函数voc_eval()
调用了函数parse_rec(filename)
,修改函数
def parse_rec(filename):
""" Parse a NWPU_VHR ground truth file
ground truth file format x1 y1 x2y2 label_name easy_or_not
"""
classes = ['airplane', 'ship', 'storage tank', 'baseball diamond', 'tennis court', 'basketball court', 'ground track field', 'harbor', 'bridge', 'vehicle']
objects = []
with open(filename,'r') as f:
for line in f.readlines():
lineArray = line.rstrip().split(',')
tl = lineArray[0:2]
br = lineArray[2:4]
label = classes[int(lineArray[4])-1]
objects.append({
'name': label,
'bbox': [float(lineArray[0].strip('(').strip(')')),
float(lineArray[1].strip('(').strip(')')),
float(lineArray[2].strip('(').strip(')')),
float(lineArray[3].strip('(').strip(')')),
],
'difficult': 0,
})
return objects
可以成功读取数据集标注。
AP for airplane = 1.0000
AP for ship = 0.8879
AP for storage tank = 0.7590
AP for baseball diamond = 0.9002
AP for tennis court = 0.8923
AP for basketball court = 0.9991
AP for ground track field = 0.9621
AP for harbor = 0.9487
AP for bridge = 0.6115
AP for vehicle = 0.8005
Mean AP = 0.8761
一个介绍Tensor board非常棒的视频:TensorBoard实践介绍