Python SVM 手写数字识别 流程

Python SVM 手写数字识别 流程

数据采集和标志

scikit-learn 自带了一些数据集,其中一个是数字识别图片的数据。

%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
digits = datasets.load_digits() # 加载数据
# 把数据所代表的图片显示出来
images_and_labels = list(zip(digits.images, digits.target))
plt.figure(figsize=(8, 6))
for index, (image, label) in enumerate(images_and_labels[:8]):
    plt.subplot(2, 4, index + 1)
    plt.axis('off')
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.title('Digit: %i' % label, fontsize=20);

特征选择

\qquad 针对一个手写的图片数据,应该怎么样选择特征呢?一个直观的方法是,直接使用图片的每个像素点作为一个特征。比如一个图片 200 × 200 200 \times 200 200×200 的分辨率,那么我们就有 40000 个特征,即特征向量的长度是 40000。
\qquad 实际上,scikit-learn 使用 Numpy 的 array 对象来表示数据,所有的图片数据保存在 digits.images里,每个元素都是一个 8 × 8 8 \times 8 8×8 尺寸的灰阶图片。我们在进行机器学习时,需要把数据保存为 样本个数 × \times × 特征个数 格式的 array 对象,针对手写识别这个案例, scikit-learn 已经为我们转换好了,它就保存在digit.data 数据里,可以通过 digits.data.shape 来查看它的数据格式为:

print("shape of raw image data: {0}".format(digits.images.shape))
print("shape of data: {0}".format(digits.data.shape))

\qquad 输出为:

shape of raw image data: (1797, 8, 8)
shape of data: (1797, 64)

数据清洗

\qquad 人们不可能在 8 × 8 8 \times 8 8×8 这么小的分辨率的图片上写出数字,在采集数据的时候,是让用户在一个大图片上写出这些数字,如果图片是 200 × 200 200 \times 200 200×200 分辨率,那么一个训练样例就有 40000 个特征,计算量将是巨大的。为了减少计算量,也为了模型的稳定性,我们需要把 200 × 200 200 \times 200 200×200 的图片缩小为 8 × 8 8\times 8 8×8 的图片。这个过程就是数据清洗,即把采集到的、不适用来做机器学习训练的数据进行预处理,从而转换为适合机器学习的数据。

模型选择

\qquad 不同的机器学习算法模型针对特定的机器学习应用有不同的效率,我们使用 支持向量机 来作为手写识别算法的模型。

模型训练

\qquad 在开始我们的模型训练之前,我们需要先把数据集分成 训练数据集测试数据集

# 把数据分成训练数据集和测试数据集
from sklearn.model_selection import train_test_split
Xtrain, Xtest, Ytrain, Ytest = train_test_split(digits.data, digits.target, test_size=0.20, random_state=2);

\qquad 接着,使用训练数据集 Xtrain 和 Ytrain 来训练模型。

# 使用支持向量机来训练模型
from sklearn import svm
clf = svm.SVC(gamma=0.001, C=100., probability=True)
clf.fit(Xtrain, Ytrain);

\qquad 训练完成后,clf 对象就会包含我们训练出来的模型参数,可以使用这个模型对象来进行预测。

模型测试

\qquad 我们来测试一下训练出来的模型的准确度。一个直观的方法是,我们用训练出来的模型 clf 预测测试数据集,然后把预测结果 Ypred 和真正的结果 Ytest 比较,看看多少个是正确的,这样就能评估出模型的准确度了。所幸,scikit-learn 提供了现成的方法来完成这箱工作:

# 评估模型的准确度
from sklearn.metrics import accuracy_score
Ypred = clf.predict(Xtest);
accuracy_score(Ytest, Ypred)
0.9777777777777777
clf.score(Xtest, Ytest)
0.9777777777777777

\qquad 除此之外,还可以直接把测试数据集里面的部分图片显示出来,并且在图片的左下角显示预测值,右下角显示真实值。

# 查看预测的情况
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
fig.subplots_adjust(hspace=0.1, wspace=0.1)

for i, ax in enumerate(axes.flat):
    ax.imshow(Xtest[i].reshape(8, 8), cmap=plt.cm.gray_r, interpolation='nearest')
    ax.text(0.05, 0.05, str(Ypred[i]), fontsize=32,
            transform=ax.transAxes,
            color='green' if Ypred[i] == Ytest[i] else 'red')
    ax.text(0.8, 0.05, str(Ytest[i]), fontsize=32,
            transform=ax.transAxes,
            color='black')
    ax.set_xticks([])
    ax.set_yticks([])

Python SVM 手写数字识别 流程_第1张图片

# Xtest[4] 的各种可能性
clf.predict_proba(Xtest[4].reshape(1, -1))
array([[0.00430518, 0.02350992, 0.01561862, 0.01075997, 0.03301238,
        0.02790444, 0.00478012, 0.15184996, 0.58209275, 0.14616666]])

模型保存与加载

\qquad 当我们对模型的准确度感到满意后,我们就可以把模型保存下来。这样下次需要预测时,可以直接加载模型来进行预测,而不是重新训练一次模型。可以使用下面的代码来保存模型:

# 保存模型参数
from sklearn.externals import joblib
joblib.dump(clf, 'digits_svm.pkl');

\qquad 当我们需要这个模型来进行预测时,直接加载模型即可进行预测。

# 导入模型参数,直接进行预测
clf = joblib.load('digits_svm.pkl')
Ypred = clf.predict(Xtest);
clf.score(Xtest, Ytest)

\qquad 笔者电脑上的输出结果是:

0.9777777777777777

你可能感兴趣的:(Python,机器学习)