鉴于自定义函数函数在SQL中的强大语义,在SQL中有十分广泛的应用。Flink在其Table/SQL API中同样支持自定义函数,且根据Flink Forward Asia 2019的规划,在后续flink版本中,自定义函数将支持python语言以及兼容Hive的自定义函数. 本文简要介绍Flink中的UDF支持及实现。
自定义函数支持类型
Flink支持的自定义函数包括UDF,UDTF,UDAF等能力,主要通过继承UserDefinedFunction,定义相关方法,通过自动代码生成技术(Code Generation)生成辅助类,完成对方法的调用。
自定义函数 | 功能 | 实现类 |
---|---|---|
UDF | 对字段进行转义,输入一条记录,返回一条记录 | ScalarFunction |
UDTF | 输入一条记录,返回一条/多条记录 | TableFunction |
UDAF | 输入一条/多条记录,返回一条记录 | AggregateFunction |
自定义函数实现
UDF实现
Flink中通过ScalarFunction标量来支持UDF的实现。应用只需继承ScalarFunction,在其中定义eval函数即可。
其应用较为简单,其实现是通过Code Generation技术自动生成相关代码,完成对eval函数的调用。 在ProcessOperator中,其function即为自动生成的函数,在其processElement方法中,会调用我们自定义的eval方法,完成相关的转化逻辑。
一个简单的UDF函数如下:
public class ToUpperCase extends ScalarFunction { public String eval(String s) { return s.toUpperCase(); } }
调用堆栈如下图所示:
自动代码生产技术自动生成的function源码如下:
public class DataStreamCalcRule$19 extends org.apache.flink.streaming.api.functions.ProcessFunction { final com.cyq.learning.udf.function.ToUpperCase function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd; final org.apache.flink.types.Row out = new org.apache.flink.types.Row(4); public DataStreamCalcRule$19() throws Exception { function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd = (com.cyq.learning.udf.function.ToUpperCase) org.apache.flink.table.utils.EncodingUtils.decodeStringToObject( "rO0ABXNyACljb20uY3lxLmxlYXJuaW5nLnVkZi5mdW5jdGlvbi5Ub1VwcGVyQ2FzZViZ_w3c2xj-AgAAeHIAL29yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25zLlNjYWxhckZ1bmN0aW9uygzW306qntsCAAB4cgA0b3JnLmFwYWNoZS5mbGluay50YWJsZS5mdW5jdGlvbnMuVXNlckRlZmluZWRGdW5jdGlvboP22EvVTaRLAgAAeHA", org.apache.flink.table.functions.UserDefinedFunction.class); } @Override public void open(org.apache.flink.configuration.Configuration parameters) throws Exception { function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd.open(new org.apache.flink.table.functions.FunctionContext(getRuntimeContext())); } @Override public void processElement(Object _in1, org.apache.flink.streaming.api.functions.ProcessFunction.Context ctx, org.apache.flink.util.Collector c) throws Exception { org.apache.flink.types.Row in1 = (org.apache.flink.types.Row) _in1; boolean isNull$18 = (java.lang.Long) in1.getField(2) == null; long result$17; if (isNull$18) { result$17 = -1L; } else { result$17 = (java.lang.Long) in1.getField(2); } boolean isNull$11 = (java.lang.String) in1.getField(0) == null; java.lang.String result$10; if (isNull$11) { result$10 = ""; } else { result$10 = (java.lang.String) (java.lang.String) in1.getField(0); } boolean isNull$13 = (java.lang.String) in1.getField(1) == null; java.lang.String result$12; if (isNull$13) { result$12 = ""; } else { result$12 = (java.lang.String) (java.lang.String) in1.getField(1); } if (isNull$11) { out.setField(0, null); } else { out.setField(0, result$10); } if (isNull$13) { out.setField(1, null); } else { out.setField(1, result$12); } java.lang.String result$14 = function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd.eval( isNull$13 ? null : (java.lang.String) result$12); boolean isNull$16 = result$14 == null; java.lang.String result$15; if (isNull$16) { result$15 = ""; } else { result$15 = (java.lang.String) result$14; } if (isNull$16) { out.setField(2, null); } else { out.setField(2, result$15); } if (isNull$18) { out.setField(3, null); } else { out.setField(3, result$17); } c.collect(out); } @Override public void close() throws Exception { function_com$cyq$learning$udf$function$ToUpperCase$078037202ef3ddcbc4f27e57d323d0fd.close(); } }
UDTF实现
Flink的UDTF函数实现通过继承TableFunction完成,其完成一行拆分成多行的核心在于TableFunction的collect方法,该方法通过调用collector的collect方法可以将消息发送至下游operator,当接收到一条消息后,可以对消息拆分,然后将拆分后的多条消息分别发送至下游operator即可完成一条输入,多条输出的目标。
一个简单的UDTF实现如下:
public class SplitFun extends TableFunction<Tuple2<String, Integer>> { private String separator = ","; public SplitFun(String separator) { this.separator = separator; } public void eval(String str) { for (String s : str.split(separator)) { collect(new Tuple2<String, Integer>(s, s.length())); } } }
堆栈调用图如下:
其自动代码生成代包含两个类(TableFunctionCollector$
42 和DataStreamCorrelateRule$
28),其中TableFunctionColletor即为将消息发送至下游operator的辅助collector, DataStreamCorrelateRule则是通过其processElement完成对TableFunction(SplitFun)的调用。
-
DataStreamCorrelateRule
$
28源码如下:public class DataStreamCorrelateRule$28 extends org.apache.flink.streaming.api.functions.ProcessFunction { final org.apache.flink.table.runtime.TableFunctionCollector instance_org$apache$flink$table$runtime$TableFunctionCollector$22; final com.cyq.learning.udf.function.SplitFun function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d; final org.apache.flink.types.Row out = new org.apache.flink.types.Row(6); public DataStreamCorrelateRule$28() throws Exception { function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d = (com.cyq.learning.udf.function.SplitFun) org.apache.flink.table.utils.EncodingUtils.decodeStringToObject( "rO0ABXNyACZjb20uY3lxLmxlYXJuaW5nLnVkZi5mdW5jdGlvbi5TcGxpdEZ1bl5QHqGTdoFOAgABTAAJc2VwYXJhdG9ydAASTGphdmEvbGFuZy9TdHJpbmc7eHIALm9yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25zLlRhYmxlRnVuY3Rpb27qA9IDIKYg-gIAAUwACWNvbGxlY3RvcnQAIUxvcmcvYXBhY2hlL2ZsaW5rL3V0aWwvQ29sbGVjdG9yO3hyADRvcmcuYXBhY2hlLmZsaW5rLnRhYmxlLmZ1bmN0aW9ucy5Vc2VyRGVmaW5lZEZ1bmN0aW9ug_bYS9VNpEsCAAB4cHB0AAFA", org.apache.flink.table.functions.UserDefinedFunction.class); } public DataStreamCorrelateRule$28(org.apache.flink.table.runtime.TableFunctionCollector arg0) throws Exception { this(); instance_org$apache$flink$table$runtime$TableFunctionCollector$22 = arg0; } @Override public void open(org.apache.flink.configuration.Configuration parameters) throws Exception { function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.open(new org.apache.flink.table.functions.FunctionContext(getRuntimeContext())); } @Override public void processElement(Object _in1, org.apache.flink.streaming.api.functions.ProcessFunction.Context ctx, org.apache.flink.util.Collector c) throws Exception { org.apache.flink.types.Row in1 = (org.apache.flink.types.Row) _in1; boolean isNull$15 = (java.lang.Long) in1.getField(2) == null; long result$14; if (isNull$15) { result$14 = -1L; } else { result$14 = (java.lang.Long) in1.getField(2); } boolean isNull$11 = (java.lang.String) in1.getField(0) == null; java.lang.String result$10; if (isNull$11) { result$10 = ""; } else { result$10 = (java.lang.String) (java.lang.String) in1.getField(0); } boolean isNull$13 = (java.lang.String) in1.getField(1) == null; java.lang.String result$12; if (isNull$13) { result$12 = ""; } else { result$12 = (java.lang.String) (java.lang.String) in1.getField(1); } boolean isNull$17 = (java.lang.Long) in1.getField(3) == null; long result$16; if (isNull$17) { result$16 = -1L; } else { result$16 = (long) org.apache.calcite.runtime.SqlFunctions.toLong((java.lang.Long) in1.getField(3)); } function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.setCollector(instance_org$apache$flink$table$runtime$TableFunctionCollector$22); java.lang.String result$25; boolean isNull$26; if (false) { result$25 = ""; isNull$26 = true; } else { boolean isNull$24 = (java.lang.String) in1.getField(1) == null; java.lang.String result$23; if (isNull$24) { result$23 = ""; } else { result$23 = (java.lang.String) (java.lang.String) in1.getField(1); } result$25 = result$23; isNull$26 = isNull$24; } function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.eval(isNull$26 ? null : (java.lang.String) result$25); boolean hasOutput = instance_org$apache$flink$table$runtime$TableFunctionCollector$22.isCollected(); if (!hasOutput) { if (isNull$11) { out.setField(0, null); } else { out.setField(0, result$10); } if (isNull$13) { out.setField(1, null); } else { out.setField(1, result$12); } if (isNull$15) { out.setField(2, null); } else { out.setField(2, result$14); } java.lang.Long result$27; if (isNull$17) { result$27 = null; } else { result$27 = result$16; } if (isNull$17) { out.setField(3, null); } else { out.setField(3, result$27); } if (true) { out.setField(4, null); } else { out.setField(4, ""); } if (true) { out.setField(5, null); } else { out.setField(5, -1); } c.collect(out); } } @Override public void close() throws Exception { function_com$cyq$learning$udf$function$SplitFun$f9ab41da087dbe9d50d61478f4bc638d.close(); } }
-
TableFunctionCollector
$
42源码如下:public class TableFunctionCollector$42 extends org.apache.flink.table.runtime.TableFunctionCollector { final org.apache.flink.types.Row out = new org.apache.flink.types.Row(6); public TableFunctionCollector$42() throws Exception { } @Override public void open(org.apache.flink.configuration.Configuration parameters) throws Exception { } @Override public void collect(Object record) throws Exception { super.collect(record); org.apache.flink.types.Row in1 = (org.apache.flink.types.Row) getInput(); org.apache.flink.api.java.tuple.Tuple2 in2 = (org.apache.flink.api.java.tuple.Tuple2) record; boolean isNull$34 = (java.lang.Long) in1.getField(2) == null; long result$33; if (isNull$34) { result$33 = -1L; } else { result$33 = (java.lang.Long) in1.getField(2); } boolean isNull$30 = (java.lang.String) in1.getField(0) == null; java.lang.String result$29; if (isNull$30) { result$29 = ""; } else { result$29 = (java.lang.String) (java.lang.String) in1.getField(0); } boolean isNull$32 = (java.lang.String) in1.getField(1) == null; java.lang.String result$31; if (isNull$32) { result$31 = ""; } else { result$31 = (java.lang.String) (java.lang.String) in1.getField(1); } boolean isNull$36 = (java.lang.Long) in1.getField(3) == null; long result$35; if (isNull$36) { result$35 = -1L; } else { result$35 = (long) org.apache.calcite.runtime.SqlFunctions.toLong((java.lang.Long) in1.getField(3)); } if (isNull$30) { out.setField(0, null); } else { out.setField(0, result$29); } if (isNull$32) { out.setField(1, null); } else { out.setField(1, result$31); } if (isNull$34) { out.setField(2, null); } else { out.setField(2, result$33); } java.lang.Long result$41; if (isNull$36) { result$41 = null; } else { result$41 = result$35; } if (isNull$36) { out.setField(3, null); } else { out.setField(3, result$41); } boolean isNull$38 = (java.lang.String) in2.f0 == null; java.lang.String result$37; if (isNull$38) { result$37 = ""; } else { result$37 = (java.lang.String) (java.lang.String) in2.f0; } if (isNull$38) { out.setField(4, null); } else { out.setField(4, result$37); } boolean isNull$40 = (java.lang.Integer) in2.f1 == null; int result$39; if (isNull$40) { result$39 = -1; } else { result$39 = (java.lang.Integer) in2.f1; } if (isNull$40) { out.setField(5, null); } else { out.setField(5, result$39); } getCollector().collect(out); } @Override public void close() throws Exception { } }
UDAF实现
Flink 的UDAF的实现是通过继承AggregateFunction并实现相关的函数逻辑完成统计分析功能。UDAF实现时,需要实现的函数较多,不同的场景下需要实现的方法也不尽相同,因此其实现也较为复杂。可参考UserDefinedFunction
一个简单的UDAF实现可能如下:
public class AggFunctionTest { public static void main(String[] args) throws Exception { StreamTableEnvironment streamTableEnvironment = EnvironmentUtil.getStreamTableEnvironment(); TableSchema tableSchema = new TableSchema(new String[]{"product_id", "number", "price", "time"}, new TypeInformation[]{Types.STRING, Types.LONG, Types.INT, Types.SQL_TIMESTAMP}); Map<String, String> kakfaProps = new HashMap<>(); kakfaProps.put("bootstrap.servers", BROKER_SERVERS); FlinkKafkaConsumerBase flinkKafkaConsumerBase = SourceUtil.createKafkaSourceStream(TOPIC, kakfaProps, tableSchema); DataStream<Row> stream = streamTableEnvironment.execEnv().addSource(flinkKafkaConsumerBase); streamTableEnvironment.registerFunction("wAvg", new WeightedAvg()); Table streamTable = streamTableEnvironment.fromDataStream(stream, "product_id,number,price,rowtime.rowtime"); streamTableEnvironment.registerTable("product", streamTable); Table result = streamTableEnvironment.sqlQuery("SELECT product_id, wAvg(number,price) as avgPrice FROM product group by product_id"); final TableSchema tableSchemaResult = new TableSchema(new String[]{"product_id", "avg"}, new TypeInformation[]{Types.STRING, Types.LONG}); final TypeInformation<Row> typeInfoResult = tableSchemaResult.toRowType(); // Here a toRestarctStream is need, Because the Talbe result is not an append table, DataStream ds = streamTableEnvironment.toRetractStream(result, typeInfoResult); ds.print(); streamTableEnvironment.execEnv().execute("Flink Agg Function Test"); } }
其中相关的结构及函数如下:
public class WeightedAvgAccum { public long sum = 0; public int count = 0; }
public class WeightedAvg extends AggregateFunction<Long, WeightedAvgAccum> { @Override public WeightedAvgAccum createAccumulator() { return new WeightedAvgAccum(); } @Override public Long getValue(WeightedAvgAccum acc) { if (acc.count == 0) { return null; } else { return acc.sum / acc.count; } } public void accumulate(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum += iValue * iWeight; acc.count += iWeight; } public void retract(WeightedAvgAccum acc, long iValue, int iWeight) { acc.sum -= iValue * iWeight; acc.count -= iWeight; } public void merge(WeightedAvgAccum acc, Iterable<WeightedAvgAccum> it) { Iterator<WeightedAvgAccum> iter = it.iterator(); while (iter.hasNext()) { WeightedAvgAccum a = iter.next(); acc.count += a.count; acc.sum += a.sum; } } public void resetAccumulator(WeightedAvgAccum acc) { acc.count = 0; acc.sum = 0L; } }
注意: 在AggFunctionTest中,由于我们执行的sql,是一个汇总统计分析是一个全局的统计,在此场景下转化为流进行输出,我们采用的是toRetractStream流,即可撤回流
模式,因为是全局统计,统计的结果实时发生发变化,所以无法使用appendStream。关于toRetractStream可以参考TalbeToStreamConversion
其执行由GroupAggProcessFunction的processelement方法完成,核心逻辑如下:
override def processElement( inputC: CRow, ctx: ProcessFunction[CRow, CRow]#Context, out: Collector[CRow]): Unit = { val currentTime = ctx.timerService().currentProcessingTime() // register state-cleanup timer processCleanupTimer(ctx, currentTime) val input = inputC.row // get accumulators and input counter var accumulators = state.value() var inputCnt = cntState.value() if (null == accumulators) { // Don't create a new accumulator for a retraction message. This // might happen if the retraction message is the first message for the // key or after a state clean up. if (!inputC.change) { return } // first accumulate message firstRow = true accumulators = function.createAccumulators() } else { firstRow = false } if (null == inputCnt) { inputCnt = 0L } // Set group keys value to the final output function.setForwardedFields(input, newRow.row) function.setForwardedFields(input, prevRow.row) // Set previous aggregate result to the prevRow function.setAggregationResults(accumulators, prevRow.row) // update aggregate result and set to the newRow if (inputC.change) { inputCnt += 1 // accumulate input function.accumulate(accumulators, input) function.setAggregationResults(accumulators, newRow.row) } else { inputCnt -= 1 // retract input function.retract(accumulators, input) function.setAggregationResults(accumulators, newRow.row) } if (inputCnt != 0) { // we aggregated at least one record for this key // update the state state.update(accumulators) cntState.update(inputCnt) // if this was not the first row if (!firstRow) { if (prevRow.row.equals(newRow.row) && !stateCleaningEnabled) { // newRow is the same as before and state cleaning is not enabled. // We emit nothing // If state cleaning is enabled, we have to emit messages to prevent too early // state eviction of downstream operators. return } else { // retract previous result if (generateRetraction) { out.collect(prevRow) } } } // emit the new result out.collect(newRow) } else { // we retracted the last record for this key // sent out a delete message out.collect(prevRow) // and clear all state state.clear() cntState.clear() } }
从以上代码可以看出其相关的统计功能主要是通过function完成状态更新,输出判定等,其中Function依旧是通过自动代码生成技术生成,function源码如下:
public final class NonWindowedAggregationHelper$17 extends org.apache.flink.table.runtime.aggregate.GeneratedAggregations { final com.cyq.learning.udf.function.WeightedAvg function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881; private final org.apache.flink.table.runtime.aggregate.SingleElementIterable<com.cyq.learning.udf.function.WeightedAvgAccum> accIt0 = new org.apache.flink.table.runtime.aggregate.SingleElementIterable<com.cyq.learning.udf.function.WeightedAvgAccum>(); public NonWindowedAggregationHelper$17() throws Exception { function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881 = (com.cyq.learning.udf.function.WeightedAvg) org.apache.flink.table.utils.EncodingUtils.decodeStringToObject( "rO0ABXNyACljb20uY3lxLmxlYXJuaW5nLnVkZi5mdW5jdGlvbi5XZWlnaHRlZEF2Z7o-ESHvqA7TAgAAeHIAMm9yZy5hcGFjaGUuZmxpbmsudGFibGUuZnVuY3Rpb25zLkFnZ3JlZ2F0ZUZ1bmN0aW9uYPYkANq5VRgCAAB4cgA0b3JnLmFwYWNoZS5mbGluay50YWJsZS5mdW5jdGlvbnMuVXNlckRlZmluZWRGdW5jdGlvboP22EvVTaRLAgAAeHA", org.apache.flink.table.functions.UserDefinedFunction.class); } public final void open( org.apache.flink.api.common.functions.RuntimeContext ctx) throws Exception { function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.open(new org.apache.flink.table.functions.FunctionContext(ctx)); } public final void setAggregationResults( org.apache.flink.types.Row accs, org.apache.flink.types.Row output) throws Exception { org.apache.flink.table.functions.AggregateFunction baseClass0 = (org.apache.flink.table.functions.AggregateFunction) function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881; com.cyq.learning.udf.function.WeightedAvgAccum acc0 = (com.cyq.learning.udf.function.WeightedAvgAccum) accs.getField(0); output.setField( 1, baseClass0.getValue(acc0)); } public final void accumulate( org.apache.flink.types.Row accs, org.apache.flink.types.Row input) throws Exception { com.cyq.learning.udf.function.WeightedAvgAccum acc0 = (com.cyq.learning.udf.function.WeightedAvgAccum) accs.getField(0); function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.accumulate(acc0 , (java.lang.Long) input.getField(1), (java.lang.Integer) input.getField(2)); } public final void retract( org.apache.flink.types.Row accs, org.apache.flink.types.Row input) throws Exception { } public final org.apache.flink.types.Row createAccumulators() throws Exception { org.apache.flink.types.Row accs = new org.apache.flink.types.Row(1); com.cyq.learning.udf.function.WeightedAvgAccum acc0 = (com.cyq.learning.udf.function.WeightedAvgAccum) function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.createAccumulator(); accs.setField( 0, acc0); return accs; } public final void setForwardedFields( org.apache.flink.types.Row input, org.apache.flink.types.Row output) { output.setField( 0, input.getField(0)); } public final org.apache.flink.types.Row createOutputRow() { return new org.apache.flink.types.Row(2); } public final org.apache.flink.types.Row mergeAccumulatorsPair( org.apache.flink.types.Row a, org.apache.flink.types.Row b) { return a; } public final void resetAccumulator( org.apache.flink.types.Row accs) throws Exception { } public final void cleanup() throws Exception { } public final void close() throws Exception { function_com$cyq$learning$udf$function$WeightedAvg$776eb73449c07a5f410bff97fc0bb881.close(); } }
以上源码读者可参考Learning
读者通过如上的分析及各种自动生成的代码大致可了解到自定义函数的执行流程。其中较为难理解的部分主要是高大上的自动代码生成技术。由于Code Generation在sql执行中性能优越性以及实现功能的灵活性,本技术在Spark和Flink中都有广泛应用。后续笔者也会对Code Generation进行分析。