Flink实现高斯朴素贝叶斯

Flink实现高斯朴素贝叶斯

在之前的文章中提到了多项式的朴素贝叶斯,在spark的ML里也实现了多项式的朴素贝叶斯和伯努利朴素贝叶斯,在实际情况当中我们处理的变量除了离散型,还有连续型。在对这类数据使用朴素贝叶斯的时候,我们通常会假定变量服从高斯分布。然后再进行概率计算。

Flink代码实现

在这里,使用的数据集是鸢尾花数据集

//  创建一个鸢尾花数据类
/**
  * Created by WZZC on 2019/4/26
  **/
case class Iris(
                 species: String,
                 features: Vector[Double]
               ) {

}

object Iris {
  def apply(irisInfo: String): Iris = {
    val splited = irisInfo.split(",")
    new Iris(
      splited.head,
      splited.tail.map(_.toDouble).toVector
    )

  }
}

flink DataSet

import bean.Iris
import breeze.stats.distributions.Gaussian
import org.apache.flink.api.scala.ExecutionEnvironment
import org.apache.flink.api.scala._



 val env = ExecutionEnvironment.getExecutionEnvironment

// 读取鸢尾花数据并解析为Iris对象作为训练集
 val irisDataSet: DataSet[Iris] = env.readTextFile("F:\\DataSource\\iris.csv")
      .map(Iris.apply(_))


 // 样本量大小 样本有150个 分为三类 Iris-setosa、Iris-versicolor、Iris-virginica,各占50条
    val sampleSize: Long = irisDataSet.count()

    // 分组
    val grouped = irisDataSet.groupBy("species")

    //计算正态分布的概率密度
    def pdf(x: Double, mu: Double, sigma2: Double) = {
      Gaussian(mu, math.sqrt(sigma2)).pdf(x)
    }

    val reduced: Seq[(Double, Vector[(Double, Double)], String)] = grouped
      .reduceGroup(elmes => {

        // compute mu and sigma
        val itris = elmes.toSeq
        val features = itris.map(_.features)
        
        val num = features.length.toDouble

        //pprob 计算先验概率
        val pprob = num / sampleSize

        val folded: Vector[Double] = features.tail.fold(features.head) { (v1, v2) =>
          v1.zip(v2).map(tp => tp._1 + tp._2)
        }

        // 计算特征的均值向量
        val muVec = folded.map(_ / num)

        // 计算特征的方差向量
        val ssr = features.map(feature => {
          feature.zip(muVec).map { case (v1, v2) => math.pow(v1 - v2, 2) }
        })

        val sigma2 = ssr.tail.fold(ssr.head) { (v1, v2) =>
          v1.zip(v2).map(tp => tp._1 + tp._2)
        }
          .map(_ / num)

        (pprob, muVec.zip(sigma2), itris.head.species)
      })
      .collect()

  // 我们还是以训练集作为测试
   val result = irisDataSet.map(iris => {

      val prob = reduced.map(probs => {

        // 计算条件概率的乘积
        val cprob = probs._2.zip(iris.features).map {
          case ((mu, sigma2), x) => pdf(x, mu, sigma2)
        }.product

        (probs._1 * cprob, probs._3)

      })
        .maxBy(_._1)

      (iris,prob)

    })

// 结果查看 ,我们只打印出分类错误的记录
 result.filter(tp=> tp._1.species != tp._2._2).print()
结果查看,最终只有6个样本被错误的分类,正确率达到96%
(Iris(Iris-versicolor,Vector(6.9, 3.1, 4.9, 1.5)),(0.0174740386794776,Iris-virginica))
(Iris(Iris-versicolor,Vector(5.9, 3.2, 4.8, 1.8)),(0.03325849601325018,Iris-virginica))
(Iris(Iris-versicolor,Vector(6.7, 3.0, 5.0, 1.7)),(0.08116596041911653,Iris-virginica))
(Iris(Iris-virginica,Vector(4.9, 2.5, 4.5, 1.7)),(0.007235803888178925,Iris-versicolor))
(Iris(Iris-virginica,Vector(6.0, 2.2, 5.0, 1.5)),(0.020461656241466127,Iris-versicolor))
(Iris(Iris-virginica,Vector(6.3, 2.8, 5.1, 1.5)),(0.05952315832865549,Iris-versicolor))

Spark代码实现

import breeze.stats.distributions.Gaussian
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.linalg.DenseVector
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.stat.Summarizer.{
  mean => summaryMean,
  variance => summaryVar
}
import org.apache.spark.sql.functions.udf


 def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName(s"${this.getClass.getSimpleName}")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    val sc = spark.sparkContext

    // 数据加载
    val irisData = spark.read
      .option("header", true)
      .option("inferSchema", true)
//      .csv("F:\\DataSource\\person_naiveBayes.csv")
      .csv("F:\\DataSource\\iris.csv")

    val rolluped = irisData.rollup($"class").count()

    // 样本量
    val sampleSize = rolluped.where($"class".isNull).head().getAs[Long](1)

    // 计算先验概率
    val pprobMap = rolluped
      .where($"class".isNotNull)
      .withColumn("pprob", $"count" / sampleSize)
      .collect()
      .map(row => {
        row.getAs[String]("class") -> row.getAs[Double]("pprob")
      })
      .toMap


    val schema = irisData.schema
    val fts = schema.filterNot(_.name == """class""").map(_.name).toArray

    // 数据转换
    val amountVectorAssembler: VectorAssembler = new VectorAssembler()
      .setInputCols(fts)
      .setOutputCol("features")

    val ftsDF = amountVectorAssembler
      .transform(irisData)
      .select("class", "features")
 
    // 聚合计算:计算特征均值向量和方差向量
    val irisAggred = ftsDF
      .groupBy($"class")
      .agg(
        summaryMean($"features") as "mfts",
        summaryVar($"features") as "vfts"
      )
 

    val cprobs: Array[(Array[(Double, Double)], String)] = irisAggred
      .collect()
      .map(row => {
        val cl = row.getAs[String]("class")
        val mus = row.getAs[DenseVector]("mfts").toArray
        val vars = row.getAs[DenseVector]("vfts").toArray
        (mus.zip(vars), cl)
      })
 

    def pdf(x: Double, mu: Double, sigma2: Double) = {
      Gaussian(mu, math.sqrt(sigma2)).pdf(x)
    }

    val predictUDF = udf((vec: DenseVector) => {
      cprobs
        .map(tp => {
          val tuples: Array[((Double, Double), Double)] = tp._1.zip(vec.toArray)
          val cp: Double = tuples.map {
            case ((mu, sigma), x) => pdf(x, mu, sigma)
          }.product
          val pprob: Double = pprobMap.getOrElse(tp._2, 0)
          (cp * pprob, tp._2)
        })
        .maxBy(_._1)
        ._2
    })

    val predictDF = ftsDF
      .withColumn("predict", predictUDF($"features"))
  

  predictDF.where($"class" =!= $"predict").show(truncate = false)
  
    spark.stop()
  }
结果查看,与Flink计算的结果一致
+---------------+-----------------+---------------+
|class          |features         |predict        |
+---------------+-----------------+---------------+
|Iris-versicolor|[6.9,3.1,4.9,1.5]|Iris-virginica |
|Iris-versicolor|[5.9,3.2,4.8,1.8]|Iris-virginica |
|Iris-versicolor|[6.7,3.0,5.0,1.7]|Iris-virginica |
|Iris-virginica |[4.9,2.5,4.5,1.7]|Iris-versicolor|
|Iris-virginica |[6.0,2.2,5.0,1.5]|Iris-versicolor|
|Iris-virginica |[6.3,2.8,5.1,1.5]|Iris-versicolor|
+---------------+-----------------+---------------+

我们可以对比之前的贝叶斯判别分析。在之前的贝叶斯判别中使用同样的鸢尾花数据集,分类的准确率是98%。因为朴素贝叶斯假定了变量之间是独立的,所以在计算上实现了很多简化(不用计算协方差矩阵)。但是分类效果仅仅比普通的贝叶斯判别差一点点。

你可能感兴趣的:(Flink实现高斯朴素贝叶斯)