GraphX Label Propagation算法改进

label propagation算法介绍

标签传播算法(label propagation)的核心思想非常简单:相似的数据应该具有相同的label。LP算法包括两大步骤:1)构造关系网;2)标签传播。

  • 算法具体步骤如下:
    1、初始时,给每个节点一个唯一的标签;
    2、每个节点使用其邻居节点的标签中最多的标签来更新自身的标签。
    3、反复执行步骤2,直到每个节点的标签都不再发生变化为止。

  • 一次迭代过程中一个节点标签的更新可以分为同步和异步两种。所谓同步更新,即节点z在第t次迭代的label依据于它的邻居节点在第t-1次迭代时所得的label;异步更新,即节点z在第t次迭代的label依据于第t次迭代已经更新过label的节点和第t次迭代未更新过label的节点在第t-1次迭代时的label。

graphX自带LP算法的缺陷

1、边权重信息不参与计算过程;
2、标签传播结果存在震荡的问题(震荡问题是所有基于BSP模式的框架普遍存在的问题)

关于graphx及BSP可见我另一篇文章 //www.greatytc.com/p/7190123ad329

边权重与无向图支持的改造

  • 基于pregel接口,重新实现了一套传播sendMessage和mergeMessage方法
def sendMessage(e: EdgeTriplet[VertexId, Int]): Iterator[(VertexId, Map[VertexId, Long])] = {
    Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))
  }
def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long])
  : Map[VertexId, Long] = {
    (count1.keySet ++ count2.keySet).map { i =>
      val count1Val = count1.getOrElse(i, 0L)
      val count2Val = count2.getOrElse(i, 0L)
      i -> (count1Val + count2Val)
    }(collection.breakOut)
  }

标签传播震荡问题改造

1、初始化每个节点属性信息,先给每个节点分配不重复标签。如,节点1对应标签1,节点i对应标签i;

2、N个节点,同步找到对应节点邻居,获取此节点邻居标签,找到出现权重最高的标签,若权重最高的标签不止一个,则选择标签值较大的标签赋值给当前节点;

3、若本轮标签重标记后,节点标签不再变化(或者达到设定的最大迭代次数),则迭代停止,否则重复第2步。迭代结果即为RS0;

4、当第3步结束后,以其结果RS0作为节点初始化信息,重新初始化每个节点属性信息,并从第2步开始,再分别迭代1轮、2轮、3轮,结果分别存为 RS1、RS2和RS3;

5、综合RS0、RS1、RS2和RS3的结果,得到最终每个节点的标签结果。如,节点i在RS0、RS1、RS2和RS3中的标签信息分别为(a、b、c、d),选择其中计数最多的标签作为节点i的最终结果,若计数最多的标签不止一个,则选择标签值最大的标签作为节点i最终的标签。

6、至此,label propagation算法结束,每个节点获得的标签即为其最终归属的cluster的id,聚类结束。

效果对比(demo数据)

graphx自带label propagation
  • demo数据展示(边权重表示点之间的亲密度)


    image.gif
  • 期望的聚类结果
image.gif
  • graphx自带LPA聚类结果(共分成4个cluster,不同颜色标注)
image.gif
  • 改进算法的聚类结果


    image.gif

效果对比(通过wifi连接获取的关系数据)

1、外卖标签,数据集中该标签占比0.3965。数据集共23137人。训练集16195人,其中带标签6451人;测试集6942人。其中带标签2723人。

a、graphx自带lp:召回率0.0823,精确率0.5450
b、pregel实现改进版lp:召回率0.2281,精确率0.4909

屏幕快照 2019-08-22 下午8.21.31.png

2、学前教育,数据集中该标签占比0.0281。数据集共23137人。训练集16195人,其中带标签462人;测试集6942人。其中带标签188人。

a、graphx自带lp:召回率0.0,精确率0.0
b、pregel实现改进版lp:召回率0.0426,精确率0.0952

屏幕快照 2019-08-22 下午8.21.58.png

3、炒股,数据集中该标签占比0.2192。数据集共23137人。训练集16195人,其中带标签3499人;测试集6942人。其中带标签1572人。

a、graphx自带lp:召回率0.0204,精确率0.3721
b、pregel实现改进版lp:召回率0.1501,精确率0.3940

屏幕快照 2019-08-22 下午8.22.24.png

4、游戏付费意愿用户,数据集中该标签占比0.1312。数据集共23137人。训练集16195人,其中带标签2137人;测试集6942人。其中带标签898人。

a、graphx自带lp:召回率0.0267,精确率0.2857
b、pregel实现改进版lp:召回率0.1292,精确率0.2736

屏幕快照 2019-08-22 下午8.22.47.png

5、35岁+标签,数据集中该标签占比0.3227。数据集共23137人。训练集16195人,其中带标签5204人;测试集6942人。其中带标签2262人。

a、graphx自带lp:召回率0.0469,精确率0.4953
b、pregel实现改进版lp:召回率0.2604,精确率0.5285


屏幕快照 2019-08-22 下午8.23.10.png

完整代码如下(scala)

package Graph.LPA

import org.apache.spark.graphx._
import org.apache.spark._
import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.LongType


object LPARevolution {
  def sendMessage(e: EdgeTriplet[VertexId, Int]): Iterator[(VertexId, Map[VertexId, Long])] = {
    Iterator((e.srcId, Map(e.dstAttr -> e.attr)), (e.dstId, Map(e.srcAttr -> e.attr)))
  }

  def mergeMessage(count1: Map[VertexId, Long], count2: Map[VertexId, Long])
  : Map[VertexId, Long] = {
    (count1.keySet ++ count2.keySet).map { i =>
      val count1Val = count1.getOrElse(i, 0L)
      val count2Val = count2.getOrElse(i, 0L)
      i -> (count1Val + count2Val)
    }(collection.breakOut)
  }

  // 更新点属性
  def vertexProgram(vid: VertexId, attr: Long, message: Map[VertexId, Long]): VertexId = {
    if(message.isEmpty){
      attr
    }
    else{
      //            print(vid)
      //            println(" 接收到的消息:   ")
      //            println(message)
      //            println("最终选择的是:")
      //            println(message.maxBy(_._2)._1)
      message.maxBy(_._2)._1  // 按照计数排序,然后取第一个
    }

  }


  def main(args: Array[String]): Unit = {
    val conf = new SparkConf()
    val sc = new SparkContext("yarn","lpa-revolution",conf)
    val hql = new HiveContext(sc)

    // 获取边数据gid,usertags,wifimac,ssid,geohash,day-int
    val edges = hql.sql("select cast(src as bigint), cast(dst as bigint)," +
      " cast(weight as int) from yangy.graph_edge_table_3day_zoom_weight_hz").rdd.
      map(row => Edge(row(0).asInstanceOf[Long], row(1).asInstanceOf[Long], row(2).asInstanceOf[Int]))

    // 获取点数据id_2_label_table
    val users = hql.sql("select cast(id as bigint), user_tags from yangy.id_2_label_table_3day_zoom_weight_hz").
      rdd.map(row => (row(0).asInstanceOf[Long], row(1)))

//    val edges = sc.textFile("file:///home/yangy/data/xh_edge_20190530_8day_1_0.txt").
//      map{line =>
//        val fields = line.split(" ")
//        (Edge(fields(0).toLong, fields(1).toLong, fields(2).toInt))
//      }
//
//    val users = sc.textFile("file:///home/yangy/data/xh_vertex_with_label_1_0.txt").
//      map { line =>
//        val fields = line.split(" ")
//        (fields(0).toLong, fields(1).toLong)
//      }

    val graph = Graph(vertices = users, edges = edges)

    // 图初始化
    val initGraph = graph.mapVertices { case (vid, attr) => vid }

    // 初始化msg
    val initialMessage = Map[VertexId, Long]()

    println("迭代结果:")

    // 分水岭,开始解决社区震荡&孤立点问题
    // ----------------------------------  迭代多轮  -------------------------------------
    val cluster1 = Pregel(initGraph, initialMessage, maxIterations = 100, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)


    // =====================================================================
    // 优雅代码的核心部分,基于前面的结果初始化新的图
    // 利用前面迭代结果重新初始化图
    // 以此结果作为基础,后续在此基础上继续迭代
    val users_trans = cluster1.vertices
    val graph_trans = Graph(vertices = users_trans, edges = edges)
    val initGraph_trans = graph_trans.mapVertices { case (vid, attr) => attr}
    // ======================================================================


    // 在基础数据上,额外迭代的轮数
    val cluster2 = Pregel(initGraph_trans, initialMessage, maxIterations = 1, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)

    val cluster3 = Pregel(initGraph_trans, initialMessage, maxIterations = 2, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)

    val cluster4 = Pregel(initGraph_trans, initialMessage, maxIterations = 3, activeDirection = EdgeDirection.Either)(
      vprog = vertexProgram,
      sendMsg = sendMessage,
      mergeMsg = mergeMessage)


    // 构建label propagation结果dataframe
    val colNames = "id,group_id"
    val schema = StructType(colNames.split(",").map(column => StructField(column, LongType)))

    // 获取每个id的分组信息,字段名是id, group_id
    val groupDf1 = hql.createDataFrame(cluster1.vertices.map(x=> Row(x._1, x._2)), schema)
    val groupDf2 = hql.createDataFrame(cluster2.vertices.map(x=> Row(x._1, x._2)), schema)
    val groupDf3 = hql.createDataFrame(cluster3.vertices.map(x=> Row(x._1, x._2)), schema)
    val groupDf4 = hql.createDataFrame(cluster4.vertices.map(x=> Row(x._1, x._2)), schema)


    val group_union_df = groupDf1.unionAll(groupDf2).unionAll(groupDf3).unionAll(groupDf4)

    // 选取合适的group_id,避免社区震荡
    group_union_df.registerTempTable("group_union_table")

    // 获取不震荡的group归属信息
    val group_no_swing = hql.sql(
      """
        |select t2.id as id,
        |       t2.group_id as group_id
        |from
        |(
        |   select t1.id as id,
        |          t1.group_id as group_id,
        |          rank() over (partition by t1.id order by t1.cnt, t1.group_id desc) as rank
        |   from
        |   (
        |       select id,
        |              group_id,
        |              count(1) as cnt
        |       from group_union_table
        |       group by id,
        |                group_id
        |   ) t1
        |) t2
        |where t2.rank = 1
        |
      """.stripMargin)

    group_no_swing.write.mode("overwrite").
      saveAsTable("yangy.graphx_cluster_zoom_no_swing_hz_100_table")

    group_no_swing.show(20)

    sc.stop()
  }

}

** 原创内容,若要转载请联系本人 **

最后编辑于
©著作权归作者所有,转载或内容合作请联系作者
  • 序言:七十年代末,一起剥皮案震惊了整个滨河市,随后出现的几起案子,更是在滨河造成了极大的恐慌,老刑警刘岩,带你破解...
    沈念sama阅读 212,816评论 6 492
  • 序言:滨河连续发生了三起死亡事件,死亡现场离奇诡异,居然都是意外死亡,警方通过查阅死者的电脑和手机,发现死者居然都...
    沈念sama阅读 90,729评论 3 385
  • 文/潘晓璐 我一进店门,熙熙楼的掌柜王于贵愁眉苦脸地迎上来,“玉大人,你说我怎么就摊上这事。” “怎么了?”我有些...
    开封第一讲书人阅读 158,300评论 0 348
  • 文/不坏的土叔 我叫张陵,是天一观的道长。 经常有香客问我,道长,这世上最难降的妖魔是什么? 我笑而不...
    开封第一讲书人阅读 56,780评论 1 285
  • 正文 为了忘掉前任,我火速办了婚礼,结果婚礼上,老公的妹妹穿的比我还像新娘。我一直安慰自己,他们只是感情好,可当我...
    茶点故事阅读 65,890评论 6 385
  • 文/花漫 我一把揭开白布。 她就那样静静地躺着,像睡着了一般。 火红的嫁衣衬着肌肤如雪。 梳的纹丝不乱的头发上,一...
    开封第一讲书人阅读 50,084评论 1 291
  • 那天,我揣着相机与录音,去河边找鬼。 笑死,一个胖子当着我的面吹牛,可吹牛的内容都是我干的。 我是一名探鬼主播,决...
    沈念sama阅读 39,151评论 3 410
  • 文/苍兰香墨 我猛地睁开眼,长吁一口气:“原来是场噩梦啊……” “哼!你这毒妇竟也来了?” 一声冷哼从身侧响起,我...
    开封第一讲书人阅读 37,912评论 0 268
  • 序言:老挝万荣一对情侣失踪,失踪者是张志新(化名)和其女友刘颖,没想到半个月后,有当地人在树林里发现了一具尸体,经...
    沈念sama阅读 44,355评论 1 303
  • 正文 独居荒郊野岭守林人离奇死亡,尸身上长有42处带血的脓包…… 初始之章·张勋 以下内容为张勋视角 年9月15日...
    茶点故事阅读 36,666评论 2 327
  • 正文 我和宋清朗相恋三年,在试婚纱的时候发现自己被绿了。 大学时的朋友给我发了我未婚夫和他白月光在一起吃饭的照片。...
    茶点故事阅读 38,809评论 1 341
  • 序言:一个原本活蹦乱跳的男人离奇死亡,死状恐怖,灵堂内的尸体忽然破棺而出,到底是诈尸还是另有隐情,我是刑警宁泽,带...
    沈念sama阅读 34,504评论 4 334
  • 正文 年R本政府宣布,位于F岛的核电站,受9级特大地震影响,放射性物质发生泄漏。R本人自食恶果不足惜,却给世界环境...
    茶点故事阅读 40,150评论 3 317
  • 文/蒙蒙 一、第九天 我趴在偏房一处隐蔽的房顶上张望。 院中可真热闹,春花似锦、人声如沸。这庄子的主人今日做“春日...
    开封第一讲书人阅读 30,882评论 0 21
  • 文/苍兰香墨 我抬头看了看天上的太阳。三九已至,却和暖如春,着一层夹袄步出监牢的瞬间,已是汗流浃背。 一阵脚步声响...
    开封第一讲书人阅读 32,121评论 1 267
  • 我被黑心中介骗来泰国打工, 没想到刚下飞机就差点儿被人妖公主榨干…… 1. 我叫王不留,地道东北人。 一个月前我还...
    沈念sama阅读 46,628评论 2 362
  • 正文 我出身青楼,却偏偏与公主长得像,于是被迫代替她去往敌国和亲。 传闻我的和亲对象是个残疾皇子,可洞房花烛夜当晚...
    茶点故事阅读 43,724评论 2 351

推荐阅读更多精彩内容