Repository: flink Updated Branches: refs/heads/release-1.2 f2240eb93 -> 07865aaf8
[FLINK-5224] [table] Improve UDTF: emit rows directly instead of buffering them This closes #3118. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/07865aaf Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/07865aaf Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/07865aaf Branch: refs/heads/release-1.2 Commit: 07865aaf8f583f5dff79acab503b5a46bdf77179 Parents: f2240eb Author: Jark Wu <wuchong...@alibaba-inc.com> Authored: Fri Jan 13 21:53:49 2017 +0800 Committer: twalthr <twal...@apache.org> Committed: Thu Jan 26 11:02:16 2017 +0100 ---------------------------------------------------------------------- .../flink/table/codegen/CodeGenerator.scala | 94 ++++++++++ .../codegen/calls/TableFunctionCallGen.scala | 1 - .../apache/flink/table/codegen/generated.scala | 16 ++ .../flink/table/functions/TableFunction.scala | 20 +- .../flink/table/plan/nodes/FlinkCorrelate.scala | 188 +++++++++++++------ .../plan/nodes/dataset/DataSetCorrelate.scala | 31 +-- .../nodes/datastream/DataStreamCorrelate.scala | 31 +-- .../table/runtime/CorrelateFlatMapRunner.scala | 65 +++++++ .../table/runtime/TableFunctionCollector.scala | 80 ++++++++ 9 files changed, 404 insertions(+), 122 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 13fe4c3..d49d7a0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -39,6 +39,7 @@ import org.apache.flink.table.codegen.Indenter.toISC import org.apache.flink.table.codegen.calls.FunctionGenerator import org.apache.flink.table.codegen.calls.ScalarOperators._ import org.apache.flink.table.functions.UserDefinedFunction +import org.apache.flink.table.runtime.TableFunctionCollector import org.apache.flink.table.typeutils.TypeConverter import org.apache.flink.table.typeutils.TypeCheckUtils._ @@ -129,6 +130,10 @@ class CodeGenerator( // (inputTerm, index) -> expr private val reusableInputUnboxingExprs = mutable.Map[(String, Int), GeneratedExpression]() + // set of constructor statements that will be added only once + // we use a LinkedHashSet to keep the insertion order + private val reusableConstructorStatements = mutable.LinkedHashSet[(String, String)]() + /** * @return code block of statements that need to be placed in the member area of the Function * (e.g. member variables and their initialization) @@ -160,6 +165,20 @@ class CodeGenerator( } /** + * @return code block of constructor statements for the Function + */ + def reuseConstructorCode(className: String): String = { + reusableConstructorStatements.map { case (params, body) => + s""" + |public $className($params) throws Exception { + | this(); + | $body + |} + |""".stripMargin + }.mkString("", "\n", "\n") + } + + /** * @return term of the (casted and possibly boxed) first input */ var input1Term = "in1" @@ -257,6 +276,8 @@ class CodeGenerator( ${reuseInitCode()} } + ${reuseConstructorCode(funcName)} + @Override public ${samHeader._1} throws Exception { ${samHeader._2.mkString("\n")} @@ -326,6 +347,52 @@ class CodeGenerator( } /** + * Generates a [[TableFunctionCollector]] that can be passed to Java compiler. + * + * @param name Class name of the table function collector. Must not be unique but has to be a + * valid Java class identifier. + * @param bodyCode body code for the collector method + * @param collectedType The type information of the element collected by the collector + * @return instance of GeneratedCollector + */ + def generateTableFunctionCollector( + name: String, + bodyCode: String, + collectedType: TypeInformation[Any]) + : GeneratedCollector = { + + val className = newName(name) + val input1TypeClass = boxedTypeTermForTypeInfo(input1) + val input2TypeClass = boxedTypeTermForTypeInfo(collectedType) + + val funcCode = j""" + public class $className extends ${classOf[TableFunctionCollector[_]].getCanonicalName} { + + ${reuseMemberCode()} + + public $className() throws Exception { + ${reuseInitCode()} + } + + @Override + public void collect(Object record) throws Exception { + super.collect(record); + $input1TypeClass $input1Term = ($input1TypeClass) getInput(); + $input2TypeClass $input2Term = ($input2TypeClass) record; + ${reuseInputUnboxingCode()} + $bodyCode + } + + @Override + public void close() { + } + } + """.stripMargin + + GeneratedCollector(className, funcCode) + } + + /** * Generates an expression that converts the first input (and second input) into the given type. * If two inputs are converted, the second input is appended. If objects or variables can * be reused, they will be added to reusable code sections internally. The evaluation result @@ -1415,6 +1482,33 @@ class CodeGenerator( fieldTerm } + + /** + * Adds a reusable constructor statement with the given parameter types. + * + * @param parameterTypes The parameter types to construct the function + * @return member variable terms + */ + def addReusableConstructor(parameterTypes: Class[_]*): Array[String] = { + val parameters = mutable.ListBuffer[String]() + val fieldTerms = mutable.ListBuffer[String]() + val body = mutable.ListBuffer[String]() + + parameterTypes.zipWithIndex.foreach { case (t, index) => + val classQualifier = t.getCanonicalName + val fieldTerm = newName(s"instance_${classQualifier.replace('.', '$')}") + val field = s"transient $classQualifier $fieldTerm = null;" + reusableMemberStatements.add(field) + fieldTerms += fieldTerm + parameters += s"$classQualifier arg$index" + body += s"$fieldTerm = arg$index;" + } + + reusableConstructorStatements.add((parameters.mkString(","), body.mkString("", "\n", "\n"))) + + fieldTerms.toArray + } + /** * Adds a reusable array to the member area of the generated [[Function]]. */ http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala index 50c569f..6e44f55 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala @@ -69,7 +69,6 @@ class TableFunctionCallGen( val functionCallCode = s""" |${parameters.map(_.code).mkString("\n")} - |$functionReference.clear(); |$functionReference.eval(${parameters.map(_.resultTerm).mkString(", ")}); |""".stripMargin http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala index 0d60dc1..b4c293d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala @@ -40,4 +40,20 @@ object GeneratedExpression { val NO_CODE = "" } +/** + * Describes a generated [[org.apache.flink.api.common.functions.Function]] + * + * @param name class name of the generated Function. + * @param returnType the type information of the result type + * @param code code of the generated Function. + * @tparam T type of function + */ case class GeneratedFunction[T](name: String, returnType: TypeInformation[Any], code: String) + +/** + * Describes a generated [[org.apache.flink.util.Collector]]. + * + * @param name class name of the generated Collector. + * @param code code of the generated Collector. + */ +case class GeneratedCollector(name: String, code: String) http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala index 653793e..d4c5021 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala @@ -18,10 +18,9 @@ package org.apache.flink.table.functions -import java.util - import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.expressions.{Expression, TableFunctionCall} +import org.apache.flink.util.Collector /** * Base class for a user-defined table function (UDTF). A user-defined table functions works on @@ -99,27 +98,28 @@ abstract class TableFunction[T] extends UserDefinedFunction { // ---------------------------------------------------------------------------------------------- - private val rows: util.ArrayList[T] = new util.ArrayList[T]() - /** * Emit an output row. * * @param row the output row */ protected def collect(row: T): Unit = { - // cache rows for now, maybe immediately process them further - rows.add(row) + collector.collect(row) } + // ---------------------------------------------------------------------------------------------- + /** - * Internal use. Get an iterator of the buffered rows. + * The code generated collector used to emit row. */ - def getRowsIterator = rows.iterator() + private var collector: Collector[T] = _ /** - * Internal use. Clear buffered rows. + * Internal use. Sets the current collector. */ - def clear() = rows.clear() + private[flink] final def setCollector(collector: Collector[T]): Unit = { + this.collector = collector + } // ---------------------------------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala index fc69493..c986602 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala @@ -22,11 +22,11 @@ import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression, GeneratedFunction} +import org.apache.flink.table.codegen.{CodeGenerator, GeneratedCollector, GeneratedExpression, GeneratedFunction} import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, NO_CODE} import org.apache.flink.table.functions.utils.TableSqlFunction -import org.apache.flink.table.runtime.FlatMapRunner +import org.apache.flink.table.runtime.{CorrelateFlatMapRunner, TableFunctionCollector} import org.apache.flink.table.typeutils.TypeConverter._ import org.apache.flink.table.api.{TableConfig, TableException} @@ -37,15 +37,22 @@ import scala.collection.JavaConverters._ */ trait FlinkCorrelate { - private[flink] def functionBody( - generator: CodeGenerator, + /** + * Creates the [[CorrelateFlatMapRunner]] to execute the join of input table + * and user-defined table function. + */ + private[flink] def correlateMapFunction( + config: TableConfig, + inputTypeInfo: TypeInformation[Any], udtfTypeInfo: TypeInformation[Any], rowType: RelDataType, + joinType: SemiJoinType, rexCall: RexCall, condition: Option[RexNode], - config: TableConfig, - joinType: SemiJoinType, - expectedType: Option[TypeInformation[Any]]): String = { + expectedType: Option[TypeInformation[Any]], + pojoFieldMapping: Option[Array[Int]], // udtf return type pojo field mapping + ruleDescription: String) + : CorrelateFlatMapRunner[Any, Any] = { val returnType = determineReturnType( rowType, @@ -53,24 +60,72 @@ trait FlinkCorrelate { config.getNullCheck, config.getEfficientTypeUsage) - val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs + val flatMap = generateFunction( + config, + inputTypeInfo, + udtfTypeInfo, + returnType, + rowType, + joinType, + rexCall, + pojoFieldMapping, + ruleDescription) + + val collector = generateCollector( + config, + inputTypeInfo, + udtfTypeInfo, + returnType, + rowType, + condition, + pojoFieldMapping) + + new CorrelateFlatMapRunner[Any, Any]( + flatMap.name, + flatMap.code, + collector.name, + collector.code, + flatMap.returnType) + + } + + /** + * Generates the flat map function to run the user-defined table function. + */ + private def generateFunction( + config: TableConfig, + inputTypeInfo: TypeInformation[Any], + udtfTypeInfo: TypeInformation[Any], + returnType: TypeInformation[Any], + rowType: RelDataType, + joinType: SemiJoinType, + rexCall: RexCall, + pojoFieldMapping: Option[Array[Int]], + ruleDescription: String) + : GeneratedFunction[FlatMapFunction[Any, Any]] = { + + val functionGenerator = new CodeGenerator( + config, + false, + inputTypeInfo, + Some(udtfTypeInfo), + None, + pojoFieldMapping) - val call = generator.generateExpression(rexCall) + val (input1AccessExprs, input2AccessExprs) = functionGenerator.generateCorrelateAccessExprs + + val collectorTerm = functionGenerator + .addReusableConstructor(classOf[TableFunctionCollector[_]]) + .head + + val call = functionGenerator.generateExpression(rexCall) var body = s""" - |${call.code} - |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator(); - """.stripMargin + |${call.resultTerm}.setCollector($collectorTerm); + |${call.code} + |""".stripMargin - if (joinType == SemiJoinType.INNER) { - // cross join - body += - s""" - |if (!iter.hasNext()) { - | return; - |} - """.stripMargin - } else if (joinType == SemiJoinType.LEFT) { + if (joinType == SemiJoinType.LEFT) { // left outer join // in case of left outer join and the returned row of table function is empty, @@ -82,63 +137,78 @@ trait FlinkCorrelate { NO_CODE, x.resultType) } - val outerResultExpr = generator.generateResultExpression( + val outerResultExpr = functionGenerator.generateResultExpression( input1AccessExprs ++ input2NullExprs, returnType, rowType.getFieldNames.asScala) body += s""" - |if (!iter.hasNext()) { - | ${outerResultExpr.code} - | ${generator.collectorTerm}.collect(${outerResultExpr.resultTerm}); - | return; - |} - """.stripMargin - } else { + |boolean hasOutput = $collectorTerm.isCollected(); + |if (!hasOutput) { + | ${outerResultExpr.code} + | ${functionGenerator.collectorTerm}.collect(${outerResultExpr.resultTerm}); + |} + |""".stripMargin + } else if (joinType != SemiJoinType.INNER) { throw TableException(s"Unsupported SemiJoinType: $joinType for correlate join.") } + functionGenerator.generateFunction( + ruleDescription, + classOf[FlatMapFunction[Any, Any]], + body, + returnType) + } + + /** + * Generates table function collector. + */ + private[flink] def generateCollector( + config: TableConfig, + inputTypeInfo: TypeInformation[Any], + udtfTypeInfo: TypeInformation[Any], + returnType: TypeInformation[Any], + rowType: RelDataType, + condition: Option[RexNode], + pojoFieldMapping: Option[Array[Int]]) + : GeneratedCollector = { + + val generator = new CodeGenerator( + config, + false, + inputTypeInfo, + Some(udtfTypeInfo), + None, + pojoFieldMapping) + + val (input1AccessExprs, input2AccessExprs) = generator.generateCorrelateAccessExprs + val crossResultExpr = generator.generateResultExpression( input1AccessExprs ++ input2AccessExprs, returnType, rowType.getFieldNames.asScala) - val projection = if (condition.isEmpty) { + val collectorCode = if (condition.isEmpty) { s""" - |${crossResultExpr.code} - |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); - """.stripMargin + |${crossResultExpr.code} + |getCollector().collect(${crossResultExpr.resultTerm}); + |""".stripMargin } else { val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo) filterGenerator.input1Term = filterGenerator.input2Term val filterCondition = filterGenerator.generateExpression(condition.get) s""" - |${filterGenerator.reuseInputUnboxingCode()} - |${filterCondition.code} - |if (${filterCondition.resultTerm}) { - | ${crossResultExpr.code} - | ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm}); - |} - |""".stripMargin + |${filterGenerator.reuseInputUnboxingCode()} + |${filterCondition.code} + |if (${filterCondition.resultTerm}) { + | ${crossResultExpr.code} + | getCollector().collect(${crossResultExpr.resultTerm}); + |} + |""".stripMargin } - val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName - body += - s""" - |while (iter.hasNext()) { - | $outputTypeClass ${generator.input2Term} = ($outputTypeClass) iter.next(); - | $projection - |} - """.stripMargin - body - } - - private[flink] def correlateMapFunction( - genFunction: GeneratedFunction[FlatMapFunction[Any, Any]]) - : FlatMapRunner[Any, Any] = { - - new FlatMapRunner[Any, Any]( - genFunction.name, - genFunction.code, - genFunction.returnType) + generator.generateTableFunctionCollector( + "TableFunctionCollector", + collectorCode, + udtfTypeInfo) } private[flink] def selectToString(rowType: RelDataType): String = { http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index fa1afc3..5a75e5d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -24,11 +24,9 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType -import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment -import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.FlinkCorrelate import org.apache.flink.table.typeutils.TypeConverter._ @@ -93,11 +91,6 @@ class DataSetCorrelate( : DataSet[Any] = { val config = tableEnv.getConfig - val returnType = determineReturnType( - getRowType, - expectedType, - config.getNullCheck, - config.getEfficientTypeUsage) // we do not need to specify input type val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv) @@ -108,31 +101,17 @@ class DataSetCorrelate( val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val generator = new CodeGenerator( + val mapFunc = correlateMapFunction( config, - false, inputDS.getType, - Some(udtfTypeInfo), - None, - Some(pojoFieldMapping)) - - val body = functionBody( - generator, udtfTypeInfo, getRowType, + joinType, rexCall, condition, - config, - joinType, - expectedType) - - val genFunction = generator.generateFunction( - ruleDescription, - classOf[FlatMapFunction[Any, Any]], - body, - returnType) - - val mapFunc = correlateMapFunction(genFunction) + expectedType, + Some(pojoFieldMapping), + ruleDescription) inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index a2d167b..bd65954 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -23,9 +23,7 @@ import org.apache.calcite.rel.logical.LogicalTableFunctionScan import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType -import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.FlinkCorrelate import org.apache.flink.table.typeutils.TypeConverter._ @@ -87,11 +85,6 @@ class DataStreamCorrelate( : DataStream[Any] = { val config = tableEnv.getConfig - val returnType = determineReturnType( - getRowType, - expectedType, - config.getNullCheck, - config.getEfficientTypeUsage) // we do not need to specify input type val inputDS = inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv) @@ -102,31 +95,17 @@ class DataStreamCorrelate( val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val generator = new CodeGenerator( + val mapFunc = correlateMapFunction( config, - false, inputDS.getType, - Some(udtfTypeInfo), - None, - Some(pojoFieldMapping)) - - val body = functionBody( - generator, udtfTypeInfo, getRowType, + joinType, rexCall, condition, - config, - joinType, - expectedType) - - val genFunction = generator.generateFunction( - ruleDescription, - classOf[FlatMapFunction[Any, Any]], - body, - returnType) - - val mapFunc = correlateMapFunction(genFunction) + expectedType, + Some(pojoFieldMapping), + ruleDescription) inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) } http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala new file mode 100644 index 0000000..4e803da --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.runtime + +import org.apache.flink.api.common.functions.{FlatMapFunction, RichFlatMapFunction} +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.java.typeutils.ResultTypeQueryable +import org.apache.flink.configuration.Configuration +import org.apache.flink.table.codegen.Compiler +import org.apache.flink.util.Collector +import org.slf4j.{Logger, LoggerFactory} + +class CorrelateFlatMapRunner[IN, OUT]( + flatMapName: String, + flatMapCode: String, + collectorName: String, + collectorCode: String, + @transient returnType: TypeInformation[OUT]) + extends RichFlatMapFunction[IN, OUT] + with ResultTypeQueryable[OUT] + with Compiler[Any] { + + val LOG: Logger = LoggerFactory.getLogger(this.getClass) + + private var function: FlatMapFunction[IN, OUT] = _ + private var collector: TableFunctionCollector[_] = _ + + override def open(parameters: Configuration): Unit = { + LOG.debug(s"Compiling TableFunctionCollector: $collectorName \n\n Code:\n$collectorCode") + val clazz = compile(getRuntimeContext.getUserCodeClassLoader, collectorName, collectorCode) + LOG.debug("Instantiating TableFunctionCollector.") + collector = clazz.newInstance().asInstanceOf[TableFunctionCollector[_]] + + LOG.debug(s"Compiling FlatMapFunction: $flatMapName \n\n Code:\n$flatMapCode") + val flatMapClazz = compile(getRuntimeContext.getUserCodeClassLoader, flatMapName, flatMapCode) + val constructor = flatMapClazz.getConstructor(classOf[TableFunctionCollector[_]]) + LOG.debug("Instantiating FlatMapFunction.") + function = constructor.newInstance(collector).asInstanceOf[FlatMapFunction[IN, OUT]] + } + + override def flatMap(in: IN, out: Collector[OUT]): Unit = { + collector.setCollector(out) + collector.setInput(in) + collector.reset() + function.flatMap(in, out) + } + + override def getProducedType: TypeInformation[OUT] = returnType +} http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala new file mode 100644 index 0000000..c9cca47 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.runtime + +import org.apache.flink.util.Collector + +/** + * The basic implementation of collector for [[org.apache.flink.table.functions.TableFunction]]. + */ +abstract class TableFunctionCollector[T] extends Collector[T] { + + private var input: Any = _ + private var collector: Collector[_] = _ + private var collected: Boolean = _ + + /** + * Sets the input row from left table, + * which will be used to cross join with the result of table function. + */ + def setInput(input: Any): Unit = { + this.input = input + } + + /** + * Gets the input value from left table, + * which will be used to cross join with the result of table function. + */ + def getInput: Any = { + input + } + + /** + * Sets the current collector, which used to emit the final row. + */ + def setCollector(collector: Collector[_]): Unit = { + this.collector = collector + } + + /** + * Gets the internal collector which used to emit the final row. + */ + def getCollector: Collector[_] = { + this.collector + } + + /** + * Resets the flag to indicate whether [[collect(T)]] has been called. + */ + def reset(): Unit = { + collected = false + } + + /** + * Whether [[collect(T)]] has been called. + * + * @return True if [[collect(T)]] has been called. + */ + def isCollected: Boolean = collected + + override def collect(record: T): Unit = { + collected = true + } +} + +