Flink UDF
概述
- 什么是UDF
- UDF是User-defined Functions的缩写,即自定义函数。
- UDF种类
- UDF分为三种:Scalar Functions、Table Functions、Aggregation Functions
- Scalar Functions
- Table Functions
- 和上面的Scalar Functions接收的参数个数一样,不同的是可以返回多行,而不是单个值
- Aggregation Functions
- 从名字就可以看出来,这个是搭配GROUP BY一起使用的,将表的一个或多个列的一行或多行数据汇聚到一个值里面,看上去有点拗口,其实可以把它简单理解为SQL中的聚合函数
- Table Aggregation Functions
- 相当于Table Functions和Aggregation Functions的结合体,聚合之后,再返回多行多列
- 为什么要有UDF
- Flink SQL目前提供了很多的内置UDF,主要是为了大家更方便的编写SQL代码完成自己的业务逻辑,具体内置的UDF可以参考官方文档;同时,Flink 也支持注册自己的UDF,下面正式开始我们今天的UDF探索之旅。
Scalar Functions
//不墨迹,我们直接贴代码
package udf;
import org.apache.flink.table.functions.ScalarFunction;
public class TestScalarFunc extends ScalarFunction {
private int factor = 2020;
//和传入数据进行计算的逻辑,参数个数任意
public int eval() {
return factor;
}
public int eval(int a) {
return a * factor;
}
public int eval(int... a) {
int res = 1;
for (int i : a) {
res *= i;
}
return res * factor;
}
}
- 自定义Scalar Functions,需要继承
ScalarFunction
,并且有一个public
的eval()
,方法可以接受任意个数参数,同时也可以在一个类中重载eval()
- 写完UDF之后需要注册到我们的运行环境中,使用姿势有两种:
tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
tEnv.registerFunction("test",new TestScalarFunc());
- 第一种偏向在纯SQL的环境中使用,比如我们有个Flink SQL的提交平台,只支持纯SQL语句,那我们可以把自己写的UDF打包上传到平台后,通过SQL语句
CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'
来创建UDF;同时可以把UDF注册到catalog
中,这里先不深入讨论,之后我们说到Flink X Hive的时候再聊吧
- 第二种注册方式,如果我们的类有构造方法,可以通过new 对象的时候传递变量进去,更为灵活一点
Table Functions
package udf;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.calcite.shaded.com.google.common.base.Strings;
import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.types.Row;
public class TestTableFunction extends TableFunction {
private String separator = ",";
public TestTableFunction(String separator) {
this.separator = separator;
}
//和传入数据进行计算的逻辑,参数个数任意
public void eval(String input){
Row row = null;
if (Strings.isNullOrEmpty(input)){
row = new Row(2);
row.setField(0,null);
row.setField(1,0);
collect(row);
}else {
String[] split = input.split(separator);
for (String word : split) {
row = new Row(2);
row.setField(0,word);
row.setField(1,word.length());
collect(row);
}
}
}
@Override
public TypeInformation getResultType() {
return Types.ROW(Types.STRING,Types.INT);
}
}
- 自定义Table Functions,需要继承
TableFunction
,并且有一个public
的eval()
,方法可以接受任意个数参数,同时也可以在一个类中重载eval()
- 因为返回的是
Row
类型,所以需要重写getResultType()
- 在SQL语句中使用时,有两种写法:
select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)
select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE
- 第一种的用法相当于用的是
CROSS JOIN
- 第二种的用法是
LEFT JOIN
Aggregation Functions
package udf;
import org.apache.flink.table.functions.AggregateFunction;
import java.util.Iterator;
public class TestAggregateFunction extends AggregateFunction<Long, TestAggregateFunction.SumAll> {
//返回最终结果
@Override
public Long getValue(SumAll acc) {
return acc.sum;
}
//构建保存中间结果的对象
@Override
public SumAll createAccumulator() {
return new SumAll();
}
//和传入数据进行计算的逻辑
public void accumulate(SumAll acc, long iValue) {
acc.sum += iValue;
}
//减去要撤回的值
public void retract(SumAll acc, long iValue) {
acc.sum -= iValue;
}
//从每个分区把数据取出来然后合并
public void merge(SumAll acc, Iterable<SumAll> it) {
Iterator<SumAll> iter = it.iterator();
while (iter.hasNext()) {
SumAll a = iter.next();
acc.sum += a.sum;
}
}
//重置内存中值时调用
public void resetAccumulator(SumAll acc) {
acc.sum = 0L;
}
public static class SumAll {
public long sum = 0;
}
}
- 自定义Aggregation Functions,需要继承
AggregateFunction
,并且必须要有 以下的方法
-
createAccumulator()
创建一个保留中间结果的数据结构
-
accumulate()
把每个输入行与中间结果进行计算,可以重载
-
getValue()
获取最终结果
- 根据不同的使用情况,还需要以下的方法
-
retract()
用于bounded OVER窗口,即窗口有结束时间
-
merge()
用于多次批量聚合和会话窗口合并
-
resetAccumulator()
用于多次批量聚合时,清空中间结果
Table Aggregation Functions
package udf;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.table.functions.TableAggregateFunction;
import org.apache.flink.types.Row;
import org.apache.flink.util.Collector;
public class TestTableAggregateFunction extends TableAggregateFunction<Row,TestTableAggregateFunction.Top2> {
//创建保留中间结果的对象
@Override
public Top2 createAccumulator() {
Top2 t = new Top2();
t.f1 = Integer.MIN_VALUE;
t.f2 = Integer.MIN_VALUE;
return t;
}
//与传入值进行计算的方法
public void accumulate(Top2 t, Integer v) {
//如果传入的值比内存中第一个值大,那就用第一个值替换第二个值,传入的值替换第一个值;
//如果传入的值比第二个值大比第一个小,那么就替换第二个值。
if (v > t.f1) {
t.f2 = t.f1;
t.f1 = v;
} else if (v > t.f2) {
t.f2 = v;
}
}
//合并分区的值
public void merge(Top2 t, Iterable<Top2> iterable) {
for (Top2 otherT : iterable) {
accumulate(t, otherT.f1);
accumulate(t, otherT.f2);
}
}
//拿到返回结果的方法
public void emitValue(Top2 t, Collector<Row> out) {
Row row = null;
//发射数据
//如果第一个值不是最小的int值,那就发出去
//如果第二个值不是最小的int值,那就发出去
if (t.f1 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.f1);
row.setField(1,1);
out.collect(row);
}
if (t.f2 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.f2);
row.setField(1,2);
out.collect(row);
}
}
//撤回流拿结果的方法,会发射撤回数据
public void emitUpdateWithRetract(Top2 t, RetractableCollector<Row> out) {
Row row = null;
//如果新旧值不相等,才需要撤回,不然没必要
//如果旧值不等于int最小值,说明之前发射过数据,需要撤回
//然后将新值发射出去
if (!t.f1.equals(t.oldF1)) {
if (t.oldF1 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.oldF1);
row.setField(1,1);
out.retract(row);
}
row = new Row(2);
row.setField(0,t.f1);
row.setField(1,1);
out.collect(row);
t.oldF1 = t.f1;
}
//和上面逻辑一样,只是一个发射f1,一个f2
if (!t.f2.equals(t.oldF2)) {
// if there is an update, retract old value then emit new value.
if (t.oldF2 != Integer.MIN_VALUE) {
row = new Row(2);
row.setField(0,t.oldF2);
row.setField(1,2);
out.retract(row);
}
row = new Row(2);
row.setField(0,t.f2);
row.setField(1,2);
out.collect(row);
t.oldF2 = t.f2;
}
}
//保留中间结果的类
public class Top2{
public Integer f1;
public Integer f2;
public Integer oldF1;
public Integer oldF2;
}
@Override
public TypeInformation<Row> getResultType() {
return Types.ROW(Types.INT,Types.INT);
}
}
- 自定义Table Aggregation Functions,需要继承
TableAggregateFunction
,并且必须要有 以下的方法
-
createAccumulator()
创建一个保留中间结果的数据结构
-
accumulate()
把每个输入行与中间结果进行计算,可以重载
- 根据不同的使用情况,还需要以下的方法
-
retract()
用于bounded OVER窗口,即窗口有结束时间
-
merge()
用于多次批量聚合和会话窗口合并
-
resetAccumulator()
用于多次批量聚合时,清空中间结果
-
emitValue()
用于批量和窗口聚合拿到结果
-
emitUpdateWithRetract()
用于流式计算的撤回流
- 目前Table Aggregation Functions只支持在Table Api中使用
完整代码
//下面贴出来的是主类的代码,具体每个UDF的类上面已经有了
package FlinkSql;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.datastream.DataStreamSource;
import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
import org.apache.flink.streaming.api.functions.source.RichSourceFunction;
import org.apache.flink.table.api.Table;
import org.apache.flink.types.Row;
import udf.TestAggregateFunction;
import udf.TestScalarFunc;
import udf.TestTableAggregateFunction;
import udf.TestTableFunction;
import static util.FlinkConstant.env;
import static util.FlinkConstant.tEnv;
public class FlinkSql04 {
public static void main(String[] args) throws Exception {
DataStream<Row> source = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row = new Row(3);
row.setField(0, 2);
row.setField(1, 3);
row.setField(2, 3);
ctx.collect(row);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.INT,Types.INT,Types.INT));
tEnv.createTemporaryView("t",source,"a,b,c");
// tEnv.sqlUpdate("CREATE FUNCTION IF NOT EXISTS test AS 'udf.TestScalarFunc'");
tEnv.registerFunction("test",new TestScalarFunc());
Table table = tEnv.sqlQuery("select test() as a,test(a) as b, test(a,b,c) as c from t");
DataStream<Row> res = tEnv.toAppendStream(table, Row.class);
// res.print().name("Scalar Functions Print").setParallelism(1);
DataStream<Row> ds2 = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row = new Row(2);
row.setField(0, 22);
row.setField(1, "aa,b,cdd,dfsfdg,exxxxx");
ctx.collect(row);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.INT, Types.STRING));
tEnv.createTemporaryView("t2",ds2,"age,name_list");
tEnv.registerFunction("test2",new TestTableFunction(","));
// Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a, LATERAL TABLE(test2(a.name_list)) as b(name, name_length)");
Table table2 = tEnv.sqlQuery("select a.age,b.name,b.name_length from t2 a LEFT JOIN LATERAL TABLE(test2(a.name_list)) as b(name, name_length) ON TRUE");
DataStream<Row> res2 = tEnv.toAppendStream(table2, Row.class);
// res2.print().name("Table Functions Print").setParallelism(1);
DataStream<Row> ds3 = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row1 = new Row(2);
row1.setField(0,"a");
row1.setField(1,1L);
Row row2 = new Row(2);
row2.setField(0,"a");
row2.setField(1,2L);
Row row3 = new Row(2);
row3.setField(0,"b");
row3.setField(1,100L);
ctx.collect(row1);
ctx.collect(row2);
ctx.collect(row3);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.STRING, Types.LONG));
tEnv.createTemporaryView("t3",ds3,"name,cnt");
tEnv.registerFunction("test3",new TestAggregateFunction());
Table table3 = tEnv.sqlQuery("select name,test3(cnt) as mySum from t3 group by name");
DataStream<Tuple2<Boolean, Row>> res3 = tEnv.toRetractStream(table3, Row.class);
// res3.print().name("Aggregate Functions Print").setParallelism(1);
DataStream<Row> ds4 = env.addSource(new RichSourceFunction<Row>() {
@Override
public void run(SourceContext<Row> ctx) throws Exception {
Row row1 = new Row(2);
row1.setField(0,"a");
row1.setField(1,1);
Row row2 = new Row(2);
row2.setField(0,"a");
row2.setField(1,2);
Row row3 = new Row(2);
row3.setField(0,"a");
row3.setField(1,100);
ctx.collect(row1);
ctx.collect(row2);
ctx.collect(row3);
}
@Override
public void cancel() {
}
}).returns(Types.ROW(Types.STRING, Types.INT));
tEnv.createTemporaryView("t4",ds4,"name,cnt");
tEnv.registerFunction("test4",new TestTableAggregateFunction());
Table table4 = tEnv.sqlQuery("select * from t4");
Table table5 = table4.groupBy("name")
.flatAggregate("test4(cnt) as (v,rank)")
.select("name,v,rank");
DataStream<Tuple2<Boolean, Row>> res4 = tEnv.toRetractStream(table5, Row.class);
res4.print().name("Aggregate Functions Print").setParallelism(1);
env.execute("test udf");
}
}