目录
mmdetection Registry类
Registry在mmdectection中使用可分三步
1、实例化Registry
2、利用Registry类中的函数register_module()对要注册的类进行装饰,即可获取类
3、利用Registry对象初始化已注册的类,即得到已注册类对象
在mmdetection中,注册器的实例化过程如下
了解Registry类之前,咱们直接上代码,先看看mmdetection Registry类怎么使用,对其功能比较直观的感受,废话不多说,代码如下:
import torch.nn as nn
from functools import partial
import inspect
# 定义注册器Registry类
class Registry(object):
def __init__(self, name):
self._name = name # 注册器容器(类)的对象名,如果为BACKBONE,则name可写为backbone,属于自定义
self._module_dict = dict() # 注册器容器,存放已注册的类
def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str
@property
def name(self): # 类属性,用于获取某个注册器名
return self._name
@property
def module_dict(self): # 类属性,用于获取某个注册器容器,存放已注册的类
return self._module_dict
def get(self, key): # 根据注册器中已注册的类的名字,获取类
return self._module_dict.get(key, None)
def _register_module(self, module_class, force=False):
"""Register a module.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__ # 类名
if not force and module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
# 将即将注册的类存放至注册器容器_module_dict中
self._module_dict[module_name] = module_class
def register_module(self, cls=None, force=False):
if cls is None:
return partial(self.register_module, force=force)
self._register_module(cls, force=force)
return cls
# 实例化Registry
HEAD = Registry("head")
# Registry使用
@HEAD.register_module
class SOLOv2Head(nn.Module):
def __init__(self):
super(SOLOv2Head, self).__init__() # 调用父类的__init__()
print("SOLOv2Head init..")
@HEAD.register_module
class SOLOv2HeadMask(nn.Module):
def __init__(self):
super(SOLOv2HeadMask, self).__init__() # 调用父类的__init__()
print("SOLOv2HeadMask init..")
print(HEAD.module_dict)
# {'SOLOv2Head': ,
# 'SOLOv2HeadMask': }
print(HEAD.name) # head
print(HEAD.get('SOLOv2HeadMask')()) # SOLOv2HeadMask init.. SOLOv2HeadMask()
输出如下:
{'SOLOv2Head': , 'SOLOv2HeadMask': }
head
SOLOv2HeadMask init..
SOLOv2HeadMask()
Registry功能作用,可简单总结为:
利用Registry类对将要注册的类进行管理,其注册的原理,则是使用Registry类中register_module()函数作为函数装饰器,对将要注册的类进行装饰(即获取到该类
HEAD = Registry("head")
@HEAD.register_module
class SOLOv2Head(nn.Module):
def __init__(self):
super(SOLOv2Head, self).__init__()
print("SOLOv2Head init..")
cls = HEAD.get('SOLOv2HeadMask')
print(cls)
print("-----------------------")
print(cls())
输出:
-----------------------
SOLOv2HeadMask init..
SOLOv2HeadMask()
首先,实例化注册器
from mmdet.utils import Registry
BACKBONES = Registry('backbone')
NECKS = Registry('neck')
ROI_EXTRACTORS = Registry('roi_extractor')
SHARED_HEADS = Registry('shared_head')
HEADS = Registry('head')
LOSSES = Registry('loss')
DETECTORS = Registry('detector')
其次,在代码模块中,例如:
--models
-- __init__.py
添加:
from .registry import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
ROI_EXTRACTORS, SHARED_HEADS)
最后,使用实例化后的注册器(例SOLOv2Head)
from ..registry import HEADS
...
@HEADS.register_module
class SOLOv2Head(nn.Module):
def __init__(self,
num_classes,
in_channels,
seg_feat_channels=256,
stacked_convs=4,
strides=(4, 8, 16, 32, 64),
base_edge_list=(16, 32, 64, 128, 256),
scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128, 512)),
sigma=0.2,
num_grids=None,
ins_out_channels=64,
loss_ins=None,
loss_cate=None,
conv_cfg=None,
norm_cfg=None,
use_dcn_in_tower=False,
type_dcn=None):
super(SOLOv2Head, self).__init__()
pass