上一章中我们是用以下代码来进行对象创建的。这段代码并没有什么问题,每次新增一个层的时候在此处添加els if即可。不过还是可以有更优雅一些的实现。
LayerBase *LayerFactory(string classname)
{
if (classname == "conv")
return new LayerConv();
else if (classname == "pooling")
return new LayerPooling();
else if (classname == "softmax")
return new LayerSoftmax();
else
return nullptr;
}
计划构造一张map映射表,key为类名或者ID,value为可返回对象的函数指针。这样上面的函数就可以变成这一个样子。
// LayerFactory.h
#include
using std::function;
class LayerFactory
{
public:
static LayerBase *Create(string str){
function <LayerBase*()> fun = s_CreateMap[str];
if (nullptr == fun)
return nullptr;
return fun();
}
static void RegisterCreater(string classname, function<LayerBase*()> creater)
{
s_CreateMap[classname] = creater;
}
private:
static map<string, function<LayerBase*()>> s_CreateMap;
};
// LayerFactory.cpp
#include"LayerFactory.h"
LayerBase *LayerFactoryMap(string classname)
{
return LayerFactory::Create(classname);
}
s_CreateMap
就是这个映射表,最直接的想法是把每个类的构造函数赋予s_CreateMap
,可是C++是无法获取指向类的构造函数的函数指针。
所以转变思路,用另一个函数(比如叫fwrap
)把构造函数封起来,然后再把fwrap
的地址赋予map表即可。如果在每个layer类里面定义这样的fwrap
函数。虽然成员函数可以获得函数指针,但其在调用的时必须要有具体对象(因为参数列表中需要传this指针),这样调用就显得麻烦了。所以最后就只有全局函数或者类的静态函数符合要求。
以Conv层举例,得到以下实现:
class AUTO_FACTORY_Conv
{
public:
AUTO_FACTORY_Conv(){ LayerFactory::RegisterCreater("conv", AUTO_FACTORY_Conv::CreateLayer); };
static LayerBase* CreateLayer() { return new LayerConv(); }
};
可以看到这里把注册语句放到了AUTO_FACTORY_Conv 类的构造函数中。这是为了只要声明一个对象即可自动去注册一个字符串和对象构建器的映射关系。
比如:
static AUTO_FACTORY_Conv g_obj_for_register_conv;
那CreateLayer
函数就是上面所描述的fwrap
了。
可以发现这段代码非常模式化,按照c++惯例用宏来总结下。
#define REGISTER_LAYER_CREATE(idname,classname) \
class AUTO_FACTORY_##idname \
{ \
public: \
AUTO_FACTORY_##idname(){ LayerFactory::RegisterCreater(#idname, AUTO_FACTORY_##idname::CreateLayer); }; \
static LayerBase* CreateLayer() { return new classname(); } \
}; \
static AUTO_FACTORY_##idname class_creater_register_##idname;
最终完整代码如下:
// LayerFactory.h
#include
using std::function;
LayerBase *LayerFactory(string classname);
class LayerFactory
{
public:
static LayerBase *Create(string str){
function <LayerBase*()> fun = s_CreateMap[str];
if (nullptr == fun)
return nullptr;
return fun();
}
static void RegisterCreater(string classname, function<LayerBase*()> creater)
{
s_CreateMap[classname] = creater;
}
private:
static map<string, function<LayerBase*()>> s_CreateMap;
};
#define REGISTER_LAYER_CREATE(idname,classname) \
class AUTO_FACTORY_##idname \
{ \
public: \
AUTO_FACTORY_##idname(){ LayerFactory::RegisterCreater(#idname, AUTO_FACTORY_##idname::CreateLayer); }; \
static LayerBase* CreateLayer() { return new classname(); } \
}; \
static AUTO_FACTORY_##idname class_creater_register_##idname;
// LayerFactory.cpp
#include"layerFactory.h"
map<string, function<LayerBase*()>> LayerFactory::s_CreateMap;
// register
REGISTER_LAYER_CREATE(conv,LayerConv)
REGISTER_LAYER_CREATE(pooling, LayerPooling)
REGISTER_LAYER_CREATE(softmax, LayerSoftmax)
LayerBase *LayerFactoryMap(string classname)
{
return LayerFactory::Create(classname);
}
REGISTER_LAYER_CREATE
宏最终应用在layerFactory.cpp中。即用于注册的类声明只有这个cpp文件可见,且定义的帮助注册的类也是置为static的,这些都是为了把对象构建器的细节封装到这个编译单元内。1.工厂模式
2.这种构建map的方式是从某篇博客学来的,但不记得出处回头补上