Spark RDD方式求topN
详见代码:
测试数据:
aa 49
bb 11
cc 34
aa 22
bb 67
cc 29
aa 36
bb 33
cc 30
aa 11
bb 44
cc 49
Spark RDD 代码
package cn.ted.secondarySort
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* Author: LiYahui
* Date: Created in 2019/3/1 10:57
* Description: TODO spark 算子求分组topN,需要实现组内排序
* Version: V1.0
*/
object GroupedTopN {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.appName(s"${this.getClass.getSimpleName}")
.master("local[2]")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.parquet.compression.codec", "gzip")
.getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
val inputPath = "F:\\LocalFileForTest\\topN"
//-------------------------------------------------------------------------------------------
//以数组得形式进行返回
val resultArray: Array[(String, List[String])] = sc.textFile(inputPath)
.map(_.split(" "))
.map(line => (line(0), line(1)))
.groupByKey()
.map(line => {
(line._1, line._2.toList.sortWith(_.toInt > _.toInt).take(3)) //按照降序进行排列
}).collect()
//将结果进行打印
for (ele <- resultArray) {
println("结果数据中的元素为:" + ele)
}
/**
* 结果数据中的元素为:(aa,List(49, 36, 22))
* 结果数据中的元素为:(bb,List(67, 44, 33))
* 结果数据中的元素为:(cc,List(49, 34, 30))
*/
//--------------------------------------------------------------------------------------------------
//方式二,能进行实际开发使用的。这种的效率
val resultRDD: RDD[(String, List[String])] = sc.textFile(inputPath)
.map(_.split(" "))
.map(line => (line(0), line(1)))
.groupByKey()
.map(line => {
(line._1, line._2.toList.sortWith(_ > _).take(2)) //按照降序进行排列
})
// 直接进行toDF操作,转换成dataframe
import spark.implicits._
val frame: DataFrame = resultRDD.toDF("key", "value")
frame.show()
/**
* +---+--------+
* |key| value|
* +---+--------+
* | aa|[49, 36]|
* | bb|[67, 44]|
* | cc|[49, 34]|
* +---+--------+
*/
//----------------------------------------------------------------------------------------------
println("===============华丽分割线=====================")
//采用spark core的方式进行转换到df
//将分组的好的topN转化成可用的rdd或者是dataframe
val tempRow: RDD[Row] = resultRDD.flatMap(line => {
val key: String = line._1.toString
val value: List[String] = line._2
flatMapTransformRow(key, value)
})
// 定义spark schema
val schema = StructType(List(
StructField("key", StringType, false),
StructField("value", StringType, false)
))
val tempDF: DataFrame = spark.createDataFrame(tempRow, schema)
tempDF.show()
/**
* +---+-----+
* |key|value|
* +---+-----+
* | aa| 49|
* | aa| 36|
* | bb| 67|
* | bb| 44|
* | cc| 49|
* | cc| 34|
* +---+-----+
*/
//-------------------------------------------------------------------------
spark.stop()
sc.stop()
}
/**
* 将 rdd进行列转行
*
* @param key
* @param value
* @return
*/
def flatMapTransformRow(key: String, value: List[String]) = {
// 定义最后的返回格式
var resultRow: Seq[Row] = Seq[Row]()
for (ele <- value) {
//注意此处书写格式
resultRow = resultRow :+ Row(key, ele)
}
resultRow
}
/**
* 数据源:
* aa 11
* bb 11
* cc 34
* aa 22
* bb 67
* cc 29
* aa 36
* bb 33
* cc 30
* aa 42
* bb 44
* cc 49
*
* 需求:1、对上述数据按key值进行分组
*
* 2、对分组后的值进行排序
*
* 3、截取分组后值得top 3位以key-value形式返回结果
*/
}
Spark SQL代码
代码如下:
- 给出的建议:代码的要回写,sql风格的代码是需要更要会写的,面试的时候经常会问道,让你手写,sql的功力还是需要经常进行练习的。
package cn.ted.secondarySort
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{Window, WindowSpec}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
/**
* Author: LiYahui
* Date: Created in 2019/3/1 10:57
* Description: TODO spark 算子求分组topN,需要实现组内排序
* Version: V1.0
*/
object GroupedTopN {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder()
.appName(s"${this.getClass.getSimpleName}")
.master("local[2]")
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
.config("spark.sql.parquet.compression.codec", "gzip")
.getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
val inputPath = "F:\\LocalFileForTest\\topN"
val tmpDF: DataFrame = sc.textFile(inputPath).map(line => {
val arr: Array[String] = line.split(" ")
(arr(0), arr(1))
}).toDF("key", "value")
// DSL风格
import org.apache.spark.sql.functions.row_number
//注意:执行排序的时候需要对字段加$引用
val windowRule: WindowSpec = Window.partitionBy("key").orderBy($"value".desc)
val resultDS: Dataset[Row] = tmpDF.withColumn("rank", row_number.over(windowRule))
.where("rank<3")
resultDS.show()
/**
* +---+-----+----+
* |key|value|rank|
* +---+-----+----+
* | cc| 49| 1|
* | cc| 34| 2|
* | bb| 67| 1|
* | bb| 44| 2|
* | aa| 49| 1|
* | aa| 36| 2|
* +---+-----+----+
*/
tmpDF.createOrReplaceTempView("tmp")
//sql 风格
val ranksql = "select key,value,row_number() over(partition by key order by value desc) as rank from tmp having rank <3"
spark.sql(ranksql).show()
/**
* +---+-----+----+
* |key|value|rank|
* +---+-----+----+
* | cc| 49| 1|
* | cc| 34| 2|
* | bb| 67| 1|
* | bb| 44| 2|
* | aa| 49| 1|
* | aa| 36| 2|
* +---+-----+----+
*/
spark.stop()
sc.stop()
}