Repository: flink
Updated Branches:
  refs/heads/release-1.2 f2240eb93 -> 07865aaf8


[FLINK-5224] [table] Improve UDTF: emit rows directly instead of buffering them

This closes #3118.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/07865aaf
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/07865aaf
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/07865aaf

Branch: refs/heads/release-1.2
Commit: 07865aaf8f583f5dff79acab503b5a46bdf77179
Parents: f2240eb
Author: Jark Wu <wuchong...@alibaba-inc.com>
Authored: Fri Jan 13 21:53:49 2017 +0800
Committer: twalthr <twal...@apache.org>
Committed: Thu Jan 26 11:02:16 2017 +0100

----------------------------------------------------------------------
 .../flink/table/codegen/CodeGenerator.scala     |  94 ++++++++++
 .../codegen/calls/TableFunctionCallGen.scala    |   1 -
 .../apache/flink/table/codegen/generated.scala  |  16 ++
 .../flink/table/functions/TableFunction.scala   |  20 +-
 .../flink/table/plan/nodes/FlinkCorrelate.scala | 188 +++++++++++++------
 .../plan/nodes/dataset/DataSetCorrelate.scala   |  31 +--
 .../nodes/datastream/DataStreamCorrelate.scala  |  31 +--
 .../table/runtime/CorrelateFlatMapRunner.scala  |  65 +++++++
 .../table/runtime/TableFunctionCollector.scala  |  80 ++++++++
 9 files changed, 404 insertions(+), 122 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index 13fe4c3..d49d7a0 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -39,6 +39,7 @@ import org.apache.flink.table.codegen.Indenter.toISC
 import org.apache.flink.table.codegen.calls.FunctionGenerator
 import org.apache.flink.table.codegen.calls.ScalarOperators._
 import org.apache.flink.table.functions.UserDefinedFunction
+import org.apache.flink.table.runtime.TableFunctionCollector
 import org.apache.flink.table.typeutils.TypeConverter
 import org.apache.flink.table.typeutils.TypeCheckUtils._
 
@@ -129,6 +130,10 @@ class CodeGenerator(
   // (inputTerm, index) -> expr
   private val reusableInputUnboxingExprs = mutable.Map[(String, Int), 
GeneratedExpression]()
 
+  // set of constructor statements that will be added only once
+  // we use a LinkedHashSet to keep the insertion order
+  private val reusableConstructorStatements = mutable.LinkedHashSet[(String, 
String)]()
+
   /**
     * @return code block of statements that need to be placed in the member 
area of the Function
     *         (e.g. member variables and their initialization)
@@ -160,6 +165,20 @@ class CodeGenerator(
   }
 
   /**
+    * @return code block of constructor statements for the Function
+    */
+  def reuseConstructorCode(className: String): String = {
+    reusableConstructorStatements.map { case (params, body) =>
+      s"""
+        |public $className($params) throws Exception {
+        |  this();
+        |  $body
+        |}
+        |""".stripMargin
+    }.mkString("", "\n", "\n")
+  }
+
+  /**
     * @return term of the (casted and possibly boxed) first input
     */
   var input1Term = "in1"
@@ -257,6 +276,8 @@ class CodeGenerator(
           ${reuseInitCode()}
         }
 
+        ${reuseConstructorCode(funcName)}
+
         @Override
         public ${samHeader._1} throws Exception {
           ${samHeader._2.mkString("\n")}
@@ -326,6 +347,52 @@ class CodeGenerator(
   }
 
   /**
+    * Generates a [[TableFunctionCollector]] that can be passed to Java 
compiler.
+    *
+    * @param name Class name of the table function collector. Must not be 
unique but has to be a
+    *             valid Java class identifier.
+    * @param bodyCode body code for the collector method
+    * @param collectedType The type information of the element collected by 
the collector
+    * @return instance of GeneratedCollector
+    */
+  def generateTableFunctionCollector(
+      name: String,
+      bodyCode: String,
+      collectedType: TypeInformation[Any])
+    : GeneratedCollector = {
+
+    val className = newName(name)
+    val input1TypeClass = boxedTypeTermForTypeInfo(input1)
+    val input2TypeClass = boxedTypeTermForTypeInfo(collectedType)
+
+    val funcCode = j"""
+      public class $className extends 
${classOf[TableFunctionCollector[_]].getCanonicalName} {
+
+        ${reuseMemberCode()}
+
+        public $className() throws Exception {
+          ${reuseInitCode()}
+        }
+
+        @Override
+        public void collect(Object record) throws Exception {
+          super.collect(record);
+          $input1TypeClass $input1Term = ($input1TypeClass) getInput();
+          $input2TypeClass $input2Term = ($input2TypeClass) record;
+          ${reuseInputUnboxingCode()}
+          $bodyCode
+        }
+
+        @Override
+        public void close() {
+        }
+      }
+    """.stripMargin
+
+    GeneratedCollector(className, funcCode)
+  }
+
+  /**
     * Generates an expression that converts the first input (and second input) 
into the given type.
     * If two inputs are converted, the second input is appended. If objects or 
variables can
     * be reused, they will be added to reusable code sections internally. The 
evaluation result
@@ -1415,6 +1482,33 @@ class CodeGenerator(
     fieldTerm
   }
 
+
+  /**
+    * Adds a reusable constructor statement with the given parameter types.
+    *
+    * @param parameterTypes The parameter types to construct the function
+    * @return member variable terms
+    */
+  def addReusableConstructor(parameterTypes: Class[_]*): Array[String] = {
+    val parameters = mutable.ListBuffer[String]()
+    val fieldTerms = mutable.ListBuffer[String]()
+    val body = mutable.ListBuffer[String]()
+
+    parameterTypes.zipWithIndex.foreach { case (t, index) =>
+      val classQualifier = t.getCanonicalName
+      val fieldTerm = newName(s"instance_${classQualifier.replace('.', '$')}")
+      val field = s"transient $classQualifier $fieldTerm = null;"
+      reusableMemberStatements.add(field)
+      fieldTerms += fieldTerm
+      parameters += s"$classQualifier arg$index"
+      body += s"$fieldTerm = arg$index;"
+    }
+
+    reusableConstructorStatements.add((parameters.mkString(","), 
body.mkString("", "\n", "\n")))
+
+    fieldTerms.toArray
+  }
+
   /**
     * Adds a reusable array to the member area of the generated [[Function]].
     */

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
index 50c569f..6e44f55 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/TableFunctionCallGen.scala
@@ -69,7 +69,6 @@ class TableFunctionCallGen(
     val functionCallCode =
       s"""
         |${parameters.map(_.code).mkString("\n")}
-        |$functionReference.clear();
         |$functionReference.eval(${parameters.map(_.resultTerm).mkString(", 
")});
         |""".stripMargin
 

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala
index 0d60dc1..b4c293d 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala
@@ -40,4 +40,20 @@ object GeneratedExpression {
   val NO_CODE = ""
 }
 
+/**
+  * Describes a generated [[org.apache.flink.api.common.functions.Function]]
+  *
+  * @param name class name of the generated Function.
+  * @param returnType the type information of the result type
+  * @param code code of the generated Function.
+  * @tparam T type of function
+  */
 case class GeneratedFunction[T](name: String, returnType: 
TypeInformation[Any], code: String)
+
+/**
+  * Describes a generated [[org.apache.flink.util.Collector]].
+  *
+  * @param name class name of the generated Collector.
+  * @param code code of the generated Collector.
+  */
+case class GeneratedCollector(name: String, code: String)

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
index 653793e..d4c5021 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/TableFunction.scala
@@ -18,10 +18,9 @@
 
 package org.apache.flink.table.functions
 
-import java.util
-
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.table.expressions.{Expression, TableFunctionCall}
+import org.apache.flink.util.Collector
 
 /**
   * Base class for a user-defined table function (UDTF). A user-defined table 
functions works on
@@ -99,27 +98,28 @@ abstract class TableFunction[T] extends UserDefinedFunction 
{
 
   // 
----------------------------------------------------------------------------------------------
 
-  private val rows: util.ArrayList[T] = new util.ArrayList[T]()
-
   /**
     * Emit an output row.
     *
     * @param row the output row
     */
   protected def collect(row: T): Unit = {
-    // cache rows for now, maybe immediately process them further
-    rows.add(row)
+    collector.collect(row)
   }
 
+  // 
----------------------------------------------------------------------------------------------
+
   /**
-    * Internal use. Get an iterator of the buffered rows.
+    * The code generated collector used to emit row.
     */
-  def getRowsIterator = rows.iterator()
+  private var collector: Collector[T] = _
 
   /**
-    * Internal use. Clear buffered rows.
+    * Internal use. Sets the current collector.
     */
-  def clear() = rows.clear()
+  private[flink] final def setCollector(collector: Collector[T]): Unit = {
+    this.collector = collector
+  }
 
   // 
----------------------------------------------------------------------------------------------
 

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala
index fc69493..c986602 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkCorrelate.scala
@@ -22,11 +22,11 @@ import org.apache.calcite.rex.{RexCall, RexNode}
 import org.apache.calcite.sql.SemiJoinType
 import org.apache.flink.api.common.functions.FlatMapFunction
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression, 
GeneratedFunction}
+import org.apache.flink.table.codegen.{CodeGenerator, GeneratedCollector, 
GeneratedExpression, GeneratedFunction}
 import org.apache.flink.table.codegen.CodeGenUtils.primitiveDefaultValue
 import org.apache.flink.table.codegen.GeneratedExpression.{ALWAYS_NULL, 
NO_CODE}
 import org.apache.flink.table.functions.utils.TableSqlFunction
-import org.apache.flink.table.runtime.FlatMapRunner
+import org.apache.flink.table.runtime.{CorrelateFlatMapRunner, 
TableFunctionCollector}
 import org.apache.flink.table.typeutils.TypeConverter._
 import org.apache.flink.table.api.{TableConfig, TableException}
 
@@ -37,15 +37,22 @@ import scala.collection.JavaConverters._
   */
 trait FlinkCorrelate {
 
-  private[flink] def functionBody(
-      generator: CodeGenerator,
+  /**
+    * Creates the [[CorrelateFlatMapRunner]] to execute the join of input table
+    * and user-defined table function.
+    */
+  private[flink] def correlateMapFunction(
+      config: TableConfig,
+      inputTypeInfo: TypeInformation[Any],
       udtfTypeInfo: TypeInformation[Any],
       rowType: RelDataType,
+      joinType: SemiJoinType,
       rexCall: RexCall,
       condition: Option[RexNode],
-      config: TableConfig,
-      joinType: SemiJoinType,
-      expectedType: Option[TypeInformation[Any]]): String = {
+      expectedType: Option[TypeInformation[Any]],
+      pojoFieldMapping: Option[Array[Int]], // udtf return type pojo field 
mapping
+      ruleDescription: String)
+    : CorrelateFlatMapRunner[Any, Any] = {
 
     val returnType = determineReturnType(
       rowType,
@@ -53,24 +60,72 @@ trait FlinkCorrelate {
       config.getNullCheck,
       config.getEfficientTypeUsage)
 
-    val (input1AccessExprs, input2AccessExprs) = 
generator.generateCorrelateAccessExprs
+    val flatMap = generateFunction(
+      config,
+      inputTypeInfo,
+      udtfTypeInfo,
+      returnType,
+      rowType,
+      joinType,
+      rexCall,
+      pojoFieldMapping,
+      ruleDescription)
+
+    val collector = generateCollector(
+      config,
+      inputTypeInfo,
+      udtfTypeInfo,
+      returnType,
+      rowType,
+      condition,
+      pojoFieldMapping)
+
+    new CorrelateFlatMapRunner[Any, Any](
+      flatMap.name,
+      flatMap.code,
+      collector.name,
+      collector.code,
+      flatMap.returnType)
+
+  }
+
+  /**
+    * Generates the flat map function to run the user-defined table function.
+    */
+  private def generateFunction(
+      config: TableConfig,
+      inputTypeInfo: TypeInformation[Any],
+      udtfTypeInfo: TypeInformation[Any],
+      returnType: TypeInformation[Any],
+      rowType: RelDataType,
+      joinType: SemiJoinType,
+      rexCall: RexCall,
+      pojoFieldMapping: Option[Array[Int]],
+      ruleDescription: String)
+    : GeneratedFunction[FlatMapFunction[Any, Any]] = {
+
+    val functionGenerator = new CodeGenerator(
+      config,
+      false,
+      inputTypeInfo,
+      Some(udtfTypeInfo),
+      None,
+      pojoFieldMapping)
 
-    val call = generator.generateExpression(rexCall)
+    val (input1AccessExprs, input2AccessExprs) = 
functionGenerator.generateCorrelateAccessExprs
+
+    val collectorTerm = functionGenerator
+      .addReusableConstructor(classOf[TableFunctionCollector[_]])
+      .head
+
+    val call = functionGenerator.generateExpression(rexCall)
     var body =
       s"""
-         |${call.code}
-         |java.util.Iterator iter = ${call.resultTerm}.getRowsIterator();
-       """.stripMargin
+        |${call.resultTerm}.setCollector($collectorTerm);
+        |${call.code}
+        |""".stripMargin
 
-    if (joinType == SemiJoinType.INNER) {
-      // cross join
-      body +=
-        s"""
-           |if (!iter.hasNext()) {
-           |  return;
-           |}
-        """.stripMargin
-    } else if (joinType == SemiJoinType.LEFT) {
+    if (joinType == SemiJoinType.LEFT) {
       // left outer join
 
       // in case of left outer join and the returned row of table function is 
empty,
@@ -82,63 +137,78 @@ trait FlinkCorrelate {
           NO_CODE,
           x.resultType)
       }
-      val outerResultExpr = generator.generateResultExpression(
+      val outerResultExpr = functionGenerator.generateResultExpression(
         input1AccessExprs ++ input2NullExprs, returnType, 
rowType.getFieldNames.asScala)
       body +=
         s"""
-           |if (!iter.hasNext()) {
-           |  ${outerResultExpr.code}
-           |  
${generator.collectorTerm}.collect(${outerResultExpr.resultTerm});
-           |  return;
-           |}
-        """.stripMargin
-    } else {
+          |boolean hasOutput = $collectorTerm.isCollected();
+          |if (!hasOutput) {
+          |  ${outerResultExpr.code}
+          |  
${functionGenerator.collectorTerm}.collect(${outerResultExpr.resultTerm});
+          |}
+          |""".stripMargin
+    } else if (joinType != SemiJoinType.INNER) {
       throw TableException(s"Unsupported SemiJoinType: $joinType for correlate 
join.")
     }
 
+    functionGenerator.generateFunction(
+      ruleDescription,
+      classOf[FlatMapFunction[Any, Any]],
+      body,
+      returnType)
+  }
+
+  /**
+    * Generates table function collector.
+    */
+  private[flink] def generateCollector(
+      config: TableConfig,
+      inputTypeInfo: TypeInformation[Any],
+      udtfTypeInfo: TypeInformation[Any],
+      returnType: TypeInformation[Any],
+      rowType: RelDataType,
+      condition: Option[RexNode],
+      pojoFieldMapping: Option[Array[Int]])
+    : GeneratedCollector = {
+
+    val generator = new CodeGenerator(
+      config,
+      false,
+      inputTypeInfo,
+      Some(udtfTypeInfo),
+      None,
+      pojoFieldMapping)
+
+    val (input1AccessExprs, input2AccessExprs) = 
generator.generateCorrelateAccessExprs
+
     val crossResultExpr = generator.generateResultExpression(
       input1AccessExprs ++ input2AccessExprs,
       returnType,
       rowType.getFieldNames.asScala)
 
-    val projection = if (condition.isEmpty) {
+    val collectorCode = if (condition.isEmpty) {
       s"""
-         |${crossResultExpr.code}
-         |${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
-       """.stripMargin
+        |${crossResultExpr.code}
+        |getCollector().collect(${crossResultExpr.resultTerm});
+        |""".stripMargin
     } else {
       val filterGenerator = new CodeGenerator(config, false, udtfTypeInfo)
       filterGenerator.input1Term = filterGenerator.input2Term
       val filterCondition = filterGenerator.generateExpression(condition.get)
       s"""
-         |${filterGenerator.reuseInputUnboxingCode()}
-         |${filterCondition.code}
-         |if (${filterCondition.resultTerm}) {
-         |  ${crossResultExpr.code}
-         |  ${generator.collectorTerm}.collect(${crossResultExpr.resultTerm});
-         |}
-         |""".stripMargin
+        |${filterGenerator.reuseInputUnboxingCode()}
+        |${filterCondition.code}
+        |if (${filterCondition.resultTerm}) {
+        |  ${crossResultExpr.code}
+        |  getCollector().collect(${crossResultExpr.resultTerm});
+        |}
+        |""".stripMargin
     }
 
-    val outputTypeClass = udtfTypeInfo.getTypeClass.getCanonicalName
-    body +=
-      s"""
-         |while (iter.hasNext()) {
-         |  $outputTypeClass ${generator.input2Term} = ($outputTypeClass) 
iter.next();
-         |  $projection
-         |}
-       """.stripMargin
-    body
-  }
-
-  private[flink] def correlateMapFunction(
-      genFunction: GeneratedFunction[FlatMapFunction[Any, Any]])
-    : FlatMapRunner[Any, Any] = {
-
-    new FlatMapRunner[Any, Any](
-      genFunction.name,
-      genFunction.code,
-      genFunction.returnType)
+    generator.generateTableFunctionCollector(
+      "TableFunctionCollector",
+      collectorCode,
+      udtfTypeInfo)
   }
 
   private[flink] def selectToString(rowType: RelDataType): String = {

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
index fa1afc3..5a75e5d 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala
@@ -24,11 +24,9 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
 import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
 import org.apache.calcite.rex.{RexCall, RexNode}
 import org.apache.calcite.sql.SemiJoinType
-import org.apache.flink.api.common.functions.FlatMapFunction
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.DataSet
 import org.apache.flink.table.api.BatchTableEnvironment
-import org.apache.flink.table.codegen.CodeGenerator
 import org.apache.flink.table.functions.utils.TableSqlFunction
 import org.apache.flink.table.plan.nodes.FlinkCorrelate
 import org.apache.flink.table.typeutils.TypeConverter._
@@ -93,11 +91,6 @@ class DataSetCorrelate(
     : DataSet[Any] = {
 
     val config = tableEnv.getConfig
-    val returnType = determineReturnType(
-      getRowType,
-      expectedType,
-      config.getNullCheck,
-      config.getEfficientTypeUsage)
 
     // we do not need to specify input type
     val inputDS = inputNode.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
@@ -108,31 +101,17 @@ class DataSetCorrelate(
     val pojoFieldMapping = sqlFunction.getPojoFieldMapping
     val udtfTypeInfo = 
sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
 
-    val generator = new CodeGenerator(
+    val mapFunc = correlateMapFunction(
       config,
-      false,
       inputDS.getType,
-      Some(udtfTypeInfo),
-      None,
-      Some(pojoFieldMapping))
-
-    val body = functionBody(
-      generator,
       udtfTypeInfo,
       getRowType,
+      joinType,
       rexCall,
       condition,
-      config,
-      joinType,
-      expectedType)
-
-    val genFunction = generator.generateFunction(
-      ruleDescription,
-      classOf[FlatMapFunction[Any, Any]],
-      body,
-      returnType)
-
-    val mapFunc = correlateMapFunction(genFunction)
+      expectedType,
+      Some(pojoFieldMapping),
+      ruleDescription)
 
     inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, 
relRowType))
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
index a2d167b..bd65954 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala
@@ -23,9 +23,7 @@ import org.apache.calcite.rel.logical.LogicalTableFunctionScan
 import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
 import org.apache.calcite.rex.{RexCall, RexNode}
 import org.apache.calcite.sql.SemiJoinType
-import org.apache.flink.api.common.functions.FlatMapFunction
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.table.codegen.CodeGenerator
 import org.apache.flink.table.functions.utils.TableSqlFunction
 import org.apache.flink.table.plan.nodes.FlinkCorrelate
 import org.apache.flink.table.typeutils.TypeConverter._
@@ -87,11 +85,6 @@ class DataStreamCorrelate(
     : DataStream[Any] = {
 
     val config = tableEnv.getConfig
-    val returnType = determineReturnType(
-      getRowType,
-      expectedType,
-      config.getNullCheck,
-      config.getEfficientTypeUsage)
 
     // we do not need to specify input type
     val inputDS = 
inputNode.asInstanceOf[DataStreamRel].translateToPlan(tableEnv)
@@ -102,31 +95,17 @@ class DataStreamCorrelate(
     val pojoFieldMapping = sqlFunction.getPojoFieldMapping
     val udtfTypeInfo = 
sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]]
 
-    val generator = new CodeGenerator(
+    val mapFunc = correlateMapFunction(
       config,
-      false,
       inputDS.getType,
-      Some(udtfTypeInfo),
-      None,
-      Some(pojoFieldMapping))
-
-    val body = functionBody(
-      generator,
       udtfTypeInfo,
       getRowType,
+      joinType,
       rexCall,
       condition,
-      config,
-      joinType,
-      expectedType)
-
-    val genFunction = generator.generateFunction(
-      ruleDescription,
-      classOf[FlatMapFunction[Any, Any]],
-      body,
-      returnType)
-
-    val mapFunc = correlateMapFunction(genFunction)
+      expectedType,
+      Some(pojoFieldMapping),
+      ruleDescription)
 
     inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, 
relRowType))
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
new file mode 100644
index 0000000..4e803da
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/CorrelateFlatMapRunner.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime
+
+import org.apache.flink.api.common.functions.{FlatMapFunction, 
RichFlatMapFunction}
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.typeutils.ResultTypeQueryable
+import org.apache.flink.configuration.Configuration
+import org.apache.flink.table.codegen.Compiler
+import org.apache.flink.util.Collector
+import org.slf4j.{Logger, LoggerFactory}
+
+class CorrelateFlatMapRunner[IN, OUT](
+    flatMapName: String,
+    flatMapCode: String,
+    collectorName: String,
+    collectorCode: String,
+    @transient returnType: TypeInformation[OUT])
+  extends RichFlatMapFunction[IN, OUT]
+  with ResultTypeQueryable[OUT]
+  with Compiler[Any] {
+
+  val LOG: Logger = LoggerFactory.getLogger(this.getClass)
+
+  private var function: FlatMapFunction[IN, OUT] = _
+  private var collector: TableFunctionCollector[_] = _
+
+  override def open(parameters: Configuration): Unit = {
+    LOG.debug(s"Compiling TableFunctionCollector: $collectorName \n\n 
Code:\n$collectorCode")
+    val clazz = compile(getRuntimeContext.getUserCodeClassLoader, 
collectorName, collectorCode)
+    LOG.debug("Instantiating TableFunctionCollector.")
+    collector = clazz.newInstance().asInstanceOf[TableFunctionCollector[_]]
+
+    LOG.debug(s"Compiling FlatMapFunction: $flatMapName \n\n 
Code:\n$flatMapCode")
+    val flatMapClazz = compile(getRuntimeContext.getUserCodeClassLoader, 
flatMapName, flatMapCode)
+    val constructor = 
flatMapClazz.getConstructor(classOf[TableFunctionCollector[_]])
+    LOG.debug("Instantiating FlatMapFunction.")
+    function = 
constructor.newInstance(collector).asInstanceOf[FlatMapFunction[IN, OUT]]
+  }
+
+  override def flatMap(in: IN, out: Collector[OUT]): Unit = {
+    collector.setCollector(out)
+    collector.setInput(in)
+    collector.reset()
+    function.flatMap(in, out)
+  }
+
+  override def getProducedType: TypeInformation[OUT] = returnType
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/07865aaf/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala
new file mode 100644
index 0000000..c9cca47
--- /dev/null
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/TableFunctionCollector.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.runtime
+
+import org.apache.flink.util.Collector
+
+/**
+  * The basic implementation of collector for 
[[org.apache.flink.table.functions.TableFunction]].
+  */
+abstract class TableFunctionCollector[T] extends Collector[T] {
+
+  private var input: Any = _
+  private var collector: Collector[_] = _
+  private var collected: Boolean = _
+
+  /**
+    * Sets the input row from left table,
+    * which will be used to cross join with the result of table function.
+    */
+  def setInput(input: Any): Unit = {
+    this.input = input
+  }
+
+  /**
+    * Gets the input value from left table,
+    * which will be used to cross join with the result of table function.
+    */
+  def getInput: Any = {
+    input
+  }
+
+  /**
+    * Sets the current collector, which used to emit the final row.
+    */
+  def setCollector(collector: Collector[_]): Unit = {
+    this.collector = collector
+  }
+
+  /**
+    * Gets the internal collector which used to emit the final row.
+    */
+  def getCollector: Collector[_] = {
+    this.collector
+  }
+
+  /**
+    * Resets the flag to indicate whether [[collect(T)]] has been called.
+    */
+  def reset(): Unit = {
+    collected = false
+  }
+
+  /**
+    * Whether [[collect(T)]] has been called.
+    *
+    * @return True if [[collect(T)]] has been called.
+    */
+  def isCollected: Boolean = collected
+
+  override def collect(record: T): Unit = {
+    collected = true
+  }
+}
+
+

Reply via email to