【Tensorflow2】使用简单的CNN网络实现图像分类

不废话,直接上代码…

一、数据读取和构建数据集

import os
import pandas as pd
import numpy as np
import tensorflow as tf
import warnings
warnings.filterwarnings('ignore')
from tensorflow.keras.preprocessing import image_dataset_from_directory

train_dir = '/home/workspace/train/'
train_dataset = image_dataset_from_directory(train_dir,batch_size=8, image_size=(64,64))
print(train_dataset.class_names)

#因为数据集只有一份,所以需要把验证集从训练集中拆出来
batch_num = train_dataset.cardinality()
valid_dataset = train_dataset.take(batch_num // 5)
train_dataset = train_dataset.skip(batch_num // 5)

二、搭建CNN网络、训练并保存

model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(64,64,3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(2, activation='softmax') #这里用了2分类,所以后面loss也是SparseCategoricalCrossentropy
])

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(),
             optimizer=tf.keras.optimizers.Adam(1e-4),
             metrics=['accuracy'])

history = model.fit_generator(train_dataset,
                             epochs=10,
                             validation_data=valid_dataset)

model.save('model_1.h5')

三、加载模型并预测测试集

def get_and_decode_img(img_path):
	'''把单张图片decode成适应model的dataset'''
    img = tf.io.read_file(img_path)
    img_tensor = tf.image.decode_jpeg(img,channels=3)
    img_tensor = tf.image.resize(img_tensor,[64,64])
    img_tensor = tf.expand_dims(img_tensor,axis=0)
    img_tensor = tf.cast(img_tensor, dtype=tf.float32)
    return img_tensor

model = tf.keras.models.load_model('model_1.h5')
test_data_name_list = os.listdir('/home/workspace/test/')
test_data_path_list = [os.path.join('/home/workspace/test', x) for x in test_data_name_list]

test_dataset = tf.data.Dataset.from_tensor_slices(test_data_path_list).map(get_and_decode_img)

pred = model.predict(test_dataset)
pred_label = np.argmax(pred, axis=-1)

test_df = pd.DataFrame({'name':test_data_name_list, 'label':pred_label})
test_df['label'] = test_df['label'].map({0:'cat', 1:'dog'})
test_df['true_label'] = test_df['name'].apply(lambda x:x.split('.')[0])

from sklearn.metrics import classification_report
print(classification_report(test_df['true_label'], test_df['label']))

你可能感兴趣的:(CV,图像分类)