Github user jackylk commented on a diff in the pull request:

    https://github.com/apache/carbondata/pull/1694#discussion_r158933818
  
    --- Diff: 
integration/spark2/src/main/scala/org/apache/spark/sql/execution/command/preaaggregate/PreAggregateUtil.scala
 ---
    @@ -166,127 +208,160 @@ object PreAggregateUtil {
           aggFunctions: AggregateFunction,
           parentTableName: String,
           parentDatabaseName: String,
    -      parentTableId: String) : scala.collection.mutable.ListBuffer[(Field, 
DataMapField)] = {
    +      parentTableId: String,
    +      newColumnName: String) : scala.collection.mutable.ListBuffer[(Field, 
DataMapField)] = {
         val list = scala.collection.mutable.ListBuffer.empty[(Field, 
DataMapField)]
         aggFunctions match {
    -      case sum@Sum(attr: AttributeReference) =>
    -        list += getField(attr.name,
    -          attr.dataType,
    -          sum.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case sum@Sum(Cast(attr: AttributeReference, changeDataType: 
DataType)) =>
    -        list += getField(attr.name,
    +      case sum@Sum(MatchCastExpression(exp: Expression, changeDataType: 
DataType)) =>
    +        list += getFieldForAggregateExpression(exp,
               changeDataType,
    -          sum.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case count@Count(Seq(attr: AttributeReference)) =>
    -        list += getField(attr.name,
    -          attr.dataType,
    -          count.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case count@Count(Seq(Cast(attr: AttributeReference, _))) =>
    -        list += getField(attr.name,
    -          attr.dataType,
    -          count.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case min@Min(attr: AttributeReference) =>
    -        list += getField(attr.name,
    -          attr.dataType,
    -          min.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case min@Min(Cast(attr: AttributeReference, changeDataType: 
DataType)) =>
    -        list += getField(attr.name,
    +          carbonTable,
    +          newColumnName,
    +          sum.prettyName)
    +      case sum@Sum(exp: Expression) =>
    +        list += getFieldForAggregateExpression(exp,
    +          sum.dataType,
    +          carbonTable,
    +          newColumnName,
    +          sum.prettyName)
    +      case count@Count(Seq(MatchCastExpression(exp: Expression, 
changeDataType: DataType))) =>
    +        list += getFieldForAggregateExpression(exp,
               changeDataType,
    -          min.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case max@Max(attr: AttributeReference) =>
    -        list += getField(attr.name,
    -          attr.dataType,
    -          max.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case max@Max(Cast(attr: AttributeReference, changeDataType: 
DataType)) =>
    -        list += getField(attr.name,
    +          carbonTable,
    +          newColumnName,
    +          count.prettyName)
    +      case count@Count(Seq(expression: Expression)) =>
    +        list += getFieldForAggregateExpression(expression,
    +          count.dataType,
    +          carbonTable,
    +          newColumnName,
    +          count.prettyName)
    +      case min@Min(MatchCastExpression(exp: Expression, changeDataType: 
DataType)) =>
    +        list += getFieldForAggregateExpression(exp,
               changeDataType,
    -          max.prettyName,
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case Average(attr: AttributeReference) =>
    -        list += getField(attr.name,
    -          attr.dataType,
    -          "sum",
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -        list += getField(attr.name,
    -          attr.dataType,
    -          "count",
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -      case Average(Cast(attr: AttributeReference, changeDataType: 
DataType)) =>
    -        list += getField(attr.name,
    +          carbonTable,
    +          newColumnName,
    +          min.prettyName)
    +      case min@Min(expression: Expression) =>
    +        list += getFieldForAggregateExpression(expression,
    +          min.dataType,
    +          carbonTable,
    +          newColumnName,
    +          min.prettyName)
    +      case max@Max(MatchCastExpression(exp: Expression, changeDataType: 
DataType)) =>
    +        list += getFieldForAggregateExpression(exp,
               changeDataType,
    -          "sum",
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    -        list += getField(attr.name,
    +          carbonTable,
    +          newColumnName,
    +          max.prettyName)
    +      case max@Max(expression: Expression) =>
    +        list += getFieldForAggregateExpression(expression,
    +          max.dataType,
    +          carbonTable,
    +          newColumnName,
    +          max.prettyName)
    +      case Average(MatchCastExpression(exp: Expression, changeDataType: 
DataType)) =>
    +        list += getFieldForAggregateExpression(exp,
               changeDataType,
    -          "count",
    -          carbonTable.getColumnByName(parentTableName, 
attr.name).getColumnId,
    -          parentTableName,
    -          parentDatabaseName, parentTableId = parentTableId)
    +          carbonTable,
    +          newColumnName,
    +          "sum")
    +        list += getFieldForAggregateExpression(exp,
    +          changeDataType,
    +          carbonTable,
    +          newColumnName,
    +          "count")
    +      case avg@Average(exp: Expression) =>
    +        list += getFieldForAggregateExpression(exp,
    +          avg.dataType,
    +          carbonTable,
    +          newColumnName,
    +          "sum")
    +        list += getFieldForAggregateExpression(exp,
    +          avg.dataType,
    +          carbonTable,
    +          newColumnName,
    +          "count")
           case others@_ =>
             throw new MalformedCarbonCommandException(s"Un-Supported 
Aggregation Type: ${
               others.prettyName}")
         }
       }
     
    +  /**
    +   * Below method will be used to get the field and its data map field 
object
    +   * for aggregate expression
    +   * @param expression
    +   *                   expression in aggregate function
    +   * @param dataType
    +   *                 data type
    +   * @param carbonTable
    +   *                    parent carbon table
    +   * @param newColumnName
    +   *                      column name of aggregate table
    +   * @param aggregationName
    +   *                        aggregate function name
    +   * @return field and its metadata tuple
    +   */
    +  def getFieldForAggregateExpression(expression: Expression,
    --- End diff --
    
    move parameter to next line, please follow this in the future


---

Reply via email to