前言
需求:业务需求要求求出score的最大值(max),最小值(min),均值(mean),标准差(stddev),中位数。需求的前四个值Spark自带函数可以解决,唯独中位数没有,所以需要自定义一个聚合函数。
实现方法以及代码
自定义函数需要继承UserDefinedAggregateFunction
class MiddleValueUDAF extends UserDefinedAggregateFunction{
// 输入参数的数据类型
override def inputSchema: StructType = {
DataTypes.createStructType(util.Arrays
.asList((DataTypes.createStructField("score",DataTypes.StringType,true))))
}
/**
*
* 更新 可以认为一个一个地将组内的字段值传递进来 实现拼接的逻辑
* buffer.getInt(0)获取的是上一次聚合后的值
* 相当于map端的combiner,combiner就是对每一个map task的处理结果进行一次小聚合
* 大聚和发生在reduce端.
* 这里即是:在进行聚合的时候,每当有新的值进来,对分组后的聚合如何进行计算
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer.update(0,Integer.valueOf(buffer.get(0).toString)+Integer.valueOf(input.get(0).toString))
buffer.update(0,buffer.get(0)+","+input.get(0).toString)
}
// buffer中的数据类型
override def bufferSchema: StructType = {
DataTypes.createStructType(util.Arrays
.asList((DataTypes.createStructField("summ",DataTypes.StringType,true))))
}
/**
* 合并其他部分结果
* 合并 update操作,可能是针对一个分组内的部分数据,在某个节点上发生的 但是可能一个分组内的数据,会分布在多个节点上处理
* 此时就要用merge操作,将各个节点上分布式拼接好的串,合并起来
* buffer1.getInt(0) : 大聚合的时候 上一次聚合后的值
* buffer2.getInt(0) : 这次计算传入进来的update的结果
* 这里即是:最后在分布式节点完成后需要进行全局级别的Merge操作
* 也可以是一个节点里面的多个executor合并
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1.update(0,Integer.valueOf(buffer1.get(0).toString)+Integer.valueOf(buffer2.get(0).toString))
buffer1.update(0,buffer1.get(0)+","+buffer2.get(0).toString)
}
//初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer.update(0,"")
}
// 确保一致性 一般用true,用以标记针对给定的一组输入,UDAF是否总是生成相同的结果
override def deterministic: Boolean = {
true
}
//计算逻辑
override def evaluate(buffer: Row): Any = {
val intArray = buffer.get(0).toString.replaceAll(",,",",").substring(1)
val list = intArray.split(",").map(_.toDouble).toList.sorted
val len = list.size
var mid = 0d
if (len % 2 == 0)
mid = (list(len / 2 - 1) + list(len / 2)) / 2
else
mid = list(len / 2)
mid
}
// 返回值的类型
override def dataType: DataType = {
DataTypes.DoubleType
}