机器学习-K近邻算法

k-近邻分类算法,即物以类聚的思想,通过已知分类中的点和未知分类的点距离最近的前k个点的分类来预测未知点的分类。
机器学习-K近邻算法_第1张图片

机器学习-K近邻算法_第2张图片

kNN.py

from numpy import  *
import operator
def createDataSet():
    group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
    labels = ['A','A','B','B']
    return group,labels

def classify0(inX,dataSet,labels,k):
    dataSetSize = dataSet.shape[0]
    diffMat = tile(inX,(dataSetSize,1)) - dataSet
    sqDiffMat = diffMat**2
    sqDistances = sqDiffMat.sum(axis=1)
    distances = sqDistances**0.5
    sortedDistIndicies = distances.argsort()
    classCount = {}
    for i in range(k):
        voteIlabel = labels[sortedDistIndicies[i]]
        classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

def file2matrix(filename):
    fr = open(filename)
    arrayOLines = fr.readlines()
    numberOfLines = len(arrayOLines)
    returnMat = zeros((numberOfLines,3))
    classLabelVector = []
    index = 0
    for line in arrayOLines:
        line = line.strip()
        listFromLine = line.split('\t')
        returnMat[index,:] = listFromLine[0:3]
        classLabelVector.append(int(listFromLine[-1]))
        index += 1
    return returnMat,classLabelVector

def autoNorm(dataSet):
    minVals = dataSet.min(0)
    maxVals = dataSet.max(0)
    ranger = maxVals - minVals
    normDataSet = (dataSet - minVals)/ranger
    return normDataSet,ranger,minVals

kNNTest.py

from numpy import array

import  kNN
# pip3 install matplotlib
import matplotlib
import matplotlib.pyplot as plt
group,labels = kNN.createDataSet()
print(group)
print(labels)
sortedClassCount = kNN.classify0([0,0],group,labels,3)
print(sortedClassCount)
# reload(kNN) 新加载已经导入的 kNN 模块,确保使用的是该模块的最新版本。常用于开发调试阶段,当模块代码有改动时,无需重启解释器即可生效。
datingDataMat,datingLabels = kNN.file2matrix('datingTestSet2.txt')
print(datingDataMat,datingLabels)
fig=plt.figure()
ax=fig.add_subplot(111)
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))
plt.show()

datingTestSet2.txt

40920	8.326976	0.953952	3
14488	7.153469	1.673904	2
26052	1.441871	0.805124	1

http:/www.manning.com/MachineLearninginAction
http://cnblogs.com/rougan/p/9295659.html

你可能感兴趣的:(python,机器学习,机器学习,近邻算法,人工智能)