digit classification


最近开始学习machine learning方面的内容,大致浏览了一遍《machine learning in action》一书,大概了解了一些常用的算法如knn,svm等具体式干啥的。

在kaggle上看到一个练手的项目:digit classification,又有良好的数据,于是打算用这个项目把各种算法都跑一遍,加深自己对各算法的研究,该文会不断更新。。。。。。


我们的数据集是mnist,链接:http://yann.lecun.com/exdb/mnist/

mnist的结构如下,选取train-images

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel

呐,由于智商比较拙急,看了这个形式居然没看明白,所以特此写出要注意的点,送给同样没看明白的童鞋。

首先该数据是以二进制存储的,我们读取的时候要以'rb'方式读取,其次,真正的数据只有[value]这一项,其他的[type]等只是来描述的,并不真正在数据文件里面。

由offset我们可以看出真正的pixel式从16开始的,一个int 32字节,所以在读取pixel之前我们要读取4个 32 bit integer,也就是magic number,number of images,number of rows,number of columns,读取二进制文件用struct比较方便,struct.unpack_from('>IIII',buf,index)表示按照大端方式读取4个int.

虽然数据集网站写着“Users of Intel processors and other low-endian machines must flip the bytes of the header.”,而我的电脑就是intel处理器,但是我尝试了一把还是得用大端方式读,读出来才是“2051 60000 28 28”,用小端方式读取就不正确了,这个小小实验一把就行。


下面先把数据文件直观的表现出来,用matplotlib把二进制文件用图像表现出来。具体如下:


# -*- coding:utf-8
import numpy as np 
import struct
import matplotlib.pyplot as plt 

filename = 'train-images.idx3-ubyte'
binfile = open(filename,'rb')#以二进制方式打开
buf = binfile.read()

index = 0
magic, numImages, numRows, numColums = struct.unpack_from('>IIII',buf,index)#读取4个32 int
print magic,' ',numImages,' ',numRows,' ',numColums
index += struct.calcsize('>IIII')


im = struct.unpack_from('>784B',buf,index)#每张图是28*28=784Byte,这里只显示第一张图
index += struct.calcsize('>784B' )

im = np.array(im)
im = im.reshape(28,28)
print im

fig = plt.figure()
plt.imshow(im,cmap = 'binary')#黑白显示
plt.show()

结果如下(很明显式一个5字):

digit classification_第1张图片

好,既然能够图形化表示数字,基本上该数字的feature我们都能拿到,一下就是用各种机器学习算法来做digit classification。

以上是显示一个图片,而我们的各种学习算法一定是处理多个图片的,那么有必要单独写一个文件来处理数据问题,即

ReadData.py:(代码和上面不少相似,这里只给链接)

https://github.com/wangyuquan/MachineLearning/blob/master/digitClassification/ReadData.py

(1)KNN:

      KNN可以说是最简单的分类器了。首先给予训练数据,并且给予这些训练数据的类别,然后当有一个新的数据进来的时候,根据最临近的k个节点的分类情况,例如k=5,就

取离新数据最近的5个点,假设3个类别是A,2个是B,按照多数服从少数的原则,将新数据赋予类别A.

具体代码如下:

from sklearn import neighbors
from ReadData import *
import datetime
startTime = datetime.datetime.now();
trainImage = getTrainImage()
trainLabel = getTrainLabel()
testImage = getTestImage()
testLabel = getTestLabel()
knn = neighbors.KNeighborsClassifier(algorithm = 'auto',leaf_size = 30,n_neighbors=3,warn_on_equidistant = True,weights = 'uniform')
knn.fit(trainImage,trainLabel)
match = 0;
for i in xrange(len(testLabel)):
	predictLabel = knn.predict(testImage[i])[0]
	print i,' ',
	print predictLabel,' ', testLabel[i]
	if(predictLabel==testLabel[i]):
		match += 1

endTime = datetime.datetime.now()
print 'use time: '+str(endTime-startTime)
print 'error rate: '+ str(1-(match*1.0/len(testLabel)))

效果如下:


错误率为3.91%,这里我的k选的是3,接下来我会尝试不同的k,来看不同k带来的不同影响。

优点:简单易懂,实现方便(指具体算法实现,当然我们在用的时候都是直接调用库啦:))

缺点:算法复杂度为K*M(train data的大小)*N(test data的大小),耗时较长,因为每个新数据都要与所以training data计算下距离,大数据情况下可能吃不消


(2)SVM:

svm可以写的还真的挺多的,这里只写一些比较重要的,并且需要记住的几点

首先svm本质上是一个二元分类器,但是也可以扩展来出来多元分类(先一分为二,剩下的再分,还有DAG等方法)。svm可以处理线性可分或者不可分的数据:线性可分的话,举个例子在二维平面数据可以被一条直线一分为二,主要工作就是找这条直线(维数多了就是超平面),然后通过优化方法找到直线或者超平面,使得margin最大(margin指点到分割面的最短距离),对于线性不可分的数据来说,采用升维的方法,也就是大家指的核函数,把数据变成基本可分,引入松弛变量来处理某些噪声点,也就是所谓的离群点,用惩罚因子来代表我们对离群点的重视度,越大越重视。至今还没真正理解svm的算法复杂度,等待进一步的学习。以下代码还是直接调用库函数,有空自己写一个简单版的.(这个没有进行多种调参,因为只是想学习svm的原理)

具体代码如下:

from sklearn import svm
from ReadData import *
import datetime
startTime = datetime.datetime.now();
trainImage = getTrainImage()
trainLabel = getTrainLabel()
testImage = getTestImage()
testLabel = getTestLabel()
clf = svm.SVC(kernel='rbf')
clf.fit(trainImage,trainLabel)
match = 0
for i in range(len(testImage)):
	predictResult = int(clf.predict(testImage[i])[0])
	if(predictResult==testLabel[i]):
		match += 1
	print i,' ',predictResult ,' ',testLabel[i]

endTime = datetime.datetime.now()
print 'match:',match
print 'use time: '+str(endTime-startTime)
print 'error rate: '+ str(1-(match*1.0/len(testImage)))

结果如下:


3:Navie Bayes

贝叶斯分类器均建立在贝叶斯公式上:P(x|y) = P(y|x)*P(x)/P(y)

在分类问题中,x代表类别,y代表特征向量,P(x|y)很难求,但是可以通过等号后面几者来求,P(y)是确定的,可以忽略,求某个特征向量最可能属于哪个类是,只要比出哪个类的概率最大就行,在比较是分母都一样,所以可以忽略,只需要求 P(y|x)*P(x),P(x)表示类别的概率, P(y|x)表示该类下特征向量是y的概率,这个有点难求。但是朴素贝叶斯假定,每个feature之间相互无关,即 P(y|x) = P(y1|x)*P(y2|x)****P(yn|x),这样就能求得该变量,该假设十分简单有效。

具体代码如下:

from sklearn.naive_bayes import MultinomialNB
from ReadData import *
import datetime
startTime = datetime.datetime.now();
trainImage = getTrainImage()
trainLabel = getTrainLabel()
testImage = getTestImage()
testLabel = getTestLabel()
mnb = MultinomialNB()
mnb.fit(trainImage,trainLabel)
match = 0
for i in range(len(testImage)):
	predictResult = int(mnb.predict(testImage[i])[0])
	if(predictResult==testLabel[i]):
		match += 1
	print i,' ',predictResult ,' ',testLabel[i]

endTime = datetime.datetime.now()
print 'match:',match
print 'use time: '+str(endTime-startTime)
print 'error rate: '+ str(1-(match*1.0/len(testImage)))


 结果如图: 
  


错误率较高,原因很明显,我们将每个像素作为feature,这些feature之间肯定是有关的,不能假设为独立

navie bayes适合各个feature独立的分类:)

(4):decision tree

决策树是一个很能直观理解的分类过程。不需要额外的参数。具体怎么构造决策树,选取哪个feature作为决策点,主要算法有ID3和C4.5。

具体代码如下:

from sklearn import tree
from ReadData import *
import datetime
startTime = datetime.datetime.now();
trainImage = getTrainImage()
trainLabel = getTrainLabel()
testImage = getTestImage()
testLabel = getTestLabel()
clf = tree.DecisionTreeClassifier()
clf.fit(trainImage,trainLabel)
match = 0
for i in range(len(testImage)):
	predictResult = int(clf.predict(testImage[i])[0])
	if(predictResult==testLabel[i]):
		match += 1
	print i,' ',predictResult ,' ',testLabel[i]

endTime = datetime.datetime.now()
print 'match:',match
print 'use time: '+str(endTime-startTime)
print 'error rate: '+ str(1-(match*1.0/len(testImage)))

结果如下:

结果也不是特别令人满意,原因和上一个差不多,我们构造的feature不太适合用决策树来分类。

(5)TODO:

你可能感兴趣的:(算法,python)