pyspark之二分类决策树算法

# 首先还是设定数据目录,取决于你的pyspark是什么方式启动的
global Path
if sc.master[0:5] == 'local':
    Path = 'file:/home/swt/pythonwork/PythonProject/'
else:
    Path = "hdfs://localhost:9000/user/swt/"
import numpy as np
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree

# 加载数据
print('load data...')
rawDataWithHeader = sc.textFile(Path+'data/train.tsv')
header = rawDataWithHeader.first()
rawData=rawDataWithHeader.filter(lambda x:x !=header)
rData=rawData.map(lambda x:x.replace("\"",""))
lines = rData.map(lambda x:x.split("\t"))

def extract_features(fleld,catedoriesMap,featureEnd):
    categoryidx = catedoriesMap[fleld[3]]
    categoryfeatures = np.zeros(len(catedoriesMap))
    categoryfeatures[categoryidx] = 1
    numericalFeatures=[convert_float(fleld) for fleld in fleld[4:featureEnd]]
    return np.concatenate((categoryfeatures,numericalFeatures))

def convert_float(x):
    return (0 if x=="?" else float(x))

def extract_label(field):
    label = field[-1]
    return float(label)

categoriesMap=lines.map(lambda flelds:flelds[3]).distinct().zipWithIndex().collectAsMap()

# pyspark算法要求数据类型是LabeledPoint格式的,所以所有数据都需要转换一下
labelpointRDD = lines.map(lambda r:LabeledPoint(extract_label(r), extract_features(r,categoriesMap,len(r)-1)))
labelpointRDD.take(1)
[LabeledPoint(0.0, [1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.789131,2.055555556,0.676470588,0.205882353,0.047058824,0.023529412,0.443783175,0.0,0.0,0.09077381,0.0,0.245831182,0.003883495,1.0,1.0,24.0,0.0,5424.0,170.0,8.0,0.152941176,0.079129575])]
# 将数据切分成训练集,验证集,测试集
(trainData, vaildationData, testData) = labelpointRDD.randomSplit([8, 1, 1]) 
# 持久化数据
trainData.persist()
vaildationData.persist()
testData.persist()
# 开始训练
model = DecisionTree.trainClassifier(trainData,numClasses=2,categoricalFeaturesInfo={}, impurity='entropy',maxDepth=5,maxBins=5)

def predictData(sc,model,categoriesMap):
    print('load data...')
    rawDataWithHeader=sc.textFile(Path+'data/test.tsv')
    header = rawDataWithHeader.first()
    rawData = rawDataWithHeader.filter(lambda x:x != header)
    rData=rawData.map(lambda x:x.replace("\"",""))
    lines=rData.map(lambda x:x.split('\t'))
    print("sum is "+str(lines.count()))
    dataRDD=lines.map(lambda r:(r[0],extract_features(r,categoriesMap,len(r))))
    DescDict={
        0:'ephemeral',
        1:'evergreen'
    }
    for data in dataRDD.take(10):
        predictResult=model.predict(data[1])
        print(str(data[0])+'\n'+'====>predict'+str(predictResult)+' is '+str(DescDict[predictResult])+'\n')
print("=======predict========")
predictData(sc,model,categoriesMap)
=======predict========
load data...
sum is 3171
http://www.lynnskitchenadventures.com/2009/04/homemade-enchilada-sauce.html
====>predict1.0 is evergreen

http://lolpics.se/18552-stun-grenade-ar
====>predict0.0 is ephemeral

http://www.xcelerationfitness.com/treadmills.html
====>predict0.0 is ephemeral

http://www.bloomberg.com/news/2012-02-06/syria-s-assad-deploys-tactics-of-father-to-crush-revolt-threatening-reign.html
====>predict0.0 is ephemeral

http://www.wired.com/gadgetlab/2011/12/stem-turns-lemons-and-limes-into-juicy-atomizers/
====>predict0.0 is ephemeral

http://www.latimes.com/health/boostershots/la-heb-fat-tax-denmark-20111013,0,2603132.story
====>predict0.0 is ephemeral

http://www.howlifeworks.com/a/a?AG_ID=1186&cid=7340ci
====>predict0.0 is ephemeral

http://romancingthestoveblog.wordpress.com/2010/01/13/sweet-potato-ravioli-with-lemon-sage-brown-butter-sauce/
====>predict1.0 is evergreen

http://www.funniez.net/Funny-Pictures/turn-men-down.html
====>predict0.0 is ephemeral

http://youfellasleepwatchingadvd.com/
====>predict0.0 is ephemeral
# 预测查看AUC指数
score=model.predict(vaildationData.map(lambda p:p.features))
scoreAndLabels=score.zip(vaildationData.map(lambda p:p.label))

from pyspark.mllib.evaluation import BinaryClassificationMetrics
metrics = BinaryClassificationMetrics(scoreAndLabels)
print("AUC="+str(metrics.areaUnderROC))
AUC=0.6577085453395308

你可能感兴趣的:(虚拟机+大数据)