使用Spark进行决策树二元分类

关于使用决策树进行分类,我想大家都是再熟悉不过了,这里,我们直接讲解如何使用Spar进行决策树二分类。

1、数据介绍

这里我们使用Kaggle上一个名叫StumbleUpon Evergreen Classification Challenge的比赛数据,StumbleUpon 是一个个性化搜索引擎,会根据用户的兴趣和评分推荐感兴趣的网页。网页分为两种,一是暂时性的,另一种是长青的,我们需要对这些网页进行一个二元分类。

数据下载的地址为:
https://www.kaggle.com/c/stumbleupon/data
我们使用train.tsv 和 test.tsv

我们可以查看数据的格式:

数据说明如下:

字段 说明
0-2列 网址,网页ID和样板文字
3-25列 特征字段,其中第3列是页面的类别(并不是最后的分类,而是说页面是体育类,娱乐类等等),属于String型,其他列是数值型,存在一定的数据缺失
26列 类别字段,1代表长青网页,0代表暂时网页

2、数据处理

从上面的说明来看,我们主要根据3-25列的特征来预测网页的类别,我们要做如下的处理,首先第3列是网页类别,我们需要将其转换为数值型特征,比如将其转换为one-hot类型,其次,缺失值进行处理。最后,在一些分类算法中,可能需要对数据进行归一化,但是在决策树模型中是不需要的。

2.1 数据导入

我们需要创建一个sparkSession,读入数据:

val spark = SparkSession.builder().appName("BinaryDecisionTree").getOrCreate()
val rawDataWithHeader = spark.sparkContext.textFile("data/classify/train.tsv")

导入的数据的第一行是字段名,是我们不需要的,所以需要将这一行去掉:

val rawData = rawDataWithHeader.mapPartitionsWithIndex{(idx,iter)=>if(idx==0) iter.drop(1) else iter}

随后,读入的一行是一个字符串,我们使用map 方法将其转换为数组:

val lines = rawData.map(_.split("\t"))

2.2 转换为训练所需数据

模型训练所需要的数据类型是RDD[LabeledPoint],我们的RDD的每一行是一个LabelPoint,LabelPoint由label和features所组成。我们直接通过代码来看是如何一步步将现有数据转换为RDD[LabeledPoint]的:

val categoriesMap = lines.map(fields=>fields(3)).distinct.collect.zipWithIndex.toMap
    val labeledPointRDD = lines.map { fields =>
        val trFields = fields.map(_.replaceAll("\"",""))
        val categoryFeaturesArray = Array.ofDim[Double](categoriesMap.size)
        val index = categoriesMap(fields(3))
        categoryFeaturesArray(index) = 1
        val numericalFeatures = trFields.slice(4,fields.size-1).map(d=>if(d=="?") 0.0 else d.toDouble)
        val label = trFields(fields.size-1).toInt
        LabeledPoint(label,Vectors.dense(categoryFeaturesArray++numericalFeatures))
    }

首先,我们创建了一个页面类别的对照表,类型为Map[String,Int],随后对lines中的每一行进行处理,首先去掉了其中的",随后将类别转换为one-hot类型,再与剩下的数字类特征进行拼接,在这之中,我们把缺失值替换为0.

2.3分割数据

接下来,我们就需要将数据分为训练集,验证集和测试集:

val Array(trainData,validationData,testData) = labeledPointRDD.randomSplit(Array(0.8,0.1,0.1))
return (trainData,validationData,testData,categoriesMap)

3、模型训练

我们主要使用DecisionTree.trainClassifier方法进行模型的训练,该方法需要提供的几个参数依次是:

参数 说明
input:RDD[LabeledPoint] 训练数据
numClasses:Int 分类数目
categoricalFeaturesInfo :Map[Int,Int] Map[Int,Int]()
Impurity:String 选用信息熵还是基尼系数进行评估,gini代表基尼系数,entropy代表信息熵
maxDepth:Int 决策树的最大深度
maxBins:Int 决策树每一个结点的最大分支数目
val model = DecisionTree.trainClassifier(trainData,2,Map[Int,Int](),impurity,maxDepth,maxBins)

4、模型评估

这里我们主要是用AUC来对模型进行评估:

val scoreAndLabel = validationData.map{data =>
      val score = model.predict(data.features)
      (score,data.label)
    }
    val Metrics = new BinaryClassificationMetrics(scoreAndLabel)
    val AUC = Metrics.areaUnderROC()
    AUC

5、模型预测

val predict = model.predict(Features).toInt
      var predictDesc = {predict match{
        case 0=>"暂时性网页(ephemeral)";
        case 1=>"长青网页(evergreen)"
      }}
      println("网址:"+url+"==>"+predictDesc)

6、完整代码

package mllib.DecisionTree

import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.sql.SparkSession
import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.tree.DecisionTree
import org.joda.time.format._
import org.joda.time._
import org.joda.time.Duration



object BinaryDecesionTree {
  def setLogger ={
    Logger.getLogger("org").setLevel(Level.OFF)
    Logger.getLogger("com").setLevel(Level.OFF)
    System.setProperty("spark.ui.showConsoleProgress","false")
    Logger.getRootLogger().setLevel(Level.OFF)
  }

  def PrepareData(spark:SparkSession):(RDD[LabeledPoint],RDD[LabeledPoint],RDD[LabeledPoint],Map[String,Int])={

    val rawDataWithHeader = spark.sparkContext.textFile("data/classify/train.tsv")
    val rawData = rawDataWithHeader.mapPartitionsWithIndex{(idx,iter)=>if(idx==0) iter.drop(1) else iter}
    val lines = rawData.map(_.split("\t"))
    println("共计:"+lines.count()+"条")

    val categoriesMap = lines.map(fields=>fields(3)).distinct.collect.zipWithIndex.toMap
    val labeledPointRDD = lines.map { fields =>
        val trFields = fields.map(_.replaceAll("\"",""))
        val categoryFeaturesArray = Array.ofDim[Double](categoriesMap.size)
        val index = categoriesMap(fields(3))
        categoryFeaturesArray(index) = 1
        val numericalFeatures = trFields.slice(4,fields.size-1).map(d=>if(d=="?") 0.0 else d.toDouble)
        val label = trFields(fields.size-1).toInt
        LabeledPoint(label,Vectors.dense(categoryFeaturesArray++numericalFeatures))
    }

    val Array(trainData,validationData,testData) = labeledPointRDD.randomSplit(Array(0.8,0.1,0.1))
    return (trainData,validationData,testData,categoriesMap)
  }

  def trainModel(trainData:RDD[LabeledPoint],impurity:String,maxDepth:Int,maxBins:Int):(DecisionTreeModel,Double)={
    val startTime = new DateTime()
    val model = DecisionTree.trainClassifier(trainData,2,Map[Int,Int](),impurity,maxDepth,maxBins)
    val endTime = new DateTime()
    val duration = new Duration(startTime,endTime)
    (model,duration.getMillis())
  }

  def evaluateModel(model:DecisionTreeModel,validationData:RDD[LabeledPoint]):Double ={
    val scoreAndLabel = validationData.map{data =>
      val score = model.predict(data.features)
      (score,data.label)
    }
    val Metrics = new BinaryClassificationMetrics(scoreAndLabel)
    val AUC = Metrics.areaUnderROC()
    AUC
  }

  def trainEvaluate(trainData:RDD[LabeledPoint],validationData:RDD[LabeledPoint]):DecisionTreeModel= {
    println("训练开始>>>>")
    val (model,time) = trainModel(trainData,"entropy",10,10)
    println("训练完成,所需时间:"+time+"毫秒")
    val AUC = evaluateModel(model,validationData)
    println("评估结果AUC="+AUC)
    return model

  }

  def PredictData(spark: SparkSession, model: DecisionTreeModel, categoriesMap: Map[String, Int])={
    val rawDataWithHeader = spark.sparkContext.textFile("data/classify/test.tsv")
    val rawData = rawDataWithHeader.mapPartitionsWithIndex{(idx,iter)=>if(idx==0) iter.drop(1) else iter}
    val lines = rawData.map(_.split("\t"))
    println("共计:"+lines.count.toString()+"条")
    val dataRDD = lines.take(20).map{fields =>
      val trFields = fields.map(_.replaceAll("\"",""))
      val categoryFeaturesArray = Array.ofDim[Double](categoriesMap.size)
      val index = categoriesMap(fields(3))
      categoryFeaturesArray(index) = 1
      val numericalFeatures = trFields.slice(4,fields.size).map(d => if(d=="?") 0.0 else d.toDouble)
      val label = 0
      val url = trFields(0)
      val Features = Vectors.dense(categoryFeaturesArray++numericalFeatures)
      val predict = model.predict(Features).toInt
      var predictDesc = {predict match{
        case 0=>"暂时性网页(ephemeral)";
        case 1=>"长青网页(evergreen)"
      }}
      println("网址:"+url+"==>"+predictDesc)
    }
  }
  def main(args:Array[String]) = {
    setLogger
    val spark = SparkSession.builder().appName("BinaryDecisionTree").getOrCreate()
    println("RunDecisionTreeBinary")
    println("===============数据准备阶段===============")
    val (trainData,validationData,testData,categoriesMap) = PrepareData(spark)
    trainData.persist();validationData.persist();testData.persist();
    println("===============训练评估阶段===============")
    val model = trainEvaluate(trainData,validationData)
    println("===============测试阶段===============")
    val auc = evaluateModel(model,testData)
    println("===============预测数据===============")
    PredictData(spark,model,categoriesMap)


  }
}

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,686评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,668评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,160评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,736评论 1 284
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,847评论 6 386
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,043评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,129评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,872评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,318评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,645评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,777评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,470评论 4 333
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,126评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,861评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,095评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,589评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,687评论 2 351

推荐阅读更多精彩内容