【Mask R-CNN】(八):代码理解demo.ipynb

首先,导入包。

import os
import sys
import random
import math
import numpy as np
import skimage.io
import matplotlib
import matplotlib.pyplot as plt

#设置根目录
ROOT_DIR = os.path.abspath("../")

#导入Mask R-CNN
sys.path.append(ROOT_DIR)
from mrcnn import utils
import mrcnn.model as modellib
from mrcnn import visualize
#导入coco的配置
sys.path.append(os.path.join(ROOT_DIR, "samples/coco/"))
import coco
%matplotlib inline 

#设置写log的目录以及保存训练model的目录
MODEL_DIR = os.path.join(ROOT_DIR, "logs")

#预训练模型的目录
COCO_MODEL_PATH = os.path.join(ROOT_DIR, "mask_rcnn_coco.h5")
#如果文件不存在,则下载COCO预训练权重
if not os.path.exists(COCO_MODEL_PATH):
    utils.download_trained_weights(COCO_MODEL_PATH)

#测试图片目录
IMAGE_DIR = os.path.join(ROOT_DIR, "images")

这里将会使用一个在MS_COCO数据集上训练好的模型。这个模型的配置信息位于coco.py文件的CocoConfig类中。

在进行预测时,需要对这个配置做一些小的改动以匹配当前的任务。具体做法是,继承CocoConfig类并更改你需要的属性。

class InferenceConfig(coco.CocoConfig):
    #将batch size设置为1,因为在预测时一次只需要处理一副图像
    #Batch size = GPU_COUNT * IMAGES_PER_GPU
    GPU_COUNT = 1
    IMAGES_PER_GPU = 1

config = InferenceConfig()
#将配置信息打印出来
config.display()

接下来就创建模型对象并加载训练好的权重。

#创建预测模型对象.
model = modellib.MaskRCNN(mode="inference", model_dir=MODEL_DIR, config=config)

#加载MS-COCO训练权重
model.load_weights(COCO_MODEL_PATH, by_name=True)

该模型分类物体并返回类别ID,类别ID以整型值来区分每一类物体。

有些数据集为类别赋予整型值以便于区分,但是有些数据集没有这样做。ID通常是顺序连续的,但不绝对。为了增强一致性并且在同一时间可以训练多个源数据,这里的Dataset class为每个类分配了自己的整型ID。

所以要注意将class IDs映射到正确的class names。

这里我们并不想下载所有的COCO数据集来测试这个demo,所以将所有的class names预先存在一个list中。class name的索引即代表了它们的ID。

# COCO Class names
#索引即代表ID.例如,要获取teddy bear的ID
# use: class_names.index('teddy bear')
class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',
               'bus', 'train', 'truck', 'boat', 'traffic light',
               'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',
               'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',
               'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
               'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
               'kite', 'baseball bat', 'baseball glove', 'skateboard',
               'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
               'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
               'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
               'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
               'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
               'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',
               'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',
               'teddy bear', 'hair drier', 'toothbrush']

一切准备好后,就可以开始预测。从测试文件夹内随机选取一副图像。

#随机选择图像
file_names = next(os.walk(IMAGE_DIR))[2]
#读取图像
image = skimage.io.imread(os.path.join(IMAGE_DIR, random.choice(file_names)))

#运行检测
results = model.detect([image], verbose=1)

#可视化检测结果
r = results[0]
visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], 
                            class_names, r['scores'])

首先会打印出相关信息。

Processing 1 images
image                    shape: (476, 640, 3)         min:    0.00000  max:  255.00000
molded_images            shape: (1, 1024, 1024, 3)    min: -123.70000  max:  120.30000
image_metas              shape: (1, 89)               min:    0.00000  max: 1024.00000

接着会显示检测结果,包含bounding box, probability和mask。

【Mask R-CNN】(八):代码理解demo.ipynb_第1张图片

 

你可能感兴趣的:(深度学习,tensorflow,Mask,RCNN)