mmdetection Registry

目录

mmdetection Registry类

Registry在mmdectection中使用可分三步

1、实例化Registry

2、利用Registry类中的函数register_module()对要注册的类进行装饰,即可获取类

3、利用Registry对象初始化已注册的类,即得到已注册类对象

在mmdetection中,注册器的实例化过程如下


mmdetection Registry类

了解Registry类之前,咱们直接上代码,先看看mmdetection Registry类怎么使用,对其功能比较直观的感受,废话不多说,代码如下:

import torch.nn as nn
from functools import partial
import inspect

# 定义注册器Registry类
class Registry(object):

    def __init__(self, name):
        self._name = name # 注册器容器(类)的对象名,如果为BACKBONE,则name可写为backbone,属于自定义
        self._module_dict = dict()  # 注册器容器,存放已注册的类

    def __repr__(self):
        format_str = self.__class__.__name__ + '(name={}, items={})'.format(
            self._name, list(self._module_dict.keys()))
        return format_str

    @property
    def name(self):  # 类属性,用于获取某个注册器名
        return self._name

    @property
    def module_dict(self):  # 类属性,用于获取某个注册器容器,存放已注册的类
        return self._module_dict

    def get(self, key):  # 根据注册器中已注册的类的名字,获取类
        return self._module_dict.get(key, None)

    def _register_module(self, module_class, force=False):
        """Register a module.

        Args:
            module (:obj:`nn.Module`): Module to be registered.
        """
        if not inspect.isclass(module_class):
            raise TypeError('module must be a class, but got {}'.format(
                type(module_class)))
        module_name = module_class.__name__ # 类名
        if not force and module_name in self._module_dict:
            raise KeyError('{} is already registered in {}'.format(
                module_name, self.name))
        # 将即将注册的类存放至注册器容器_module_dict中
        self._module_dict[module_name] = module_class

    def register_module(self, cls=None, force=False):
        if cls is None:
            return partial(self.register_module, force=force)
        self._register_module(cls, force=force)
        return cls

# 实例化Registry
HEAD = Registry("head")

# Registry使用
@HEAD.register_module
class SOLOv2Head(nn.Module):
    def __init__(self):
        super(SOLOv2Head, self).__init__()  # 调用父类的__init__()
        print("SOLOv2Head init..")

@HEAD.register_module
class SOLOv2HeadMask(nn.Module):
    def __init__(self):
        super(SOLOv2HeadMask, self).__init__()  # 调用父类的__init__()
        print("SOLOv2HeadMask init..")

print(HEAD.module_dict) 
# {'SOLOv2Head': ,
# 'SOLOv2HeadMask': }
print(HEAD.name)  # head
print(HEAD.get('SOLOv2HeadMask')()) # SOLOv2HeadMask init..  SOLOv2HeadMask()

输出如下:

{'SOLOv2Head': , 'SOLOv2HeadMask': }
head
SOLOv2HeadMask init..
SOLOv2HeadMask()

Registry功能作用,可简单总结为:

        利用Registry类对将要注册的类进行管理,其注册的原理,则是使用Registry类中register_module()函数作为函数装饰器,对将要注册的类进行装饰(即获取到该类),在按照类名,使用dict对注册的类进行存储!

Registry在mmdectection中使用可分三步

1、实例化Registry

HEAD = Registry("head")


2、利用Registry类中的函数register_module()对要注册的类进行装饰,即可获取类

@HEAD.register_module
class SOLOv2Head(nn.Module):
    def __init__(self):
        super(SOLOv2Head, self).__init__()
        print("SOLOv2Head init..")

3、利用Registry对象初始化已注册的类,即得到已注册类对象

cls = HEAD.get('SOLOv2HeadMask')
print(cls)
print("-----------------------")
print(cls())

输出:

-----------------------
SOLOv2HeadMask init..
SOLOv2HeadMask()

在mmdetection中,注册器的实例化过程如下

首先,实例化注册器

from mmdet.utils import Registry

BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')

其次,在代码模块中,例如:

--models

        -- __init__.py

添加:

from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
                       ROI_EXTRACTORS, SHARED_HEADS)

最后,使用实例化后的注册器(例SOLOv2Head)

from ..registry import HEADS

...

@HEADS.register_module
class SOLOv2Head(nn.Module):

    def __init__(self,
                 num_classes,
                 in_channels,
                 seg_feat_channels=256,
                 stacked_convs=4,
                 strides=(4, 8, 16, 32, 64),
                 base_edge_list=(16, 32, 64, 128, 256),
                 scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
                 sigma=0.2,
                 num_grids=None,
                 ins_out_channels=64,
                 loss_ins=None,
                 loss_cate=None,
                 conv_cfg=None,
                 norm_cfg=None,
                 use_dcn_in_tower=False,
                 type_dcn=None):
        super(SOLOv2Head, self).__init__()
        pass

你可能感兴趣的:(pytorch,mmdetection,python,计算机视觉,目标检测,人工智能,pytorch)