Github user fhueske commented on a diff in the pull request:

    https://github.com/apache/flink/pull/5555#discussion_r184041605
  
    --- Diff: 
flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/AggregationCodeGenerator.scala
 ---
    @@ -207,112 +241,164 @@ class AggregationCodeGenerator(
             }
         }
     
    +
         /**
    -      * Create DataView Term, for example, acc1_map_dataview.
    -      *
    -      * @param aggIndex index of aggregate function
    -      * @param fieldName field name of DataView
    -      * @return term to access [[MapView]] or [[ListView]]
    +      * Add all data view for all distinct filters defined by aggregation 
functions.
           */
    -    def createDataViewTerm(aggIndex: Int, fieldName: String): String = {
    -      s"acc${aggIndex}_${fieldName}_dataview"
    +    def addDistinctFilterDataViews(): Unit = {
    +      val descMapping: Map[String, StateDescriptor[_, _]] = distinctAggs
    +        .flatMap(specs => specs.map(s => (s.stateId, s.toStateDescriptor)))
    +        .toMap[String, StateDescriptor[_ <: State, _]]
    +      if (isStateBackedDataViews) {
    +        for (i <- aggs.indices) yield {
    +          for (spec <- distinctAggs(i)) {
    +            // Check if stat descriptor exists.
    +            val desc: StateDescriptor[_, _] = 
descMapping.getOrElse(spec.stateId,
    +              throw new CodeGenException(
    +                s"Can not find DataView for distinct filter in accumulator 
by id: ${spec.stateId}"))
    +
    +            addReusableDataView(spec, desc, i)
    +          }
    +        }
    +      }
         }
     
         /**
    -      * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] 
to the open, cleanup,
    -      * close and member area of the generated function.
    -      *
    +      * Add all data views for all field accumulators defined by 
aggregation functions.
           */
    -    def addReusableDataViews(): Unit = {
    +    def addAccumulatorDataViews(): Unit = {
           if (accConfig.isDefined) {
             val descMapping: Map[String, StateDescriptor[_, _]] = accConfig.get
               .flatMap(specs => specs.map(s => (s.stateId, 
s.toStateDescriptor)))
               .toMap[String, StateDescriptor[_ <: State, _]]
     
             for (i <- aggs.indices) yield {
               for (spec <- accConfig.get(i)) yield {
    -            val dataViewField = spec.field
    -            val dataViewTypeTerm = dataViewField.getType.getCanonicalName
    -            val desc = descMapping.getOrElse(spec.stateId,
    +            // Check if stat descriptor exists.
    +            val desc: StateDescriptor[_, _] = 
descMapping.getOrElse(spec.stateId,
                   throw new CodeGenException(
                     s"Can not find DataView in accumulator by id: 
${spec.stateId}"))
     
    -            // define the DataView variables
    -            val serializedData = serializeStateDescriptor(desc)
    -            val dataViewFieldTerm = createDataViewTerm(i, 
dataViewField.getName)
    -            val field =
    -              s"""
    -                 |    final $dataViewTypeTerm $dataViewFieldTerm;
    -                 |""".stripMargin
    -            reusableMemberStatements.add(field)
    -
    -            // create DataViews
    -            val descFieldTerm = s"${dataViewFieldTerm}_desc"
    -            val descClassQualifier = classOf[StateDescriptor[_, 
_]].getCanonicalName
    -            val descDeserializeCode =
    -              s"""
    -                 |    $descClassQualifier $descFieldTerm = 
($descClassQualifier)
    -                 |      
org.apache.flink.util.InstantiationUtil.deserializeObject(
    -                 |      
org.apache.commons.codec.binary.Base64.decodeBase64("$serializedData"),
    -                 |      $contextTerm.getUserCodeClassLoader());
    -                 |""".stripMargin
    -            val createDataView = if (dataViewField.getType == 
classOf[MapView[_, _]]) {
    -              s"""
    -                 |    $descDeserializeCode
    -                 |    $dataViewFieldTerm = new 
org.apache.flink.table.dataview.StateMapView(
    -                 |      $contextTerm.getMapState((
    -                 |      
org.apache.flink.api.common.state.MapStateDescriptor)$descFieldTerm));
    -                 |""".stripMargin
    -            } else if (dataViewField.getType == classOf[ListView[_]]) {
    -              s"""
    -                 |    $descDeserializeCode
    -                 |    $dataViewFieldTerm = new 
org.apache.flink.table.dataview.StateListView(
    -                 |      $contextTerm.getListState((
    -                 |      
org.apache.flink.api.common.state.ListStateDescriptor)$descFieldTerm));
    -                 |""".stripMargin
    -            } else {
    -              throw new CodeGenException(s"Unsupported dataview type: 
$dataViewTypeTerm")
    -            }
    -            reusableOpenStatements.add(createDataView)
    -
    -            // cleanup DataViews
    -            val cleanup =
    -              s"""
    -                 |    $dataViewFieldTerm.clear();
    -                 |""".stripMargin
    -            reusableCleanupStatements.add(cleanup)
    +            addReusableDataView(spec, desc, i)
               }
             }
           }
         }
     
    +    /**
    +      * Create DataView Term, for example, acc1_map_dataview.
    +      *
    +      * @param aggIndex index of aggregate function
    +      * @param fieldName field name of DataView
    +      * @return term to access [[MapView]] or [[ListView]]
    +      */
    +    def createDataViewTerm(aggIndex: Int, fieldName: String): String = {
    +      s"acc${aggIndex}_${fieldName}_dataview"
    +    }
    +
    +    /**
    +      * Adds a reusable [[org.apache.flink.table.api.dataview.DataView]] 
to the open, cleanup,
    +      * close and member area of the generated function.
    +      * @param spec the [[DataViewSpec]] of the desired data view term.
    +      * @param desc the [[StateDescriptor]] of the desired data view term.
    +      * @param aggIndex the aggregation function index associate with the 
data view.
    +      */
    +    def addReusableDataView(
    +        spec: DataViewSpec[_],
    +        desc: StateDescriptor[_, _],
    +        aggIndex: Int): Unit = {
    +      val dataViewField = spec.field
    +      val dataViewTypeTerm = dataViewField.getType.getCanonicalName
    +
    +      // define the DataView variables
    +      val serializedData = serializeStateDescriptor(desc)
    +      val dataViewFieldTerm = createDataViewTerm(aggIndex, 
dataViewField.getName)
    +      val field =
    +        s"""
    +           |    final $dataViewTypeTerm $dataViewFieldTerm;
    +           |""".stripMargin
    +      reusableMemberStatements.add(field)
    +
    +      // create DataViews
    +      val descFieldTerm = s"${dataViewFieldTerm}_desc"
    +      val descClassQualifier = classOf[StateDescriptor[_, 
_]].getCanonicalName
    +      val descDeserializeCode =
    +        s"""
    +           |    $descClassQualifier $descFieldTerm = ($descClassQualifier)
    +           |      
org.apache.flink.util.InstantiationUtil.deserializeObject(
    +           |      
org.apache.commons.codec.binary.Base64.decodeBase64("$serializedData"),
    +           |      $contextTerm.getUserCodeClassLoader());
    +           |""".stripMargin
    +      val createDataView = if (dataViewField.getType == classOf[MapView[_, 
_]]) {
    +        s"""
    +           |    $descDeserializeCode
    +           |    $dataViewFieldTerm = new 
org.apache.flink.table.dataview.StateMapView(
    +           |      $contextTerm.getMapState((
    +           |      
org.apache.flink.api.common.state.MapStateDescriptor)$descFieldTerm));
    +           |""".stripMargin
    +      } else if (dataViewField.getType == classOf[ListView[_]]) {
    +        s"""
    +           |    $descDeserializeCode
    +           |    $dataViewFieldTerm = new 
org.apache.flink.table.dataview.StateListView(
    +           |      $contextTerm.getListState((
    +           |      
org.apache.flink.api.common.state.ListStateDescriptor)$descFieldTerm));
    --- End diff --
    
    `org.apache.flink.api.common.state.ListStateDescriptor` -> 
`classOf[ListStateDescriptor[_]].getCanonicalName` 


---

Reply via email to