支持向量机SVM的学习笔记。对书中关键知识点进行了摘录,并加入一些自己的理解。
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
python代码主要参照《机器学习实战》上的python代码和数据集
#author=altman
import numpy as np
import copy as cp
import matplotlib.pyplot as plt
'''
读取数据
'''
def readData(fileName):
matrix = []
labels = []
fr = open(fileName)
for line in fr.readlines():
line = line.strip().split("\t")
data = list(map(float,line))
matrix.append([data[0],data[1]])
labels.append(data[-1])
return matrix,labels
def clipAlpha(aj,H,L):
if aj > H:
aj = H
if L > aj:
aj = L
return aj
def kernelTrans(X, A, kTup):
m,n = np.shape(X)
K = np.mat(np.zeros((m,1)))
if kTup[0]=='lin':
K = X * A.T
elif kTup[0]=='rbf':
for j in range(m):
deltaRow = X[j,:] - A
K[j] = deltaRow*deltaRow.T
K = np.exp(K/(-1*kTup[1]**2))
return K
'''
数据结构
'''
class svmStrcut(object):
"""docstring for svmStrcut"""
def __init__(self, data,labels,C,toler,kTup):
self.C = C
self.data = data
self.labels = labels
self.toler = toler
self.m = self.data.shape[0]
self.b = 0.0
self.alphas = np.mat(np.zeros((self.m,1)))
self.eCache = np.mat(np.zeros((self.m,2)))
self.K = np.mat(np.zeros((self.m,self.m)))
for i in range(self.m):
self.K[:,i] = kernelTrans(self.data, self.data[i,:], kTup)
'''
计算误差
error=g(x)-y
'''
def calcEk(sT,k):
value = float(np.multiply(sT.alphas,sT.labels.T).T*sT.K[:,k]+sT.b)
error = value - float(sT.labels[0,k])
return error
'''
选择第二个改变alpha
'''
def selectJ(i,sT,ei):
maxK = -1
maxDeltaE = 0
ej = 0
#将i对应的alpha,标识为已更改过
sT.eCache[i] = [1,ei]
validEcacheList = np.nonzero(sT.eCache[:,0].A)[0]
if (len(validEcacheList))>1:
for k in validEcacheList:
if k == i : continue
ek = calcEk(sT,k)
deltaE = abs(ei-ek)
if deltaE > maxDeltaE:
maxDeltaE = deltaE
ej = ek
maxK = k
return maxK,ej
else:
j = i
while (j==i):
j = np.random.uniform(sT.m)
ej = calcEk(sT,j)
return j,ej
def updateEk(sT,k):
ek = calcEk(sT,k)
sT.eCache[k] = [1,ek]
'''
内循环,即第二alpha选择
'''
def innerLoop(i,sT):
ei = calcEk(sT,i)
if ((sT.labels[0,i]*ei < -sT.toler and sT.alphas[i]< sT.C)or(sT.labels[0,i]*ei > sT.toler and sT.alphas[i]> 0)):
#根据ei和i,选择j
j,ej = selectJ(i,sT,ei)
#保存i和j的旧的alpha
alphaOldI = sT.alphas[i].copy()
alphaOldJ = sT.alphas[j].copy()
#计算alpha的上下界
if (sT.labels[0,i] != sT.labels[0,j]):
L = max(0, sT.alphas[j] - sT.alphas[i])
H = min(sT.C, sT.C + sT.alphas[j] - sT.alphas[i])
else:
L = max(0, sT.alphas[j] + sT.alphas[i] - sT.C)
H = min(sT.C, sT.alphas[j] + sT.alphas[i])
if L == H :print("L==H");return 0
#更新j的alpha
eta = sT.K[i,i]+sT.K[j,j]-2*sT.K[i,j]
if eta <=0 :print("eta<=0");return 0
sT.alphas[j] = alphaOldJ + sT.labels[0,j]*(ei-ej)/eta
sT.alphas[j] = clipAlpha(sT.alphas[j],H,L)
updateEk(sT,j)
#如果j的移动过小,则直接返回
if (abs(sT.alphas[j]-alphaOldJ) < 0.00001):
print("j not move enough") ;return 0
#更新i的alpha
sT.alphas[i] = alphaOldI + sT.labels[0,i]*sT.labels[0,j]*(alphaOldJ-sT.alphas[j])
updateEk(sT,i)
#更新b
b1 = sT.b - ei- sT.labels[0,i]*(sT.alphas[i]-alphaOldI)*sT.K[i,i] - sT.labels[0,j]*(sT.alphas[j]-alphaOldJ)*sT.K[i,j]
b2 = sT.b - ej- sT.labels[0,i]*(sT.alphas[i]-alphaOldI)*sT.K[i,j] - sT.labels[0,j]*(sT.alphas[j]-alphaOldJ)*sT.K[j,j]
if (0 < sT.alphas[i]) and (sT.C > sT.alphas[i]): sT.b = b1
elif (0 < sT.alphas[j]) and (sT.C > sT.alphas[j]): sT.b = b2
else: sT.b = (b1 + b2)/2.0
return 1
else:
return 0
def smoP(dataMatIn, classLabels, C, toler, maxIter,kTup=['lin',0]):
sT = svmStrcut(np.mat(dataMatIn),np.mat(classLabels),C,toler,kTup)
count = 0
entireSet = True
while (count < maxIter):
alphaPairsChanged = 0
#首先计算满足0 0) * (sT.alphas.A < C))[0]
worestI = None
worestErr = 0.0
for i in nonBoundIs:
error = calcEk(sT,i)
if error > worestErr:
worestErr = error
worestI = i
if worestI != None:
alphaPairsChanged += innerLoop(worestI,sT)
#遍历全部样本
if alphaPairsChanged==0 :
for i in range(sT.m):
alphaPairsChanged += innerLoop(i,sT)
print ("fullSet, iter: %d i:%d, pairs changed %d" %(count,i,alphaPairsChanged))
count += 1
if entireSet: entireSet = False
elif (alphaPairsChanged == 0): entireSet = True
print ("iteration number: %d" % count)
if alphaPairsChanged==0 :
break
return sT.b,sT.alphas
#计算权值,w=sum(alpha*label*x)
def calcWs(alphas,dataArr,classLabels):
X = np.mat(dataArr); labelMat = np.mat(classLabels).transpose()
m,n = np.shape(X)
w = np.zeros((n,1))
for i in range(m):
w += np.multiply(alphas[i]*labelMat[i],X[i,:].T)
return w
def show(data,labels,w,b):
x1=[]
y1=[]
x2=[]
y2=[]
for i in range(len(labels)):
if labels[i] == 1:
x1.append(data[i,0])
y1.append(data[i,1])
else:
x2.append(data[i,0])
y2.append(data[i,1])
plt.scatter(x1,y1,edgecolors='r')
plt.scatter(x2,y2,edgecolors='k')
max_x = (np.max(data[:,0]))
min_x = (np.min(data[:,0]))
y_min_x = float(-(min_x*w[0]+b)/w[1])
y_max_x = float(-(max_x*w[0]+b)/w[1])
plt.plot([min_x, max_x], [y_min_x, y_max_x], '-g')
plt.show()
def testRbf(k1=1.3):
dataArr,labelArr = readData('testSetRBF.txt')
b,alphas = smoP(dataArr, labelArr, 200, 0.0001, 10000, ('rbf', k1))
datMat=np.mat(dataArr)
labelMat = np.mat(labelArr).transpose()
svInd=np.nonzero(alphas.A>0)[0]
sVs=datMat[svInd]
labelSV = labelMat[svInd];
print ("there are %d Support Vectors" % np.shape(sVs)[0])
m,n = np.shape(datMat)
errorCount = 0
for i in range(m):
kernelEval = kernelTrans(sVs,datMat[i,:],('rbf', k1))
predict=kernelEval.T * np.multiply(labelSV,alphas[svInd]) + b
if np.sign(predict)!=np.sign(labelArr[i]): errorCount += 1
print ("the training error rate is: %f" % (float(errorCount)/m))
dataArr,labelArr = readData('testSetRBF2.txt')
errorCount = 0
datMat= np.mat(dataArr)
labelMat = np.mat(labelArr).transpose()
m,n = np.shape(datMat)
for i in range(m):
kernelEval = kernelTrans(sVs,datMat[i,:],('rbf', k1))
predict=kernelEval.T * np.multiply(labelSV,alphas[svInd]) + b
if np.sign(predict)!=np.sign(labelArr[i]): errorCount += 1
print ("the test error rate is: %f" % (float(errorCount)/m))
def test():
dataArr,labelArr = readData('testSet.txt')
b,alphas = smoP(dataArr, labelArr, 0.6, 0.0001, 40, ('lin', 0))
ws = calcWs(alphas,dataArr,labelArr)
show(np.mat(dataArr),np.array(labelArr),np.array(ws),b)
def main():
pass
if __name__ == '__main__':
main()
testSet.txt数据集
3.542485 1.977398 -1
3.018896 2.556416 -1
7.551510 -1.580030 1
2.114999 -0.004466 -1
8.127113 1.274372 1
7.108772 -0.986906 1
8.610639 2.046708 1
2.326297 0.265213 -1
3.634009 1.730537 -1
0.341367 -0.894998 -1
3.125951 0.293251 -1
2.123252 -0.783563 -1
0.887835 -2.797792 -1
7.139979 -2.329896 1
1.696414 -1.212496 -1
8.117032 0.623493 1
8.497162 -0.266649 1
4.658191 3.507396 -1
8.197181 1.545132 1
1.208047 0.213100 -1
1.928486 -0.321870 -1
2.175808 -0.014527 -1
7.886608 0.461755 1
3.223038 -0.552392 -1
3.628502 2.190585 -1
7.407860 -0.121961 1
7.286357 0.251077 1
2.301095 -0.533988 -1
-0.232542 -0.547690 -1
3.457096 -0.082216 -1
3.023938 -0.057392 -1
8.015003 0.885325 1
8.991748 0.923154 1
7.916831 -1.781735 1
7.616862 -0.217958 1
2.450939 0.744967 -1
7.270337 -2.507834 1
1.749721 -0.961902 -1
1.803111 -0.176349 -1
8.804461 3.044301 1
1.231257 -0.568573 -1
2.074915 1.410550 -1
-0.743036 -1.736103 -1
3.536555 3.964960 -1
8.410143 0.025606 1
7.382988 -0.478764 1
6.960661 -0.245353 1
8.234460 0.701868 1
8.168618 -0.903835 1
1.534187 -0.622492 -1
9.229518 2.066088 1
7.886242 0.191813 1
2.893743 -1.643468 -1
1.870457 -1.040420 -1
5.286862 -2.358286 1
6.080573 0.418886 1
2.544314 1.714165 -1
6.016004 -3.753712 1
0.926310 -0.564359 -1
0.870296 -0.109952 -1
2.369345 1.375695 -1
1.363782 -0.254082 -1
7.279460 -0.189572 1
1.896005 0.515080 -1
8.102154 -0.603875 1
2.529893 0.662657 -1
1.963874 -0.365233 -1
8.132048 0.785914 1
8.245938 0.372366 1
6.543888 0.433164 1
-0.236713 -5.766721 -1
8.112593 0.295839 1
9.803425 1.495167 1
1.497407 -0.552916 -1
1.336267 -1.632889 -1
9.205805 -0.586480 1
1.966279 -1.840439 -1
8.398012 1.584918 1
7.239953 -1.764292 1
7.556201 0.241185 1
9.015509 0.345019 1
8.266085 -0.230977 1
8.545620 2.788799 1
9.295969 1.346332 1
2.404234 0.570278 -1
2.037772 0.021919 -1
1.727631 -0.453143 -1
1.979395 -0.050773 -1
8.092288 -1.372433 1
1.667645 0.239204 -1
9.854303 1.365116 1
7.921057 -1.327587 1
8.500757 1.492372 1
1.339746 -0.291183 -1
3.107511 0.758367 -1
2.609525 0.902979 -1
3.263585 1.367898 -1
2.912122 -0.202359 -1
1.731786 0.589096 -1
2.387003 1.573131 -1
testSetRBF.txt数据集
-0.214824 0.662756 -1.000000
-0.061569 -0.091875 1.000000
0.406933 0.648055 -1.000000
0.223650 0.130142 1.000000
0.231317 0.766906 -1.000000
-0.748800 -0.531637 -1.000000
-0.557789 0.375797 -1.000000
0.207123 -0.019463 1.000000
0.286462 0.719470 -1.000000
0.195300 -0.179039 1.000000
-0.152696 -0.153030 1.000000
0.384471 0.653336 -1.000000
-0.117280 -0.153217 1.000000
-0.238076 0.000583 1.000000
-0.413576 0.145681 1.000000
0.490767 -0.680029 -1.000000
0.199894 -0.199381 1.000000
-0.356048 0.537960 -1.000000
-0.392868 -0.125261 1.000000
0.353588 -0.070617 1.000000
0.020984 0.925720 -1.000000
-0.475167 -0.346247 -1.000000
0.074952 0.042783 1.000000
0.394164 -0.058217 1.000000
0.663418 0.436525 -1.000000
0.402158 0.577744 -1.000000
-0.449349 -0.038074 1.000000
0.619080 -0.088188 -1.000000
0.268066 -0.071621 1.000000
-0.015165 0.359326 1.000000
0.539368 -0.374972 -1.000000
-0.319153 0.629673 -1.000000
0.694424 0.641180 -1.000000
0.079522 0.193198 1.000000
0.253289 -0.285861 1.000000
-0.035558 -0.010086 1.000000
-0.403483 0.474466 -1.000000
-0.034312 0.995685 -1.000000
-0.590657 0.438051 -1.000000
-0.098871 -0.023953 1.000000
-0.250001 0.141621 1.000000
-0.012998 0.525985 -1.000000
0.153738 0.491531 -1.000000
0.388215 -0.656567 -1.000000
0.049008 0.013499 1.000000
0.068286 0.392741 1.000000
0.747800 -0.066630 -1.000000
0.004621 -0.042932 1.000000
-0.701600 0.190983 -1.000000
0.055413 -0.024380 1.000000
0.035398 -0.333682 1.000000
0.211795 0.024689 1.000000
-0.045677 0.172907 1.000000
0.595222 0.209570 -1.000000
0.229465 0.250409 1.000000
-0.089293 0.068198 1.000000
0.384300 -0.176570 1.000000
0.834912 -0.110321 -1.000000
-0.307768 0.503038 -1.000000
-0.777063 -0.348066 -1.000000
0.017390 0.152441 1.000000
-0.293382 -0.139778 1.000000
-0.203272 0.286855 1.000000
0.957812 -0.152444 -1.000000
0.004609 -0.070617 1.000000
-0.755431 0.096711 -1.000000
-0.526487 0.547282 -1.000000
-0.246873 0.833713 -1.000000
0.185639 -0.066162 1.000000
0.851934 0.456603 -1.000000
-0.827912 0.117122 -1.000000
0.233512 -0.106274 1.000000
0.583671 -0.709033 -1.000000
-0.487023 0.625140 -1.000000
-0.448939 0.176725 1.000000
0.155907 -0.166371 1.000000
0.334204 0.381237 -1.000000
0.081536 -0.106212 1.000000
0.227222 0.527437 -1.000000
0.759290 0.330720 -1.000000
0.204177 -0.023516 1.000000
0.577939 0.403784 -1.000000
-0.568534 0.442948 -1.000000
-0.011520 0.021165 1.000000
0.875720 0.422476 -1.000000
0.297885 -0.632874 -1.000000
-0.015821 0.031226 1.000000
0.541359 -0.205969 -1.000000
-0.689946 -0.508674 -1.000000
-0.343049 0.841653 -1.000000
0.523902 -0.436156 -1.000000
0.249281 -0.711840 -1.000000
0.193449 0.574598 -1.000000
-0.257542 -0.753885 -1.000000
-0.021605 0.158080 1.000000
0.601559 -0.727041 -1.000000
-0.791603 0.095651 -1.000000
-0.908298 -0.053376 -1.000000
0.122020 0.850966 -1.000000
-0.725568 -0.292022 -1.000000
testSetRBF2.txt数据集
0.676771 -0.486687 -1.000000
0.008473 0.186070 1.000000
-0.727789 0.594062 -1.000000
0.112367 0.287852 1.000000
0.383633 -0.038068 1.000000
-0.927138 -0.032633 -1.000000
-0.842803 -0.423115 -1.000000
-0.003677 -0.367338 1.000000
0.443211 -0.698469 -1.000000
-0.473835 0.005233 1.000000
0.616741 0.590841 -1.000000
0.557463 -0.373461 -1.000000
-0.498535 -0.223231 -1.000000
-0.246744 0.276413 1.000000
-0.761980 -0.244188 -1.000000
0.641594 -0.479861 -1.000000
-0.659140 0.529830 -1.000000
-0.054873 -0.238900 1.000000
-0.089644 -0.244683 1.000000
-0.431576 -0.481538 -1.000000
-0.099535 0.728679 -1.000000
-0.188428 0.156443 1.000000
0.267051 0.318101 1.000000
0.222114 -0.528887 -1.000000
0.030369 0.113317 1.000000
0.392321 0.026089 1.000000
0.298871 -0.915427 -1.000000
-0.034581 -0.133887 1.000000
0.405956 0.206980 1.000000
0.144902 -0.605762 -1.000000
0.274362 -0.401338 1.000000
0.397998 -0.780144 -1.000000
0.037863 0.155137 1.000000
-0.010363 -0.004170 1.000000
0.506519 0.486619 -1.000000
0.000082 -0.020625 1.000000
0.057761 -0.155140 1.000000
0.027748 -0.553763 -1.000000
-0.413363 -0.746830 -1.000000
0.081500 -0.014264 1.000000
0.047137 -0.491271 1.000000
-0.267459 0.024770 1.000000
-0.148288 -0.532471 -1.000000
-0.225559 -0.201622 1.000000
0.772360 -0.518986 -1.000000
-0.440670 0.688739 -1.000000
0.329064 -0.095349 1.000000
0.970170 -0.010671 -1.000000
-0.689447 -0.318722 -1.000000
-0.465493 -0.227468 -1.000000
-0.049370 0.405711 1.000000
-0.166117 0.274807 1.000000
0.054483 0.012643 1.000000
0.021389 0.076125 1.000000
-0.104404 -0.914042 -1.000000
0.294487 0.440886 -1.000000
0.107915 -0.493703 -1.000000
0.076311 0.438860 1.000000
0.370593 -0.728737 -1.000000
0.409890 0.306851 -1.000000
0.285445 0.474399 -1.000000
-0.870134 -0.161685 -1.000000
-0.654144 -0.675129 -1.000000
0.285278 -0.767310 -1.000000
0.049548 -0.000907 1.000000
0.030014 -0.093265 1.000000
-0.128859 0.278865 1.000000
0.307463 0.085667 1.000000
0.023440 0.298638 1.000000
0.053920 0.235344 1.000000
0.059675 0.533339 -1.000000
0.817125 0.016536 -1.000000
-0.108771 0.477254 1.000000
-0.118106 0.017284 1.000000
0.288339 0.195457 1.000000
0.567309 -0.200203 -1.000000
-0.202446 0.409387 1.000000
-0.330769 -0.240797 1.000000
-0.422377 0.480683 -1.000000
-0.295269 0.326017 1.000000
0.261132 0.046478 1.000000
-0.492244 -0.319998 -1.000000
-0.384419 0.099170 1.000000
0.101882 -0.781145 -1.000000
0.234592 -0.383446 1.000000
-0.020478 -0.901833 -1.000000
0.328449 0.186633 1.000000
-0.150059 -0.409158 1.000000
-0.155876 -0.843413 -1.000000
-0.098134 -0.136786 1.000000
0.110575 -0.197205 1.000000
0.219021 0.054347 1.000000
0.030152 0.251682 1.000000
0.033447 -0.122824 1.000000
-0.686225 -0.020779 -1.000000
-0.911211 -0.262011 -1.000000
0.572557 0.377526 -1.000000
-0.073647 -0.519163 -1.000000
-0.281830 -0.797236 -1.000000
-0.555263 0.126232 -1.000000