[FLINK-6242] [table] Add code generation for DataSet Aggregates

This closes #3735.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/3b4542b8
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/3b4542b8
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/3b4542b8

Branch: refs/heads/master
Commit: 3b4542b8f0981f01e42c861bccbc67c8b3a20fdd
Parents: 4024aff
Author: shaoxuan-wang <wshaox...@gmail.com>
Authored: Tue Apr 18 21:45:49 2017 +0800
Committer: Fabian Hueske <fhue...@apache.org>
Committed: Fri Apr 21 21:28:54 2017 +0200

----------------------------------------------------------------------
 .../flink/table/codegen/CodeGenerator.scala     | 224 +++++++-----
 .../plan/nodes/dataset/DataSetAggregate.scala   |  11 +-
 .../nodes/dataset/DataSetWindowAggregate.scala  |  25 +-
 .../table/runtime/aggregate/AggregateUtil.scala | 359 +++++++++++--------
 .../runtime/aggregate/DataSetAggFunction.scala  | 109 ++----
 .../aggregate/DataSetFinalAggFunction.scala     | 121 ++-----
 .../aggregate/DataSetPreAggFunction.scala       |  93 ++---
 ...SetSessionWindowAggReduceGroupFunction.scala | 116 ++----
 ...aSetSessionWindowAggregatePreProcessor.scala | 145 +++-----
 ...tSlideTimeWindowAggReduceGroupFunction.scala | 177 +++------
 ...SetSlideWindowAggReduceCombineFunction.scala |  98 ++---
 ...taSetSlideWindowAggReduceGroupFunction.scala | 123 +++----
 ...umbleCountWindowAggReduceGroupFunction.scala |  89 ++---
 ...mbleTimeWindowAggReduceCombineFunction.scala |  88 ++---
 ...TumbleTimeWindowAggReduceGroupFunction.scala |  99 ++---
 .../aggregate/DataSetWindowAggMapFunction.scala |  64 ++--
 .../aggregate/GeneratedAggregations.scala       |  32 +-
 .../scala/batch/table/AggregationsITCase.scala  |   3 +-
 ...ProcessingOverRangeProcessFunctionTest.scala |   8 +-
 19 files changed, 878 insertions(+), 1106 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index c6e3c9a..510a870 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -250,57 +250,88 @@ class CodeGenerator(
     * @param aggregates  All aggregate functions
     * @param aggFields   Indexes of the input fields for all aggregate 
functions
     * @param aggMapping  The mapping of aggregates to output fields
+    * @param partialResults A flag defining whether final or partial results 
(accumulators) are set
+    *                       to the output row.
     * @param fwdMapping  The mapping of input fields to output fields
+    * @param mergeMapping An optional mapping to specify the accumulators to 
merge. If not set, we
+    *                     assume that both rows have the accumulators at the 
same position.
+    * @param constantFlags An optional parameter to define where to set 
constant boolean flags in
+    *                      the output row.
     * @param outputArity The number of fields in the output row.
     *
     * @return A GeneratedAggregationsFunction
     */
   def generateAggregations(
-     name: String,
-     generator: CodeGenerator,
-     inputType: RelDataType,
-     aggregates: Array[AggregateFunction[_ <: Any]],
-     aggFields: Array[Array[Int]],
-     aggMapping: Array[Int],
-     fwdMapping: Array[(Int, Int)],
-     outputArity: Int)
+      name: String,
+      generator: CodeGenerator,
+      inputType: RelDataType,
+      aggregates: Array[AggregateFunction[_ <: Any]],
+      aggFields: Array[Array[Int]],
+      aggMapping: Array[Int],
+      partialResults: Boolean,
+      fwdMapping: Array[Int],
+      mergeMapping: Option[Array[Int]],
+      constantFlags: Option[Array[(Int, Boolean)]],
+      outputArity: Int)
   : GeneratedAggregationsFunction = {
 
-    def genSetAggregationResults(
-      accTypes: Array[String],
-      aggs: Array[String],
-      aggMapping: Array[Int]): String = {
+    // get unique function name
+    val funcName = newName(name)
+    // register UDAGGs
+    val aggs = aggregates.map(a => generator.addReusableFunction(a))
+    // get java types of accumulators
+    val accTypes = aggregates.map { a =>
+      a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
+    }
+
+    // get java types of input fields
+    val javaTypes = inputType.getFieldList
+      .map(f => FlinkTypeFactory.toTypeInfo(f.getType))
+      .map(t => t.getTypeClass.getCanonicalName)
+    // get parameter lists for aggregation functions
+    val parameters = aggFields.map {inFields =>
+      val fields = for (f <- inFields) yield s"(${javaTypes(f)}) 
input.getField($f)"
+      fields.mkString(", ")
+    }
+
+    def genSetAggregationResults: String = {
 
       val sig: String =
         j"""
-            |  public void setAggregationResults(
-            |    org.apache.flink.types.Row accs,
-            |    org.apache.flink.types.Row output)""".stripMargin
+           |  public final void setAggregationResults(
+           |    org.apache.flink.types.Row accs,
+           |    org.apache.flink.types.Row output)""".stripMargin
 
       val setAggs: String = {
         for (i <- aggs.indices) yield
-          j"""
-             |    org.apache.flink.table.functions.AggregateFunction 
baseClass$i =
-             |      (org.apache.flink.table.functions.AggregateFunction) 
${aggs(i)};
-             |
-             |    output.setField(
-             |      ${aggMapping(i)},
-             |      baseClass$i.getValue((${accTypes(i)}) 
accs.getField($i)));""".stripMargin
+
+          if (partialResults) {
+            j"""
+               |    output.setField(
+               |      ${aggMapping(i)},
+               |      (${accTypes(i)}) accs.getField($i));""".stripMargin
+          } else {
+            j"""
+               |    org.apache.flink.table.functions.AggregateFunction 
baseClass$i =
+               |      (org.apache.flink.table.functions.AggregateFunction) 
${aggs(i)};
+               |
+               |    output.setField(
+               |      ${aggMapping(i)},
+               |      baseClass$i.getValue((${accTypes(i)}) 
accs.getField($i)));""".stripMargin
+          }
       }.mkString("\n")
 
-      j"""$sig {
+      j"""
+         |$sig {
          |$setAggs
          |  }""".stripMargin
     }
 
-    def genAccumulate(
-     accTypes: Array[String],
-     aggs: Array[String],
-     parameters: Array[String]): String = {
+    def genAccumulate: String = {
 
       val sig: String =
         j"""
-            |  public void accumulate(
+            |  public final void accumulate(
             |    org.apache.flink.types.Row accs,
             |    org.apache.flink.types.Row input)""".stripMargin
 
@@ -317,14 +348,11 @@ class CodeGenerator(
          |  }""".stripMargin
     }
 
-    def genRetract(
-      accTypes: Array[String],
-      aggs: Array[String],
-      parameters: Array[String]): String = {
+    def genRetract: String = {
 
       val sig: String =
         j"""
-            |  public void retract(
+            |  public final void retract(
             |    org.apache.flink.types.Row accs,
             |    org.apache.flink.types.Row input)""".stripMargin
 
@@ -341,12 +369,11 @@ class CodeGenerator(
          |  }""".stripMargin
     }
 
-    def genCreateAccumulators(
-        aggs: Array[String]): String = {
+    def genCreateAccumulators: String = {
 
       val sig: String =
         j"""
-           |  public org.apache.flink.types.Row createAccumulators()
+           |  public final org.apache.flink.types.Row createAccumulators()
            |    """.stripMargin
       val init: String =
         j"""
@@ -373,22 +400,24 @@ class CodeGenerator(
          |  }""".stripMargin
     }
 
-    def genSetForwardedFields(
-        forwardMapping: Array[(Int, Int)]): String = {
+    def genSetForwardedFields: String = {
 
       val sig: String =
         j"""
-           |  public void setForwardedFields(
+           |  public final void setForwardedFields(
            |    org.apache.flink.types.Row input,
            |    org.apache.flink.types.Row output)
            |    """.stripMargin
+
       val forward: String = {
-        for (i <- forwardMapping.indices) yield
-          j"""
-             |    output.setField(
-             |      ${forwardMapping(i)._1},
-             |      input.getField(${forwardMapping(i)._2}));"""
-            .stripMargin
+        for (i <- fwdMapping.indices if fwdMapping(i) >= 0) yield
+          {
+            j"""
+               |    output.setField(
+               |      $i,
+               |      input.getField(${fwdMapping(i)}));"""
+              .stripMargin
+          }
       }.mkString("\n")
 
       j"""$sig {
@@ -396,20 +425,44 @@ class CodeGenerator(
          |  }""".stripMargin
     }
 
-    def genCreateOutputRow(outputArity: Int): String = {
+    def genSetConstantFlags: String = {
+
+      val sig: String =
+        j"""
+           |  public final void setConstantFlags(org.apache.flink.types.Row 
output)
+           |    """.stripMargin
+
+      val setFlags: String = if (constantFlags.isDefined) {
+        {
+          for (cf <- constantFlags.get) yield {
+            j"""
+               |    output.setField(${cf._1}, ${if (cf._2) "true" else 
"false"});"""
+              .stripMargin
+          }
+        }.mkString("\n")
+      } else {
+        ""
+      }
+
+      j"""$sig {
+         |$setFlags
+         |  }""".stripMargin
+    }
+
+    def genCreateOutputRow: String = {
       j"""
-         |  public org.apache.flink.types.Row createOutputRow() {
+         |  public final org.apache.flink.types.Row createOutputRow() {
          |    return new org.apache.flink.types.Row($outputArity);
          |  }""".stripMargin
     }
 
-    def genMergeAccumulatorsPair(
-        accTypes: Array[String],
-        aggs: Array[String]): String = {
+    def genMergeAccumulatorsPair: String = {
+
+      val mapping = mergeMapping.getOrElse(aggs.indices.toArray)
 
       val sig: String =
         j"""
-           |  public org.apache.flink.types.Row mergeAccumulatorsPair(
+           |  public final org.apache.flink.types.Row mergeAccumulatorsPair(
            |    org.apache.flink.types.Row a,
            |    org.apache.flink.types.Row b)
            """.stripMargin
@@ -417,7 +470,7 @@ class CodeGenerator(
         for (i <- aggs.indices) yield
           j"""
              |    ${accTypes(i)} aAcc$i = (${accTypes(i)}) a.getField($i);
-             |    ${accTypes(i)} bAcc$i = (${accTypes(i)}) b.getField($i);
+             |    ${accTypes(i)} bAcc$i = (${accTypes(i)}) 
b.getField(${mapping(i)});
              |    accList$i.set(0, aAcc$i);
              |    accList$i.set(1, bAcc$i);
              |    a.setField(
@@ -430,75 +483,76 @@ class CodeGenerator(
            |      return a;
            """.stripMargin
 
-      j"""$sig {
+      j"""
+         |$sig {
          |$merge
          |$ret
          |  }""".stripMargin
     }
 
-    def genMergeList(accTypes: Array[String]): String = {
+    def genMergeList: String = {
       {
         for (i <- accTypes.indices) yield
           j"""
-             |    java.util.ArrayList<${accTypes(i)}> accList$i;
+             |    private final java.util.ArrayList<${accTypes(i)}> accList$i =
+             |      new java.util.ArrayList<${accTypes(i)}>(2);
              """.stripMargin
       }.mkString("\n")
     }
 
-    def initMergeList(
-        accTypes: Array[String],
-        aggs: Array[String]): String = {
+    def initMergeList: String = {
       {
         for (i <- accTypes.indices) yield
           j"""
-             |    accList$i = new java.util.ArrayList<${accTypes(i)}>(2);
              |    accList$i.add(${aggs(i)}.createAccumulator());
              |    accList$i.add(${aggs(i)}.createAccumulator());
              """.stripMargin
       }.mkString("\n")
     }
 
-    // get unique function name
-    val funcName = newName(name)
-    // register UDAGGs
-    val aggs = aggregates.map(a => generator.addReusableFunction(a))
-    // get java types of accumulators
-    val accTypes = aggregates.map { a =>
-      a.getClass.getMethod("createAccumulator").getReturnType.getCanonicalName
-    }
+    def genResetAccumulator: String = {
 
-    // get java types of input fields
-    val javaTypes = inputType.getFieldList
-      .map(f => FlinkTypeFactory.toTypeInfo(f.getType))
-      .map(t => t.getTypeClass.getCanonicalName)
-    // get parameter lists for aggregation functions
-    val parameters = aggFields.map {inFields =>
-      val fields = for (f <- inFields) yield s"(${javaTypes(f)}) 
input.getField($f)"
-      fields.mkString(", ")
+      val sig: String =
+        j"""
+           |  public final void resetAccumulator(
+           |    org.apache.flink.types.Row accs)""".stripMargin
+
+      val reset: String = {
+        for (i <- aggs.indices) yield
+          j"""
+             |    ${aggs(i)}.resetAccumulator(
+             |      ((${accTypes(i)}) accs.getField($i)));""".stripMargin
+      }.mkString("\n")
+
+      j"""$sig {
+         |$reset
+         |  }""".stripMargin
     }
 
     var funcCode =
       j"""
-         |public class $funcName
+         |public final class $funcName
          |  extends 
org.apache.flink.table.runtime.aggregate.GeneratedAggregations {
          |
          |  ${reuseMemberCode()}
-         |  ${genMergeList(accTypes)}
+         |  $genMergeList
          |  public $funcName() throws Exception {
          |    ${reuseInitCode()}
-         |    ${initMergeList(accTypes, aggs)}
+         |    $initMergeList
          |  }
          |  ${reuseConstructorCode(funcName)}
          |
          """.stripMargin
 
-    funcCode += genSetAggregationResults(accTypes, aggs, aggMapping) + "\n"
-    funcCode += genAccumulate(accTypes, aggs, parameters) + "\n"
-    funcCode += genRetract(accTypes, aggs, parameters) + "\n"
-    funcCode += genCreateAccumulators(aggs) + "\n"
-    funcCode += genSetForwardedFields(fwdMapping) + "\n"
-    funcCode += genCreateOutputRow(outputArity) + "\n"
-    funcCode += genMergeAccumulatorsPair(accTypes, aggs) + "\n"
+    funcCode += genSetAggregationResults + "\n"
+    funcCode += genAccumulate + "\n"
+    funcCode += genRetract + "\n"
+    funcCode += genCreateAccumulators + "\n"
+    funcCode += genSetForwardedFields + "\n"
+    funcCode += genSetConstantFlags + "\n"
+    funcCode += genCreateOutputRow + "\n"
+    funcCode += genMergeAccumulatorsPair + "\n"
+    funcCode += genResetAccumulator + "\n"
     funcCode += "}"
 
     GeneratedAggregationsFunction(funcName, funcCode)

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index 5a4aa59..b92775c 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -29,6 +29,7 @@ import org.apache.flink.api.java.DataSet
 import org.apache.flink.api.java.typeutils.RowTypeInfo
 import org.apache.flink.table.api.BatchTableEnvironment
 import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.codegen.CodeGenerator
 import org.apache.flink.table.plan.nodes.CommonAggregate
 import org.apache.flink.table.runtime.aggregate.{AggregateUtil, 
DataSetPreAggFunction}
 import org.apache.flink.table.runtime.aggregate.AggregateUtil.CalcitePair
@@ -89,19 +90,25 @@ class DataSetAggregate(
 
   override def translateToPlan(tableEnv: BatchTableEnvironment): DataSet[Row] 
= {
 
+    val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
+
+    val generator = new CodeGenerator(
+      tableEnv.getConfig,
+      false,
+      inputDS.getType)
+
     val (
       preAgg: Option[DataSetPreAggFunction],
       preAggType: Option[TypeInformation[Row]],
       finalAgg: GroupReduceFunction[Row, Row]
       ) = AggregateUtil.createDataSetAggregateFunctions(
+        generator,
         namedAggregates,
         inputType,
         rowRelDataType,
         grouping,
         inGroupingSet)
 
-    val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
-
     val aggString = aggregationToString(inputType, grouping, getRowType, 
namedAggregates, Nil)
 
     val rowTypeInfo = 
FlinkTypeFactory.toInternalRowTypeInfo(getRowType).asInstanceOf[RowTypeInfo]

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
index a94deb1..96c427e 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
@@ -28,6 +28,7 @@ import org.apache.flink.api.java.typeutils.ResultTypeQueryable
 import org.apache.flink.table.api.BatchTableEnvironment
 import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
 import org.apache.flink.table.calcite.FlinkTypeFactory
+import org.apache.flink.table.codegen.CodeGenerator
 import org.apache.flink.table.plan.logical._
 import org.apache.flink.table.plan.nodes.CommonAggregate
 import org.apache.flink.table.runtime.aggregate.AggregateUtil.{CalcitePair, _}
@@ -109,21 +110,28 @@ class DataSetWindowAggregate(
 
     val inputDS = getInput.asInstanceOf[DataSetRel].translateToPlan(tableEnv)
 
+    val generator = new CodeGenerator(
+      tableEnv.getConfig,
+      false,
+      inputDS.getType)
+
     // whether identifiers are matched case-sensitively
     val caseSensitive = 
tableEnv.getFrameworkConfig.getParserConfig.caseSensitive()
 
     window match {
       case EventTimeTumblingGroupWindow(_, _, size) =>
         createEventTimeTumblingWindowDataSet(
+          generator,
           inputDS,
           isTimeInterval(size.resultType),
           caseSensitive)
 
       case EventTimeSessionGroupWindow(_, _, gap) =>
-        createEventTimeSessionWindowDataSet(inputDS, caseSensitive)
+        createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive)
 
       case EventTimeSlidingGroupWindow(_, _, size, slide) =>
         createEventTimeSlidingWindowDataSet(
+          generator,
           inputDS,
           isTimeInterval(size.resultType),
           asLong(size),
@@ -139,17 +147,20 @@ class DataSetWindowAggregate(
   }
 
   private def createEventTimeTumblingWindowDataSet(
+      generator: CodeGenerator,
       inputDS: DataSet[Row],
       isTimeWindow: Boolean,
       isParserCaseSensitive: Boolean): DataSet[Row] = {
 
     val mapFunction = createDataSetWindowPrepareMapFunction(
+      generator,
       window,
       namedAggregates,
       grouping,
       inputType,
       isParserCaseSensitive)
     val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
+      generator,
       window,
       namedAggregates,
       inputType,
@@ -195,6 +206,7 @@ class DataSetWindowAggregate(
   }
 
   private[this] def createEventTimeSessionWindowDataSet(
+      generator: CodeGenerator,
       inputDS: DataSet[Row],
       isParserCaseSensitive: Boolean): DataSet[Row] = {
 
@@ -203,6 +215,7 @@ class DataSetWindowAggregate(
 
     // create mapFunction for initializing the aggregations
     val mapFunction = createDataSetWindowPrepareMapFunction(
+      generator,
       window,
       namedAggregates,
       grouping,
@@ -229,6 +242,7 @@ class DataSetWindowAggregate(
       if (groupingKeys.length > 0) {
         // create groupCombineFunction for combine the aggregations
         val combineGroupFunction = 
createDataSetWindowAggregationCombineFunction(
+          generator,
           window,
           namedAggregates,
           inputType,
@@ -236,6 +250,7 @@ class DataSetWindowAggregate(
 
         // create groupReduceFunction for calculating the aggregations
         val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
+          generator,
           window,
           namedAggregates,
           inputType,
@@ -257,6 +272,7 @@ class DataSetWindowAggregate(
       } else {
         // non-grouping window
         val mapPartitionFunction = 
createDataSetWindowAggregationMapPartitionFunction(
+          generator,
           window,
           namedAggregates,
           inputType,
@@ -264,6 +280,7 @@ class DataSetWindowAggregate(
 
         // create groupReduceFunction for calculating the aggregations
         val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
+          generator,
           window,
           namedAggregates,
           inputType,
@@ -288,6 +305,7 @@ class DataSetWindowAggregate(
 
         // create groupReduceFunction for calculating the aggregations
         val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
+          generator,
           window,
           namedAggregates,
           inputType,
@@ -303,6 +321,7 @@ class DataSetWindowAggregate(
       } else {
         // non-grouping window
         val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
+          generator,
           window,
           namedAggregates,
           inputType,
@@ -320,6 +339,7 @@ class DataSetWindowAggregate(
   }
 
   private def createEventTimeSlidingWindowDataSet(
+      generator: CodeGenerator,
       inputDS: DataSet[Row],
       isTimeWindow: Boolean,
       size: Long,
@@ -330,6 +350,7 @@ class DataSetWindowAggregate(
     // create MapFunction for initializing the aggregations
     // it aligns the rowtime for pre-tumbling in case of a time-window for 
partial aggregates
     val mapFunction = createDataSetWindowPrepareMapFunction(
+      generator,
       window,
       namedAggregates,
       grouping,
@@ -365,6 +386,7 @@ class DataSetWindowAggregate(
         // create GroupReduceFunction
         // for pre-tumbling and replicating/omitting the content for each pane
         val prepareReduceFunction = 
createDataSetSlideWindowPrepareGroupReduceFunction(
+          generator,
           window,
           namedAggregates,
           grouping,
@@ -401,6 +423,7 @@ class DataSetWindowAggregate(
 
     // create GroupReduceFunction for final aggregation and conversion to 
output row
     val aggregateReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
+      generator,
       window,
       namedAggregates,
       inputType,

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index da57153..2c503c6 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -82,7 +82,7 @@ object AggregateUtil {
     val aggregationStateType: RowTypeInfo =
       createDataSetAggregateBufferDataType(Array(), aggregates, inputType)
 
-    val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, 
x)).toArray
+    val forwardMapping = (0 until inputType.getFieldCount).toArray
     val aggMapping = aggregates.indices.map(x => x + 
inputType.getFieldCount).toArray
     val outputArity = inputType.getFieldCount + aggregates.length
 
@@ -93,7 +93,10 @@ object AggregateUtil {
       aggregates,
       aggFields,
       aggMapping,
+      partialResults = false,
       forwardMapping,
+      None,
+      None,
       outputArity
     )
 
@@ -153,7 +156,7 @@ object AggregateUtil {
     val aggregationStateType: RowTypeInfo = 
createAccumulatorRowType(aggregates)
     val inputRowType = 
FlinkTypeFactory.toInternalRowTypeInfo(inputType).asInstanceOf[RowTypeInfo]
 
-    val forwardMapping = (0 until inputType.getFieldCount).map(x => (x, 
x)).toArray
+    val forwardMapping = (0 until inputType.getFieldCount).toArray
     val aggMapping = aggregates.indices.map(x => x + 
inputType.getFieldCount).toArray
     val outputArity = inputType.getFieldCount + aggregates.length
 
@@ -164,7 +167,10 @@ object AggregateUtil {
       aggregates,
       aggFields,
       aggMapping,
+      partialResults = false,
       forwardMapping,
+      None,
+      None,
       outputArity
     )
 
@@ -225,6 +231,7 @@ object AggregateUtil {
     * NOTE: this function is only used for time based window on batch tables.
     */
   def createDataSetWindowPrepareMapFunction(
+      generator: CodeGenerator,
       window: LogicalWindow,
       namedAggregates: Seq[CalcitePair[AggregateCall, String]],
       groupings: Array[Int],
@@ -249,7 +256,7 @@ object AggregateUtil {
         val timeFieldPos = getTimeFieldPosition(time, inputType, 
isParserCaseSensitive)
         (timeFieldPos, Some(asLong(size)))
 
-      case EventTimeTumblingGroupWindow(_, time, size) =>
+      case EventTimeTumblingGroupWindow(_, time, _) =>
         val timeFieldPos = getTimeFieldPosition(time, inputType, 
isParserCaseSensitive)
         (timeFieldPos, None)
 
@@ -272,10 +279,25 @@ object AggregateUtil {
         throw new UnsupportedOperationException(s"$window is currently not 
supported on batch")
     }
 
-    new DataSetWindowAggMapFunction(
+    val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)
+    val outputArity = aggregates.length + groupings.length + 1
+
+    val genFunction = generator.generateAggregations(
+      "DataSetAggregatePrepareMapHelper",
+      generator,
+      inputType,
       aggregates,
       aggFieldIndexes,
+      aggMapping,
+      partialResults = true,
       groupings,
+      None,
+      None,
+      outputArity
+    )
+
+    new DataSetWindowAggMapFunction(
+      genFunction,
       timeFieldPos,
       tumbleTimeWindowSize,
       mapReturnType)
@@ -309,6 +331,7 @@ object AggregateUtil {
     * NOTE: this function is only used for sliding windows with partial 
aggregates on batch tables.
     */
   def createDataSetSlideWindowPrepareGroupReduceFunction(
+      generator: CodeGenerator,
       window: LogicalWindow,
       namedAggregates: Seq[CalcitePair[AggregateCall, String]],
       groupings: Array[Int],
@@ -316,10 +339,10 @@ object AggregateUtil {
       isParserCaseSensitive: Boolean)
     : RichGroupReduceFunction[Row, Row] = {
 
-    val aggregates = transformToAggregateFunctions(
+    val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       inputType,
-      needRetraction = false)._2
+      needRetraction = false)
 
     val returnType: RowTypeInfo = createDataSetAggregateBufferDataType(
       groupings,
@@ -327,13 +350,27 @@ object AggregateUtil {
       inputType,
       Some(Array(BasicTypeInfo.LONG_TYPE_INFO)))
 
+    val keysAndAggregatesArity = groupings.length + namedAggregates.length
+
     window match {
       case EventTimeSlidingGroupWindow(_, _, size, slide) if 
isTimeInterval(size.resultType) =>
         // sliding time-window for partial aggregations
-        new DataSetSlideTimeWindowAggReduceGroupFunction(
+        val genFunction = generator.generateAggregations(
+          "DataSetAggregatePrepareMapHelper",
+          generator,
+          inputType,
           aggregates,
-          groupings.length,
-          returnType.getArity - 1,
+          aggFieldIndexes,
+          aggregates.indices.map(_ + groupings.length).toArray,
+          partialResults = true,
+          groupings,
+          Some(aggregates.indices.map(_ + groupings.length).toArray),
+          None,
+          keysAndAggregatesArity + 1
+        )
+        new DataSetSlideTimeWindowAggReduceGroupFunction(
+          genFunction,
+          keysAndAggregatesArity,
           asLong(size),
           asLong(slide),
           returnType)
@@ -400,6 +437,7 @@ object AggregateUtil {
     * NOTE: this function is only used for window on batch tables.
     */
   def createDataSetWindowAggregationGroupReduceFunction(
+      generator: CodeGenerator,
       window: LogicalWindow,
       namedAggregates: Seq[CalcitePair[AggregateCall, String]],
       inputType: RelDataType,
@@ -414,19 +452,37 @@ object AggregateUtil {
       inputType,
       needRetraction = false)
 
-    // the mapping relation between field index of intermediate aggregate Row 
and output Row.
-    val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, 
groupings)
+    val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)
+
+    val genPreAggFunction = generator.generateAggregations(
+      "GroupingWindowAggregateHelper",
+      generator,
+      inputType,
+      aggregates,
+      aggFieldIndexes,
+      aggMapping,
+      partialResults = true,
+      groupings,
+      Some(aggregates.indices.map(_ + groupings.length).toArray),
+      None,
+      outputType.getFieldCount
+    )
 
-    // the mapping relation between aggregate function index in list and its 
corresponding
-    // field index in output Row.
-    val aggOffsetMapping = getAggregateMapping(namedAggregates, outputType)
+    val genFinalAggFunction = generator.generateAggregations(
+      "GroupingWindowAggregateHelper",
+      generator,
+      inputType,
+      aggregates,
+      aggFieldIndexes,
+      aggMapping,
+      partialResults = false,
+      groupings.indices.toArray,
+      Some(aggregates.indices.map(_ + groupings.length).toArray),
+      None,
+      outputType.getFieldCount
+    )
 
-    if (groupingOffsetMapping.length != groupings.length ||
-      aggOffsetMapping.length != namedAggregates.length) {
-      throw new TableException(
-        "Could not find output field in input data type " +
-          "or aggregate functions.")
-    }
+    val keysAndAggregatesArity = groupings.length + namedAggregates.length
 
     window match {
       case EventTimeTumblingGroupWindow(_, _, size) if 
isTimeInterval(size.resultType) =>
@@ -435,41 +491,33 @@ object AggregateUtil {
         if (doAllSupportPartialMerge(aggregates)) {
           // for incremental aggregations
           new DataSetTumbleTimeWindowAggReduceCombineFunction(
+            genPreAggFunction,
+            genFinalAggFunction,
             asLong(size),
             startPos,
             endPos,
-            aggregates,
-            groupingOffsetMapping,
-            aggOffsetMapping,
-            outputType.getFieldCount)
+            keysAndAggregatesArity)
         }
         else {
           // for non-incremental aggregations
           new DataSetTumbleTimeWindowAggReduceGroupFunction(
+            genFinalAggFunction,
             asLong(size),
             startPos,
             endPos,
-            aggregates,
-            groupingOffsetMapping,
-            aggOffsetMapping,
             outputType.getFieldCount)
         }
       case EventTimeTumblingGroupWindow(_, _, size) =>
         // tumbling count window
         new DataSetTumbleCountWindowAggReduceGroupFunction(
-          asLong(size),
-          aggregates,
-          groupingOffsetMapping,
-          aggOffsetMapping,
-          outputType.getFieldCount)
+          genFinalAggFunction,
+          asLong(size))
 
       case EventTimeSessionGroupWindow(_, _, gap) =>
         val (startPos, endPos) = computeWindowStartEndPropertyPos(properties)
         new DataSetSessionWindowAggReduceGroupFunction(
-          aggregates,
-          groupingOffsetMapping,
-          aggOffsetMapping,
-          outputType.getFieldCount,
+          genFinalAggFunction,
+          keysAndAggregatesArity,
           startPos,
           endPos,
           asLong(gap),
@@ -480,10 +528,9 @@ object AggregateUtil {
         if (doAllSupportPartialMerge(aggregates)) {
           // for partial aggregations
           new DataSetSlideWindowAggReduceCombineFunction(
-            aggregates,
-            groupingOffsetMapping,
-            aggOffsetMapping,
-            outputType.getFieldCount,
+            genPreAggFunction,
+            genFinalAggFunction,
+            keysAndAggregatesArity,
             startPos,
             endPos,
             asLong(size))
@@ -491,10 +538,8 @@ object AggregateUtil {
         else {
           // for non-partial aggregations
           new DataSetSlideWindowAggReduceGroupFunction(
-            aggregates,
-            groupingOffsetMapping,
-            aggOffsetMapping,
-            outputType.getFieldCount,
+            genFinalAggFunction,
+            keysAndAggregatesArity,
             startPos,
             endPos,
             asLong(size))
@@ -502,10 +547,8 @@ object AggregateUtil {
 
       case EventTimeSlidingGroupWindow(_, _, size, _) =>
         new DataSetSlideWindowAggReduceGroupFunction(
-            aggregates,
-            groupingOffsetMapping,
-            aggOffsetMapping,
-            outputType.getFieldCount,
+            genFinalAggFunction,
+            keysAndAggregatesArity,
             None,
             None,
             asLong(size))
@@ -537,15 +580,20 @@ object AggregateUtil {
     *
     */
   def createDataSetWindowAggregationMapPartitionFunction(
+    generator: CodeGenerator,
     window: LogicalWindow,
     namedAggregates: Seq[CalcitePair[AggregateCall, String]],
     inputType: RelDataType,
     groupings: Array[Int]): MapPartitionFunction[Row, Row] = {
 
-    val aggregates = transformToAggregateFunctions(
+    val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       inputType,
-      needRetraction = false)._2
+      needRetraction = false)
+
+    val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
+
+    val keysAndAggregatesArity = groupings.length + namedAggregates.length
 
     window match {
       case EventTimeSessionGroupWindow(_, _, gap) =>
@@ -556,9 +604,23 @@ object AggregateUtil {
             inputType,
             Option(Array(BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.LONG_TYPE_INFO)))
 
-        new DataSetSessionWindowAggregatePreProcessor(
+        val genFunction = generator.generateAggregations(
+          "GroupingWindowAggregateHelper",
+          generator,
+          inputType,
           aggregates,
-          groupings,
+          aggFieldIndexes,
+          aggMapping,
+          partialResults = true,
+          groupings.indices.toArray,
+          Some(aggregates.indices.map(_ + groupings.length).toArray),
+          None,
+          groupings.length + aggregates.length + 2
+        )
+
+        new DataSetSessionWindowAggregatePreProcessor(
+          genFunction,
+          keysAndAggregatesArity,
           asLong(gap),
           combineReturnType)
       case _ =>
@@ -585,16 +647,21 @@ object AggregateUtil {
     *
     */
   private[flink] def createDataSetWindowAggregationCombineFunction(
+      generator: CodeGenerator,
       window: LogicalWindow,
       namedAggregates: Seq[CalcitePair[AggregateCall, String]],
       inputType: RelDataType,
       groupings: Array[Int])
     : GroupCombineFunction[Row, Row] = {
 
-    val aggregates = transformToAggregateFunctions(
+    val (aggFieldIndexes, aggregates) = transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       inputType,
-      needRetraction = false)._2
+      needRetraction = false)
+
+    val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
+
+    val keysAndAggregatesArity = groupings.length + namedAggregates.length
 
     window match {
 
@@ -606,9 +673,23 @@ object AggregateUtil {
             inputType,
             Option(Array(BasicTypeInfo.LONG_TYPE_INFO, 
BasicTypeInfo.LONG_TYPE_INFO)))
 
-        new DataSetSessionWindowAggregatePreProcessor(
+        val genFunction = generator.generateAggregations(
+          "GroupingWindowAggregateHelper",
+          generator,
+          inputType,
           aggregates,
-          groupings,
+          aggFieldIndexes,
+          aggMapping,
+          partialResults = true,
+          groupings.indices.toArray,
+          Some(aggregates.indices.map(_ + groupings.length).toArray),
+          None,
+          groupings.length + aggregates.length + 2
+        )
+
+        new DataSetSessionWindowAggregatePreProcessor(
+          genFunction,
+          keysAndAggregatesArity,
           asLong(gap),
           combineReturnType)
 
@@ -625,6 +706,7 @@ object AggregateUtil {
     * respective output type are generated as well.
     */
   private[flink] def createDataSetAggregateFunctions(
+      generator: CodeGenerator,
       namedAggregates: Seq[CalcitePair[AggregateCall, String]],
       inputType: RelDataType,
       outputType: RelDataType,
@@ -645,51 +727,91 @@ object AggregateUtil {
       outputType
     )
 
-    val groupingSetsMapping: Array[(Int, Int)] = if (inGroupingSet) {
-      getGroupingSetsIndicatorMapping(inputType, outputType)
+    val constantFlags: Option[Array[(Int, Boolean)]] =
+    if (inGroupingSet) {
+
+      val groupingSetsMapping = getGroupingSetsIndicatorMapping(inputType, 
outputType)
+      val nonNullKeysFields = gkeyOutMapping.map(_._1)
+      val flags = for ((in, out) <- groupingSetsMapping) yield
+        (out, !nonNullKeysFields.contains(in))
+      Some(flags)
     } else {
-      Array()
+      None
     }
 
-    if (doAllSupportPartialMerge(aggregates)) {
+    val aggOutFields = aggOutMapping.map(_._1)
 
-      // compute grouping key and aggregation positions
-      val gkeyInFields = gkeyOutMapping.map(_._2)
-      val gkeyOutFields = gkeyOutMapping.map(_._1)
-      val aggOutFields = aggOutMapping.map(_._1)
+    if (doAllSupportPartialMerge(aggregates)) {
 
       // compute preaggregation type
-      val preAggFieldTypes = gkeyInFields
+      val preAggFieldTypes = gkeyOutMapping.map(_._2)
         .map(inputType.getFieldList.get(_).getType)
         .map(FlinkTypeFactory.toTypeInfo) ++ createAccumulatorType(aggregates)
       val preAggRowType = new RowTypeInfo(preAggFieldTypes: _*)
 
+      val genPreAggFunction = generator.generateAggregations(
+        "DataSetAggregatePrepareMapHelper",
+        generator,
+        inputType,
+        aggregates,
+        aggInFields,
+        aggregates.indices.map(_ + groupings.length).toArray,
+        partialResults = true,
+        groupings,
+        None,
+        None,
+        groupings.length + aggregates.length
+      )
+
+      // compute mapping of forwarded grouping keys
+      val gkeyMapping: Array[Int] = if (gkeyOutMapping.nonEmpty) {
+        val gkeyOutFields = gkeyOutMapping.map(_._1)
+        val mapping = Array.fill[Int](gkeyOutFields.max + 1)(-1)
+        gkeyOutFields.zipWithIndex.foreach(m => mapping(m._1) = m._2)
+        mapping
+      } else {
+        new Array[Int](0)
+      }
+
+      val genFinalAggFunction = generator.generateAggregations(
+        "DataSetAggregateFinalHelper",
+        generator,
+        inputType,
+        aggregates,
+        aggInFields,
+        aggOutFields,
+        partialResults = false,
+        gkeyMapping,
+        Some(aggregates.indices.map(_ + groupings.length).toArray),
+        constantFlags,
+        outputType.getFieldCount
+      )
+
       (
-        Some(new DataSetPreAggFunction(
-          aggregates,
-          aggInFields,
-          gkeyInFields
-        )),
+        Some(new DataSetPreAggFunction(genPreAggFunction)),
         Some(preAggRowType),
-        new DataSetFinalAggFunction(
-          aggregates,
-          aggOutFields,
-          gkeyOutFields,
-          groupingSetsMapping,
-          outputType.getFieldCount)
+        new DataSetFinalAggFunction(genFinalAggFunction)
       )
     }
     else {
+      val genFunction = generator.generateAggregations(
+        "DataSetAggregateHelper",
+        generator,
+        inputType,
+        aggregates,
+        aggInFields,
+        aggOutFields,
+        partialResults = false,
+        groupings,
+        None,
+        constantFlags,
+        outputType.getFieldCount
+      )
+
       (
         None,
         None,
-        new DataSetAggFunction(
-          aggregates,
-          aggInFields,
-          aggOutMapping,
-          gkeyOutMapping,
-          groupingSetsMapping,
-          outputType.getFieldCount)
+        new DataSetAggFunction(genFunction)
       )
     }
 
@@ -768,15 +890,12 @@ object AggregateUtil {
       aggregates,
       aggFields,
       aggMapping,
-      Array(),
-      outputArity)
-
-    val aggregateMapping = getAggregateMapping(namedAggregates, outputType)
-
-    if (aggregateMapping.length != namedAggregates.length) {
-      throw new TableException(
-        "Could not find output field in input data type or aggregate 
functions.")
-    }
+      partialResults = false,
+      Array(), // no fields are forwarded
+      None,
+      None,
+      outputArity
+    )
 
     val aggResultTypes = namedAggregates.map(a => 
FlinkTypeFactory.toTypeInfo(a.left.getType))
 
@@ -1169,62 +1288,6 @@ object AggregateUtil {
     new RowTypeInfo(aggTypes: _*)
   }
 
-  // Find the mapping between the index of aggregate list and aggregated value 
index in output Row.
-  private def getAggregateMapping(
-    namedAggregates: Seq[CalcitePair[AggregateCall, String]],
-    outputType: RelDataType): Array[(Int, Int)] = {
-
-    // the mapping relation between aggregate function index in list and its 
corresponding
-    // field index in output Row.
-    var aggOffsetMapping = ArrayBuffer[(Int, Int)]()
-
-    outputType.getFieldList.zipWithIndex.foreach {
-      case (outputFieldType, outputIndex) =>
-        namedAggregates.zipWithIndex.foreach {
-          case (namedAggCall, aggregateIndex) =>
-            if (namedAggCall.getValue.equals(outputFieldType.getName) &&
-              namedAggCall.getKey.getType.equals(outputFieldType.getType)) {
-              aggOffsetMapping += ((outputIndex, aggregateIndex))
-            }
-        }
-    }
-
-    aggOffsetMapping.toArray
-  }
-
-  // Find the mapping between the index of group key in intermediate aggregate 
Row and its index
-  // in output Row.
-  private def getGroupKeysMapping(
-    inputDatType: RelDataType,
-    outputType: RelDataType,
-    groupKeys: Array[Int]): Array[(Int, Int)] = {
-
-    // the mapping relation between field index of intermediate aggregate Row 
and output Row.
-    var groupingOffsetMapping = ArrayBuffer[(Int, Int)]()
-
-    outputType.getFieldList.zipWithIndex.foreach {
-      case (outputFieldType, outputIndex) =>
-        inputDatType.getFieldList.zipWithIndex.foreach {
-          // find the field index in input data type.
-          case (inputFieldType, inputIndex) =>
-            if (outputFieldType.getName.equals(inputFieldType.getName) &&
-              outputFieldType.getType.equals(inputFieldType.getType)) {
-              // as aggregated field in output data type would not have a 
matched field in
-              // input data, so if inputIndex is not -1, it must be a group 
key. Then we can
-              // find the field index in buffer data by the group keys index 
mapping between
-              // input data and buffer data.
-              for (i <- groupKeys.indices) {
-                if (inputIndex == groupKeys(i)) {
-                  groupingOffsetMapping += ((outputIndex, i))
-                }
-              }
-            }
-        }
-    }
-
-    groupingOffsetMapping.toArray
-  }
-
   private def getTimeFieldPosition(
     timeField: Expression,
     inputType: RelDataType,

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala
index 867943e..5f459f9 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetAggFunction.scala
@@ -21,101 +21,66 @@ import java.lang.Iterable
 
 import org.apache.flink.api.common.functions.RichGroupReduceFunction
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
 import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
 
 /**
   * [[RichGroupReduceFunction]] to compute aggregates that do not support 
pre-aggregation for batch
   * (DataSet) queries.
   *
-  * @param aggregates The aggregate functions.
-  * @param aggInFields The positions of the aggregation input fields.
-  * @param gkeyOutMapping The mapping of group keys between input and output 
positions.
-  * @param aggOutMapping  The mapping of aggregates to output positions.
-  * @param groupingSetsMapping The mapping of grouping set keys between input 
and output positions.
-  * @param finalRowArity The arity of the final resulting row.
+  * @param genAggregations Code-generated [[GeneratedAggregations]]
   */
 class DataSetAggFunction(
-    private val aggregates: Array[AggregateFunction[_ <: Any]],
-    private val aggInFields: Array[Array[Int]],
-    private val aggOutMapping: Array[(Int, Int)],
-    private val gkeyOutMapping: Array[(Int, Int)],
-    private val groupingSetsMapping: Array[(Int, Int)],
-    private val finalRowArity: Int)
-  extends RichGroupReduceFunction[Row, Row] {
-
-  Preconditions.checkNotNull(aggregates)
-  Preconditions.checkNotNull(aggInFields)
-  Preconditions.checkNotNull(aggOutMapping)
-  Preconditions.checkNotNull(gkeyOutMapping)
-  Preconditions.checkNotNull(groupingSetsMapping)
+    private val genAggregations: GeneratedAggregationsFunction)
+  extends RichGroupReduceFunction[Row, Row]
+    with Compiler[GeneratedAggregations] {
 
   private var output: Row = _
+  private var accumulators: Row = _
 
-  private var intermediateGKeys: Option[Array[Int]] = None
-  private var accumulators: Array[Accumulator] = _
+  val LOG = LoggerFactory.getLogger(this.getClass)
+  private var function: GeneratedAggregations = _
 
   override def open(config: Configuration) {
-    accumulators = new Array(aggregates.length)
-    output = new Row(finalRowArity)
-
-    if (!groupingSetsMapping.isEmpty) {
-      intermediateGKeys = Some(gkeyOutMapping.map(_._1))
-    }
+    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+                s"Code:\n$genAggregations.code")
+    val clazz = compile(
+      getClass.getClassLoader,
+      genAggregations.name,
+      genAggregations.code)
+    LOG.debug("Instantiating AggregateHelper.")
+    function = clazz.newInstance()
+
+    output = function.createOutputRow()
+    accumulators = function.createAccumulators()
   }
 
   override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
 
-    // create accumulators
-    var i = 0
-    while (i < aggregates.length) {
-      accumulators(i) = aggregates(i).createAccumulator()
-      i += 1
-    }
+    // reset accumulators
+    function.resetAccumulator(accumulators)
 
     val iterator = records.iterator()
 
+    var record: Row = null
     while (iterator.hasNext) {
-      val record = iterator.next()
+      record = iterator.next()
 
       // accumulate
-      i = 0
-      while (i < aggregates.length) {
-        aggregates(i).accumulate(accumulators(i), 
record.getField(aggInFields(i)(0)))
-        i += 1
-      }
-
-      // check if this record is the last record
-      if (!iterator.hasNext) {
-        // set group keys value to final output
-        i = 0
-        while (i < gkeyOutMapping.length) {
-          val (out, in) = gkeyOutMapping(i)
-          output.setField(out, record.getField(in))
-          i += 1
-        }
-
-        // set agg results to output
-        i = 0
-        while (i < aggOutMapping.length) {
-          val (out, in) = aggOutMapping(i)
-          output.setField(out, aggregates(in).getValue(accumulators(in)))
-          i += 1
-        }
-
-        // set grouping set flags to output
-        if (intermediateGKeys.isDefined) {
-          i = 0
-          while (i < groupingSetsMapping.length) {
-            val (in, out) = groupingSetsMapping(i)
-            output.setField(out, !intermediateGKeys.get.contains(in))
-            i += 1
-          }
-        }
-
-        out.collect(output)
-      }
+      function.accumulate(accumulators, record)
     }
+
+    // set group keys value to final output
+    function.setForwardedFields(record, output)
+
+    // set agg results to output
+    function.setAggregationResults(accumulators, output)
+
+    // set grouping set flags to output
+    function.setConstantFlags(output)
+
+    out.collect(output)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala
index e3db7a2..9b81992 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetFinalAggFunction.scala
@@ -19,60 +19,43 @@
 package org.apache.flink.table.runtime.aggregate
 
 import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
 
 import org.apache.flink.api.common.functions.RichGroupReduceFunction
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
 import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
 
 /**
   * [[RichGroupReduceFunction]] to compute the final result of a 
pre-aggregated aggregation
   * for batch (DataSet) queries.
   *
-  * @param aggregates The aggregate functions.
-  * @param aggOutFields The positions of the aggregation results in the output
-  * @param gkeyOutFields The positions of the grouping keys in the output
-  * @param groupingSetsMapping The mapping of grouping set keys between input 
and output positions.
-  * @param finalRowArity The arity of the final resulting row
+  * @param genAggregations Code-generated [[GeneratedAggregations]]
   */
 class DataSetFinalAggFunction(
-    private val aggregates: Array[AggregateFunction[_ <: Any]],
-    private val aggOutFields: Array[Int],
-    private val gkeyOutFields: Array[Int],
-    private val groupingSetsMapping: Array[(Int, Int)],
-    private val finalRowArity: Int)
-  extends RichGroupReduceFunction[Row, Row] {
-
-  Preconditions.checkNotNull(aggregates)
-  Preconditions.checkNotNull(aggOutFields)
-  Preconditions.checkNotNull(gkeyOutFields)
-  Preconditions.checkNotNull(groupingSetsMapping)
+    private val genAggregations: GeneratedAggregationsFunction)
+  extends RichGroupReduceFunction[Row, Row]
+    with Compiler[GeneratedAggregations] {
 
   private var output: Row = _
+  private var accumulators: Row = _
 
-  private val intermediateGKeys: Option[Array[Int]] = if 
(!groupingSetsMapping.isEmpty) {
-    Some(gkeyOutFields)
-  } else {
-    None
-  }
-
-  private val numAggs = aggregates.length
-  private val numGKeys = gkeyOutFields.length
-
-  private val accumulators: Array[JArrayList[Accumulator]] =
-    Array.fill(numAggs)(new JArrayList[Accumulator](2))
+  val LOG = LoggerFactory.getLogger(this.getClass)
+  private var function: GeneratedAggregations = _
 
   override def open(config: Configuration) {
-    output = new Row(finalRowArity)
-
-    // init lists with two empty accumulators
-    for (i <- aggregates.indices) {
-      val accumulator = aggregates(i).createAccumulator()
-      accumulators(i).add(accumulator)
-      accumulators(i).add(accumulator)
-    }
+    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+                s"Code:\n$genAggregations.code")
+    val clazz = compile(
+      getClass.getClassLoader,
+      genAggregations.name,
+      genAggregations.code)
+    LOG.debug("Instantiating AggregateHelper.")
+    function = clazz.newInstance()
+
+    output = function.createOutputRow()
+    accumulators = function.createAccumulators()
   }
 
   override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
@@ -80,56 +63,24 @@ class DataSetFinalAggFunction(
     val iterator = records.iterator()
 
     // reset first accumulator
-    var i = 0
-    while (i < aggregates.length) {
-      aggregates(i).resetAccumulator(accumulators(i).get(0))
-      i += 1
-    }
+    function.resetAccumulator(accumulators)
 
+    var record: Row = null
     while (iterator.hasNext) {
-      val record = iterator.next()
-
+      record = iterator.next()
       // accumulate
-      i = 0
-      while (i < aggregates.length) {
-        // insert received accumulator into acc list
-        val newAcc = record.getField(numGKeys + i).asInstanceOf[Accumulator]
-        accumulators(i).set(1, newAcc)
-        // merge acc list
-        val retAcc = aggregates(i).merge(accumulators(i))
-        // insert result into acc list
-        accumulators(i).set(0, retAcc)
-        i += 1
-      }
-
-      // check if this record is the last record
-      if (!iterator.hasNext) {
-        // set group keys value to final output
-        i = 0
-        while (i < gkeyOutFields.length) {
-          output.setField(gkeyOutFields(i), record.getField(i))
-          i += 1
-        }
-
-        // get final aggregate value and set to output.
-        i = 0
-        while (i < aggOutFields.length) {
-          output.setField(aggOutFields(i), 
aggregates(i).getValue(accumulators(i).get(0)))
-          i += 1
-        }
-
-        // set grouping set flags to output
-        if (intermediateGKeys.isDefined) {
-          i = 0
-          while (i < groupingSetsMapping.length) {
-            val (in, out) = groupingSetsMapping(i)
-            output.setField(out, !intermediateGKeys.get.contains(in))
-            i += 1
-          }
-        }
-
-        out.collect(output)
-      }
+      function.mergeAccumulatorsPair(accumulators, record)
     }
+
+    // set group keys value to final output
+    function.setForwardedFields(record, output)
+
+    // get final aggregate value and set to output.
+    function.setAggregationResults(accumulators, output)
+
+    // set grouping set flags to output
+    function.setConstantFlags(output)
+
+    out.collect(output)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
index db49a53..8febe3e 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetPreAggFunction.scala
@@ -21,85 +21,64 @@ import java.lang.Iterable
 
 import org.apache.flink.api.common.functions._
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
 import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
 
 /**
   * [[GroupCombineFunction]] and [[MapPartitionFunction]] to compute 
pre-aggregates for batch
   * (DataSet) queries.
   *
-  * @param aggregates The aggregate functions.
-  * @param aggInFields The positions of the aggregation input fields.
-  * @param groupingKeys The positions of the grouping keys in the input.
+  * @param genAggregations Code-generated [[GeneratedAggregations]]
   */
-class DataSetPreAggFunction(
-    private val aggregates: Array[AggregateFunction[_ <: Any]],
-    private val aggInFields: Array[Array[Int]],
-    private val groupingKeys: Array[Int])
+class DataSetPreAggFunction(genAggregations: GeneratedAggregationsFunction)
   extends AbstractRichFunction
   with GroupCombineFunction[Row, Row]
-  with MapPartitionFunction[Row, Row] {
-
-  Preconditions.checkNotNull(aggregates)
-  Preconditions.checkNotNull(aggInFields)
-  Preconditions.checkNotNull(groupingKeys)
+  with MapPartitionFunction[Row, Row]
+  with Compiler[GeneratedAggregations] {
 
   private var output: Row = _
-  private var accumulators: Array[Accumulator] = _
+  private var accumulators: Row = _
+
+  val LOG = LoggerFactory.getLogger(this.getClass)
+  private var function: GeneratedAggregations = _
 
   override def open(config: Configuration) {
-    accumulators = new Array(aggregates.length)
-    output = new Row(groupingKeys.length + aggregates.length)
+    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+                s"Code:\n$genAggregations.code")
+    val clazz = compile(
+      getClass.getClassLoader,
+      genAggregations.name,
+      genAggregations.code)
+    LOG.debug("Instantiating AggregateHelper.")
+    function = clazz.newInstance()
+
+    output = function.createOutputRow()
+    accumulators = function.createAccumulators()
   }
 
   override def combine(values: Iterable[Row], out: Collector[Row]): Unit = {
-    preaggregate(values, out)
-  }
-
-  override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit 
= {
-    preaggregate(values, out)
-  }
-
-  def preaggregate(records: Iterable[Row], out: Collector[Row]): Unit = {
+    // reset accumulators
+    function.resetAccumulator(accumulators)
 
-    // create accumulators
-    var i = 0
-    while (i < aggregates.length) {
-      accumulators(i) = aggregates(i).createAccumulator()
-      i += 1
-    }
-
-    val iterator = records.iterator()
+    val iterator = values.iterator()
 
+    var record: Row = null
     while (iterator.hasNext) {
-      val record = iterator.next()
-
+      record = iterator.next()
       // accumulate
-      i = 0
-      while (i < aggregates.length) {
-        aggregates(i).accumulate(accumulators(i), 
record.getField(aggInFields(i)(0)))
-        i += 1
-      }
-      // check if this record is the last record
-      if (!iterator.hasNext) {
-        // set group keys value to output
-        i = 0
-        while (i < groupingKeys.length) {
-          output.setField(i, record.getField(groupingKeys(i)))
-          i += 1
-        }
+      function.accumulate(accumulators, record)
+    }
 
-        // set agg results to output
-        i = 0
-        while (i < accumulators.length) {
-          output.setField(groupingKeys.length + i, accumulators(i))
-          i += 1
-        }
+    // set group keys and accumulators to output
+    function.setAggregationResults(accumulators, output)
+    function.setForwardedFields(record, output)
 
-        out.collect(output)
-      }
-    }
+    out.collect(output)
   }
 
+  override def mapPartition(values: Iterable[Row], out: Collector[Row]): Unit 
= {
+    combine(values, out)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala
index d108570..95699a2 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggReduceGroupFunction.scala
@@ -18,13 +18,13 @@
 package org.apache.flink.table.runtime.aggregate
 
 import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
 
 import org.apache.flink.api.common.functions.RichGroupReduceFunction
 import org.apache.flink.types.Row
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
 
 /**
   * It wraps the aggregate logic inside of
@@ -40,53 +40,45 @@ import org.apache.flink.util.{Collector, Preconditions}
   *  2. when partial aggregate is supported, the input data structure of 
reduce is
   * |groupKey1|groupKey2|sum1|count1|sum2|count2|windowStart|windowEnd|
   *
-  * @param aggregates             The aggregate functions.
-  * @param groupKeysMapping       The index mapping of group keys between 
intermediate aggregate Row
-  *                               and output Row.
-  * @param aggregateMapping       The index mapping between aggregate function 
list and
-  *                               aggregated value index in output Row.
-  * @param finalRowArity          The output row field count.
+  * @param genAggregations Code-generated [[GeneratedAggregations]]
+  * @param keysAndAggregatesArity    The total arity of keys and aggregates
   * @param finalRowWindowStartPos The relative window-start field position.
   * @param finalRowWindowEndPos   The relative window-end field position.
   * @param gap                    Session time window gap.
   */
 class DataSetSessionWindowAggReduceGroupFunction(
-    aggregates: Array[AggregateFunction[_ <: Any]],
-    groupKeysMapping: Array[(Int, Int)],
-    aggregateMapping: Array[(Int, Int)],
-    finalRowArity: Int,
+    genAggregations: GeneratedAggregationsFunction,
+    keysAndAggregatesArity: Int,
     finalRowWindowStartPos: Option[Int],
     finalRowWindowEndPos: Option[Int],
     gap: Long,
     isInputCombined: Boolean)
-  extends RichGroupReduceFunction[Row, Row] {
+  extends RichGroupReduceFunction[Row, Row]
+    with Compiler[GeneratedAggregations] {
 
-  Preconditions.checkNotNull(aggregates)
-  Preconditions.checkNotNull(groupKeysMapping)
+  private var collector: TimeWindowPropertyCollector = _
+  private val intermediateRowWindowStartPos = keysAndAggregatesArity
+  private val intermediateRowWindowEndPos = keysAndAggregatesArity + 1
 
-  private var aggregateBuffer: Row = _
   private var output: Row = _
-  private var collector: TimeWindowPropertyCollector = _
-  private val accumStartPos: Int = groupKeysMapping.length
-  private val intermediateRowArity: Int = accumStartPos + aggregates.length + 2
-  private val intermediateRowWindowStartPos = intermediateRowArity - 2
-  private val intermediateRowWindowEndPos = intermediateRowArity - 1
+  private var accumulators: Row = _
 
-  val accumulatorList: Array[JArrayList[Accumulator]] = 
Array.fill(aggregates.length) {
-    new JArrayList[Accumulator](2)
-  }
+  val LOG = LoggerFactory.getLogger(this.getClass)
+  private var function: GeneratedAggregations = _
 
   override def open(config: Configuration) {
-    aggregateBuffer = new Row(intermediateRowArity)
-    output = new Row(finalRowArity)
+    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+                s"Code:\n$genAggregations.code")
+    val clazz = compile(
+      getClass.getClassLoader,
+      genAggregations.name,
+      genAggregations.code)
+    LOG.debug("Instantiating AggregateHelper.")
+    function = clazz.newInstance()
+
+    output = function.createOutputRow()
+    accumulators = function.createAccumulators()
     collector = new TimeWindowPropertyCollector(finalRowWindowStartPos, 
finalRowWindowEndPos)
-
-    // init lists with two empty accumulators
-    for (i <- aggregates.indices) {
-      val accumulator = aggregates(i).createAccumulator()
-      accumulatorList(i).add(accumulator)
-      accumulatorList(i).add(accumulator)
-    }
   }
 
   /**
@@ -105,13 +97,8 @@ class DataSetSessionWindowAggReduceGroupFunction(
     var windowEnd: java.lang.Long = null
     var currentRowTime: java.lang.Long = null
 
-
-    // reset first accumulator in merge list
-    var i = 0
-    while (i < aggregates.length) {
-      aggregates(i).resetAccumulator(accumulatorList(i).get(0))
-      i += 1
-    }
+    // reset accumulator
+    function.resetAccumulator(accumulators)
 
     val iterator = records.iterator()
 
@@ -125,38 +112,18 @@ class DataSetSessionWindowAggReduceGroupFunction(
         // calculate the current window and open a new window
         if (null != windowEnd) {
           // evaluate and emit the current window's result.
-          doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd)
-
-          // reset first accumulator in list
-          i = 0
-          while (i < aggregates.length) {
-            aggregates(i).resetAccumulator(accumulatorList(i).get(0))
-            i += 1
-          }
+          doEvaluateAndCollect(out, windowStart, windowEnd)
+          // reset accumulator
+          function.resetAccumulator(accumulators)
         } else {
-          // set group keys value to final output.
-          i = 0
-          while (i < groupKeysMapping.length) {
-            val (after, previous) = groupKeysMapping(i)
-            output.setField(after, record.getField(previous))
-            i += 1
-          }
+          // set keys to output
+          function.setForwardedFields(record, output)
         }
 
         windowStart = 
record.getField(intermediateRowWindowStartPos).asInstanceOf[Long]
       }
 
-      i = 0
-      while (i < aggregates.length) {
-        // insert received accumulator into acc list
-        val newAcc = record.getField(accumStartPos + 
i).asInstanceOf[Accumulator]
-        accumulatorList(i).set(1, newAcc)
-        // merge acc list
-        val retAcc = aggregates(i).merge(accumulatorList(i))
-        // insert result into acc list
-        accumulatorList(i).set(0, retAcc)
-        i += 1
-      }
+      function.mergeAccumulatorsPair(accumulators, record)
 
       windowEnd = if (isInputCombined) {
         // partial aggregate is supported
@@ -167,15 +134,13 @@ class DataSetSessionWindowAggReduceGroupFunction(
       }
     }
     // evaluate and emit the current window's result.
-    doEvaluateAndCollect(out, accumulatorList, windowStart, windowEnd)
+    doEvaluateAndCollect(out, windowStart, windowEnd)
   }
 
   /**
     * Evaluate and emit the data of the current window.
     *
     * @param out             the collection of the aggregate results
-    * @param accumulatorList an array (indexed by aggregate index) of the 
accumulator lists for
-    *                        each aggregate
     * @param windowStart     the window's start attribute value is the min 
(rowtime) of all rows
     *                        in the window.
     * @param windowEnd       the window's end property value is max (rowtime) 
+ gap for all rows
@@ -183,18 +148,11 @@ class DataSetSessionWindowAggReduceGroupFunction(
     */
   def doEvaluateAndCollect(
       out: Collector[Row],
-      accumulatorList: Array[JArrayList[Accumulator]],
       windowStart: Long,
       windowEnd: Long): Unit = {
 
-    // merge the accumulators and then get value for the final output
-    var i = 0
-    while (i < aggregateMapping.length) {
-      val (after, previous) = aggregateMapping(i)
-      val agg = aggregates(previous)
-      output.setField(after, agg.getValue(accumulatorList(previous).get(0)))
-      i += 1
-    }
+    // set value for the final output
+    function.setAggregationResults(accumulators, output)
 
     // adds TimeWindow properties to output then emit output
     if (finalRowWindowStartPos.isDefined || finalRowWindowEndPos.isDefined) {

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala
index acd9e63..22a2682 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSessionWindowAggregatePreProcessor.scala
@@ -18,55 +18,55 @@
 package org.apache.flink.table.runtime.aggregate
 
 import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
 
 import org.apache.flink.api.common.functions.{AbstractRichFunction, 
GroupCombineFunction, MapPartitionFunction}
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable
 import org.apache.flink.types.Row
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
 
 /**
   * This wraps the aggregate logic inside of
   * [[org.apache.flink.api.java.operators.GroupCombineOperator]].
   *
-  * @param aggregates          The aggregate functions.
-  * @param groupingKeys        The indexes of the grouping fields.
+  * @param genAggregations Code-generated [[GeneratedAggregations]]
+  * @param keysAndAggregatesArity    The total arity of keys and aggregates
   * @param gap                 Session time window gap.
   * @param intermediateRowType Intermediate row data type.
   */
 class DataSetSessionWindowAggregatePreProcessor(
-    aggregates: Array[AggregateFunction[_ <: Any]],
-    groupingKeys: Array[Int],
+    genAggregations: GeneratedAggregationsFunction,
+    keysAndAggregatesArity: Int,
     gap: Long,
     @transient intermediateRowType: TypeInformation[Row])
   extends AbstractRichFunction
   with MapPartitionFunction[Row,Row]
   with GroupCombineFunction[Row,Row]
-  with ResultTypeQueryable[Row] {
+  with ResultTypeQueryable[Row]
+  with Compiler[GeneratedAggregations] {
 
-  Preconditions.checkNotNull(aggregates)
-  Preconditions.checkNotNull(groupingKeys)
+  private var output: Row = _
+  private val rowTimeFieldPos = keysAndAggregatesArity
+  private var accumulators: Row = _
 
-  private var aggregateBuffer: Row = _
-  private val accumStartPos: Int = groupingKeys.length
-  private val rowTimeFieldPos = accumStartPos + aggregates.length
-
-  val accumulatorList: Array[JArrayList[Accumulator]] = 
Array.fill(aggregates.length) {
-    new JArrayList[Accumulator](2)
-  }
+  val LOG = LoggerFactory.getLogger(this.getClass)
+  private var function: GeneratedAggregations = _
 
   override def open(config: Configuration) {
-    aggregateBuffer = new Row(rowTimeFieldPos + 2)
-
-    // init lists with two empty accumulators
-    for (i <- aggregates.indices) {
-      val accumulator = aggregates(i).createAccumulator()
-      accumulatorList(i).add(accumulator)
-      accumulatorList(i).add(accumulator)
-    }
+    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+                s"Code:\n$genAggregations.code")
+    val clazz = compile(
+      getClass.getClassLoader,
+      genAggregations.name,
+      genAggregations.code)
+    LOG.debug("Instantiating AggregateHelper.")
+    function = clazz.newInstance()
+
+    accumulators = function.createAccumulators()
+    output = function.createOutputRow()
   }
 
   /**
@@ -79,43 +79,13 @@ class DataSetSessionWindowAggregatePreProcessor(
     *
     */
   override def combine(records: Iterable[Row], out: Collector[Row]): Unit = {
-    preProcessing(records, out)
-  }
-
-  /**
-    * Divide window based on the rowtime
-    * (current'rowtime - previous’rowtime > gap), and then merge data 
(within a unified window)
-    * into an aggregate buffer.
-    *
-    * @param records  Intermediate aggregate Rows.
-    * @return Pre partition intermediate aggregate Row.
-    *
-    */
-  override def mapPartition(records: Iterable[Row], out: Collector[Row]): Unit 
= {
-    preProcessing(records, out)
-  }
-
-  /**
-    * Intermediate aggregate Rows, divide window based on the rowtime
-    * (current'rowtime - previous’rowtime > gap), and then merge data 
(within a unified window)
-    * into an aggregate buffer.
-    *
-    * @param records Intermediate aggregate Rows.
-    * @return PreProcessing intermediate aggregate Row.
-    *
-    */
-  private def preProcessing(records: Iterable[Row], out: Collector[Row]): Unit 
= {
 
     var windowStart: java.lang.Long = null
     var windowEnd: java.lang.Long = null
     var currentRowTime: java.lang.Long = null
 
-    // reset first accumulator in merge list
-    var i = 0
-    while (i < aggregates.length) {
-      aggregates(i).resetAccumulator(accumulatorList(i).get(0))
-      i += 1
-    }
+    // reset accumulator
+    function.resetAccumulator(accumulators)
 
     val iterator = records.iterator()
 
@@ -128,51 +98,44 @@ class DataSetSessionWindowAggregatePreProcessor(
         // calculate the current window and open a new window.
         if (windowEnd != null) {
           // emit the current window's merged data
-          doCollect(out, accumulatorList, windowStart, windowEnd)
-
-          // reset first value of accumulator list
-          i = 0
-          while (i < aggregates.length) {
-            aggregates(i).resetAccumulator(accumulatorList(i).get(0))
-            i += 1
-          }
+          doCollect(out, windowStart, windowEnd)
+
+          // reset accumulator
+          function.resetAccumulator(accumulators)
         } else {
           // set group keys to aggregateBuffer.
-          i = 0
-          while (i < groupingKeys.length) {
-            aggregateBuffer.setField(i, record.getField(i))
-            i += 1
-          }
+          function.setForwardedFields(record, output)
         }
 
         windowStart = record.getField(rowTimeFieldPos).asInstanceOf[Long]
       }
 
-      i = 0
-      while (i < aggregates.length) {
-        // insert received accumulator into acc list
-        val newAcc = record.getField(accumStartPos + 
i).asInstanceOf[Accumulator]
-        accumulatorList(i).set(1, newAcc)
-        // merge acc list
-        val retAcc = aggregates(i).merge(accumulatorList(i))
-        // insert result into acc list
-        accumulatorList(i).set(0, retAcc)
-        i += 1
-      }
+      function.mergeAccumulatorsPair(accumulators, record)
 
       // the current rowtime is the last rowtime of the next calculation.
       windowEnd = currentRowTime + gap
     }
     // emit the merged data of the current window.
-    doCollect(out, accumulatorList, windowStart, windowEnd)
+    doCollect(out, windowStart, windowEnd)
+  }
+
+  /**
+    * Divide window based on the rowtime
+    * (current'rowtime - previous’rowtime > gap), and then merge data 
(within a unified window)
+    * into an aggregate buffer.
+    *
+    * @param records  Intermediate aggregate Rows.
+    * @return Pre partition intermediate aggregate Row.
+    *
+    */
+  override def mapPartition(records: Iterable[Row], out: Collector[Row]): Unit 
= {
+    combine(records, out)
   }
 
   /**
     * Emit the merged data of the current window.
     *
     * @param out             the collection of the aggregate results
-    * @param accumulatorList an array (indexed by aggregate index) of the 
accumulator lists for
-    *                        each aggregate
     * @param windowStart     the window's start attribute value is the min 
(rowtime)
     *                        of all rows in the window.
     * @param windowEnd       the window's end property value is max (rowtime) 
+ gap
@@ -180,24 +143,18 @@ class DataSetSessionWindowAggregatePreProcessor(
     */
   def doCollect(
       out: Collector[Row],
-      accumulatorList: Array[JArrayList[Accumulator]],
       windowStart: Long,
       windowEnd: Long): Unit = {
 
-    // merge the accumulators into one accumulator
-    var i = 0
-    while (i < aggregates.length) {
-      aggregateBuffer.setField(accumStartPos + i, accumulatorList(i).get(0))
-      i += 1
-    }
+    function.setAggregationResults(accumulators, output)
 
     // intermediate Row WindowStartPos is rowtime pos.
-    aggregateBuffer.setField(rowTimeFieldPos, windowStart)
+    output.setField(rowTimeFieldPos, windowStart)
 
     // intermediate Row WindowEndPos is rowtime pos + 1.
-    aggregateBuffer.setField(rowTimeFieldPos + 1, windowEnd)
+    output.setField(rowTimeFieldPos + 1, windowEnd)
 
-    out.collect(aggregateBuffer)
+    out.collect(output)
   }
 
   override def getProducedType: TypeInformation[Row] = {

http://git-wip-us.apache.org/repos/asf/flink/blob/3b4542b8/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala
index 422989b..b3a19a4 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/DataSetSlideTimeWindowAggReduceGroupFunction.scala
@@ -18,16 +18,16 @@
 package org.apache.flink.table.runtime.aggregate
 
 import java.lang.Iterable
-import java.util.{ArrayList => JArrayList}
 
 import org.apache.flink.api.common.functions.{CombineFunction, 
RichGroupReduceFunction}
 import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable
 import org.apache.flink.configuration.Configuration
 import org.apache.flink.streaming.api.windowing.windows.TimeWindow
-import org.apache.flink.table.functions.{Accumulator, AggregateFunction}
+import org.apache.flink.table.codegen.{Compiler, GeneratedAggregationsFunction}
 import org.apache.flink.types.Row
-import org.apache.flink.util.{Collector, Preconditions}
+import org.apache.flink.util.Collector
+import org.slf4j.LoggerFactory
 
 /**
   * It is used for sliding windows on batch for time-windows. It takes a 
prepared input row (with
@@ -38,106 +38,78 @@ import org.apache.flink.util.{Collector, Preconditions}
   * it does no final aggregate evaluation. It also includes the logic of
   * [[DataSetSlideTimeWindowAggFlatMapFunction]].
   *
-  * @param aggregates aggregate functions
-  * @param groupingKeysLength number of grouping keys
-  * @param timeFieldPos position of aligned time field
+  * @param genAggregations Code-generated [[GeneratedAggregations]]
+  * @param keysAndAggregatesArity The total arity of keys and aggregates
   * @param windowSize window size of the sliding window
   * @param windowSlide window slide of the sliding window
   * @param returnType return type of this function
   */
 class DataSetSlideTimeWindowAggReduceGroupFunction(
-    private val aggregates: Array[AggregateFunction[_ <: Any]],
-    private val groupingKeysLength: Int,
-    private val timeFieldPos: Int,
+    private val genAggregations: GeneratedAggregationsFunction,
+    private val keysAndAggregatesArity: Int,
     private val windowSize: Long,
     private val windowSlide: Long,
     @transient private val returnType: TypeInformation[Row])
   extends RichGroupReduceFunction[Row, Row]
   with CombineFunction[Row, Row]
-  with ResultTypeQueryable[Row] {
+  with ResultTypeQueryable[Row]
+  with Compiler[GeneratedAggregations] {
 
-  Preconditions.checkNotNull(aggregates)
+  private val timeFieldPos = returnType.getArity - 1
+  private val intermediateWindowStartPos = keysAndAggregatesArity
 
   protected var intermediateRow: Row = _
-  // add one field to store window start
-  protected val intermediateRowArity: Int = groupingKeysLength + 
aggregates.length + 1
-  protected val accumulatorList: Array[JArrayList[Accumulator]] = 
Array.fill(aggregates.length) {
-    new JArrayList[Accumulator](2)
-  }
-  private val intermediateWindowStartPos: Int = intermediateRowArity - 1
+  private var accumulators: Row = _
+
+  val LOG = LoggerFactory.getLogger(this.getClass)
+  private var function: GeneratedAggregations = _
 
   override def open(config: Configuration) {
-    intermediateRow = new Row(intermediateRowArity)
-
-    // init lists with two empty accumulators
-    var i = 0
-    while (i < aggregates.length) {
-      val accumulator = aggregates(i).createAccumulator()
-      accumulatorList(i).add(accumulator)
-      accumulatorList(i).add(accumulator)
-      i += 1
-    }
+    LOG.debug(s"Compiling AggregateHelper: $genAggregations.name \n\n " +
+                s"Code:\n$genAggregations.code")
+    val clazz = compile(
+      getClass.getClassLoader,
+      genAggregations.name,
+      genAggregations.code)
+    LOG.debug("Instantiating AggregateHelper.")
+    function = clazz.newInstance()
+
+    accumulators = function.createAccumulators()
+    intermediateRow = function.createOutputRow()
   }
 
   override def reduce(records: Iterable[Row], out: Collector[Row]): Unit = {
 
     // reset first accumulator
-    var i = 0
-    while (i < aggregates.length) {
-      val accumulator = aggregates(i).createAccumulator()
-      accumulatorList(i).set(0, accumulator)
-      i += 1
-    }
+    function.resetAccumulator(accumulators)
 
     val iterator = records.iterator()
 
+    var record: Row = null
     while (iterator.hasNext) {
-      val record = iterator.next()
+      record = iterator.next()
 
       // accumulate
-      i = 0
-      while (i < aggregates.length) {
-        // insert received accumulator into acc list
-        val newAcc = record.getField(groupingKeysLength + 
i).asInstanceOf[Accumulator]
-        accumulatorList(i).set(1, newAcc)
-        // merge acc list
-        val retAcc = aggregates(i).merge(accumulatorList(i))
-        // insert result into acc list
-        accumulatorList(i).set(0, retAcc)
-        i += 1
-      }
+      function.mergeAccumulatorsPair(accumulators, record)
+    }
+
+    val windowStart = record.getField(timeFieldPos).asInstanceOf[Long]
+
+    // adopted from SlidingEventTimeWindows.assignWindows
+    var start: Long = TimeWindow.getWindowStartWithOffset(windowStart, 0, 
windowSlide)
+
+    // skip preparing output if it is not necessary
+    if (start > windowStart - windowSize) {
+
+      // set group keys and partial accumulated result
+      function.setAggregationResults(accumulators, intermediateRow)
+      function.setForwardedFields(record, intermediateRow)
 
-      // trigger tumbling evaluation
-      if (!iterator.hasNext) {
-        val windowStart = record.getField(timeFieldPos).asInstanceOf[Long]
-
-        // adopted from SlidingEventTimeWindows.assignWindows
-        var start: Long = TimeWindow.getWindowStartWithOffset(windowStart, 0, 
windowSlide)
-
-        // skip preparing output if it is not necessary
-        if (start > windowStart - windowSize) {
-
-          // set group keys
-          i = 0
-          while (i < groupingKeysLength) {
-            intermediateRow.setField(i, record.getField(i))
-            i += 1
-          }
-
-          // set accumulators
-          i = 0
-          while (i < aggregates.length) {
-            intermediateRow.setField(groupingKeysLength + i, 
accumulatorList(i).get(0))
-            i += 1
-          }
-
-          // adopted from SlidingEventTimeWindows.assignWindows
-          while (start > windowStart - windowSize) {
-            intermediateRow.setField(intermediateWindowStartPos, start)
-            out.collect(intermediateRow)
-            start -= windowSlide
-          }
-        }
+      // adopted from SlidingEventTimeWindows.assignWindows
+      while (start > windowStart - windowSize) {
+        intermediateRow.setField(intermediateWindowStartPos, start)
+        out.collect(intermediateRow)
+        start -= windowSlide
       }
     }
   }
@@ -145,54 +117,21 @@ class DataSetSlideTimeWindowAggReduceGroupFunction(
   override def combine(records: Iterable[Row]): Row = {
 
     // reset first accumulator
-    var i = 0
-    while (i < aggregates.length) {
-      aggregates(i).resetAccumulator(accumulatorList(i).get(0))
-      i += 1
-    }
+    function.resetAccumulator(accumulators)
 
     val iterator = records.iterator()
+    var record: Row = null
     while (iterator.hasNext) {
-      val record = iterator.next()
-
-      i = 0
-      while (i < aggregates.length) {
-        // insert received accumulator into acc list
-        val newAcc = record.getField(groupingKeysLength + 
i).asInstanceOf[Accumulator]
-        accumulatorList(i).set(1, newAcc)
-        // merge acc list
-        val retAcc = aggregates(i).merge(accumulatorList(i))
-        // insert result into acc list
-        accumulatorList(i).set(0, retAcc)
-        i += 1
-      }
-
-      // check if this record is the last record
-      if (!iterator.hasNext) {
-
-        // set group keys
-        i = 0
-        while (i < groupingKeysLength) {
-          intermediateRow.setField(i, record.getField(i))
-          i += 1
-        }
-
-        // set accumulators
-        i = 0
-        while (i < aggregates.length) {
-          intermediateRow.setField(groupingKeysLength + i, 
accumulatorList(i).get(0))
-          i += 1
-        }
-
-        intermediateRow.setField(timeFieldPos, record.getField(timeFieldPos))
-
-        return intermediateRow
-      }
+      record = iterator.next()
+      function.mergeAccumulatorsPair(accumulators, record)
     }
+    // set group keys and partial accumulated result
+    function.setAggregationResults(accumulators, intermediateRow)
+    function.setForwardedFields(record, intermediateRow)
+
+    intermediateRow.setField(timeFieldPos, record.getField(timeFieldPos))
 
-    // this code path should never be reached as we return before the loop 
finishes
-    // we need this to prevent a compiler error
-    throw new IllegalArgumentException("Group is empty. This should never 
happen.")
+    intermediateRow
   }
 
   override def getProducedType: TypeInformation[Row] = {

Reply via email to