keras yolov3利用预训练模型生成自己数据的xml文件

修改yolo_video.py:

直接上代码:

if FLAGS.image:
        """
        Image detection mode, disregard any remaining command line arguments
        """
        print("Image detection mode")
        if "input" in FLAGS:
            print(" Ignoring remaining command line arguments: " + FLAGS.input + "," + FLAGS.output)
        
        detect_xml(YOLO(**vars(FLAGS)))      
        # detect_img(YOLO(**vars(FLAGS)))
def detect_xml(yolo):
    llist=get_all_file('/home/user07/MyProjects/yolo/keras-yolo3-master/datasets/tjzt/JPEGImages')
    folder='/home/user07/MyProjects/yolo/keras-yolo3-master/datasets/tjPersonXml'
    num_count=0
    for index in range(len(llist)):
        img=llist[index]
        image = Image.open(img)
        lxml=os.path.join(folder,os.path.splitext(img.split('/')[-1])[0])
        r_image = yolo.detect_xml(image,lxml)
        num_count +=1
        print('这是第%d张图'%num_count)
    yolo.close_session()

修改yolo.py:

def detect_xml(self, image,lxml):
        start = timer()

        if self.model_image_size != (None, None):
            assert self.model_image_size[0]%32 == 0, 'Multiples of 32 required'
            assert self.model_image_size[1]%32 == 0, 'Multiples of 32 required'
            boxed_image = letterbox_image(image, tuple(reversed(self.model_image_size)))
        else:
            new_image_size = (image.width - (image.width % 32),
                              image.height - (image.height % 32))
            boxed_image = letterbox_image(image, new_image_size)
        image_data = np.array(boxed_image, dtype='float32')

        # print(image_data.shape)
        image_data /= 255.
        image_data = np.expand_dims(image_data, 0)  # Add batch dimension.

        out_boxes, out_scores, out_classes = self.sess.run(
            [self.boxes, self.scores, self.classes],
            feed_dict={
                self.yolo_model.input: image_data,
                self.input_image_shape: [image.size[1], image.size[0]],
                K.learning_phase(): 0
            })
        

        doc = xml.dom.minidom.Document()
        annotation = doc.createElement('annotation')
        doc.appendChild(annotation)

        folder = doc.createElement('folder')
        folder_text = doc.createTextNode('blank')
        folder.appendChild(folder_text)
        annotation.appendChild(folder)

        filename = doc.createElement('filename')
        filename_text = doc.createTextNode('blank')
        filename.appendChild(filename_text)
        annotation.appendChild(filename)

        path = doc.createElement('path')
        path_text = doc.createTextNode('blank')
        path.appendChild(path_text)
        annotation.appendChild(path)

        source = doc.createElement('source')
        databass = doc.createElement('databass')
        databass_text = doc.createTextNode('Unknown')
        source.appendChild(databass)
        databass.appendChild(databass_text)
        annotation.appendChild(source)

        size = doc.createElement('size')
        width = doc.createElement('width')
        width_text = doc.createTextNode(str(image.width))
        height = doc.createElement('height')
        height_text = doc.createTextNode(str(image.height))
        depth = doc.createElement('depth')
        depth_text = doc.createTextNode('3')
        size.appendChild(width)
        width.appendChild(width_text)
        size.appendChild(height)
        height.appendChild(height_text)
        size.appendChild(depth)
        depth.appendChild(depth_text)
        annotation.appendChild(size)

        segmented = doc.createElement('segmented')
        segmented_text = doc.createTextNode('0')
        segmented.appendChild(segmented_text)
        annotation.appendChild(segmented)

        for i, c in reversed(list(enumerate(out_classes))):
            predicted_class = self.class_names[c]
            if str(predicted_class)=='person':
                box=out_boxes[i]
                lobject=doc.createElement('object')
                name=doc.createElement('name')
                name_text=doc.createTextNode(str(predicted_class))
                pose=doc.createElement('pose')
                pose_text=doc.createTextNode('Unspecified')
                truncated=doc.createElement('truncated')
                truncated_text=doc.createTextNode('0')
                difficult=doc.createElement('difficult')
                difficult_text=doc.createTextNode('0')

                name.appendChild(name_text)
                lobject.appendChild(name)
                pose.appendChild(pose_text)
                lobject.appendChild(pose)
                truncated.appendChild(truncated_text)
                lobject.appendChild(truncated)
                difficult.appendChild(difficult_text)
                lobject.appendChild(difficult)

                top, left, bottom, right = box
                top = max(0, np.floor(top + 0.5).astype('int32'))
                left = max(0, np.floor(left + 0.5).astype('int32'))
                bottom = min(image.size[1], np.floor(bottom + 0.5).astype('int32'))
                right = min(image.size[0], np.floor(right + 0.5).astype('int32'))

                bndbox=doc.createElement('bndbox')
                xmin=doc.createElement('xmin')
                xmin_text=doc.createTextNode(str(left))
                ymin=doc.createElement('ymin')
                ymin_text=doc.createTextNode(str(top))
                xmax=doc.createElement('xmax')
                xmax_text=doc.createTextNode(str(right))
                ymax=doc.createElement('ymax')
                ymax_text=doc.createTextNode(str(bottom))

                xmin.appendChild(xmin_text)
                bndbox.appendChild(xmin)
                ymin.appendChild(ymin_text)
                bndbox.appendChild(ymin)
                xmax.appendChild(xmax_text)
                bndbox.appendChild(xmax)
                ymax.appendChild(ymax_text)
                bndbox.appendChild(ymax)

                lobject.appendChild(bndbox)
                annotation.appendChild(lobject)

        fp = open('%s.xml' %lxml , 'w+')
        doc.writexml(fp, indent='\t', addindent='\t', newl='\n',encoding='utf-8')
        fp.close()       
        end = timer()
        # print(end - start)

代码中的图片路径以及保存xml的路径修改为自己的即可。最后调用的时候:

python yolo_video.py --image

即可。

你可能感兴趣的:(yolo)