迁移学习, 用现成网络,跑自己数据: 保留已有网络除输出层以外其它层的权重, 改变已有网络的输出层的输出class 个数. 以已有网络权值为基础, 训练自己的网络,
以keras 2.1.5 / VGG16Net为例.
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.models import Sequential
from keras.layers import Dropout, Flatten, Dense
from keras import Model
from keras import initializers
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.applications.vgg16 import VGG16
# prepare data augmentation configuration
train_datagen = ImageDataGenerator(
rescale=1. / 255,
# shear_range=0.2,
# zoom_range=0.2,
# horizontal_flip=True
)
test_datagen = ImageDataGenerator(rescale=1. / 255)
# the input is the same as original network
input_shape = (224,224,3)
train_generator = train_datagen.flow_from_directory(
directory = './data/train/',
target_size = input_shape[:-1],
color_mode = 'rgb',
classes = None,
class_mode = 'categorical',
batch_size = 10,
shuffle = True)
test_generator = test_datagen.flow_from_directory(
directory = './data/test/',
target_size = input_shape[:-1],
batch_size = 10,
class_mode = 'categorical')
# build the VGG16 network, 加载VGG16网络, 改变输出层的类别数目.
# include_top = True, load the whole network
# set the new output classes to 10
# weights = None, load no weights
base_model = VGG16(input_shape = input_shape,
include_top = True,
classes = 10,
weights = None
)
print('Model loaded.')
base_model.layers[-1].name = 'pred'
base_model.layers[-1].kernel_initializer.get_config()
将会得到:
{
'distribution': 'uniform', 'mode': 'fan_avg', 'scale': 1.0, 'seed': None}
base_model.layers[-1].kernel_initializer = initializers.glorot_normal()
base_model.load_weights('./vgg16_weights_tf_dim_ordering_tf_kernels.h5', by_name = True)
# compile the model with a SGD/momentum optimizer
# and a very slow learning rate.
sgd = optimizers.SGD(lr=0.01, decay=1e-4, momentum=0.9, nesterov=True)
base_model.compile(loss = 'categorical_crossentropy',
optimizer = sgd,
metrics=['accuracy'])
# fine-tune the model
check = ModelCheckpoint('./',
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
period=1)
stop = EarlyStopping(monitor='val_loss',
min_delta=0,
patience=0,
verbose=0,
mode='auto')
base_model.fit_generator(
generator = train_generator,
epochs = 5,
verbose = 1,
validation_data = test_generator,
shuffle = True,
callbacks = [check, stop]
)
model.save_weights('fine_tuned_net.h5')