Github user jackylk commented on a diff in the pull request: https://github.com/apache/carbondata/pull/1694#discussion_r158933818 --- Diff: integration/spark2/src/main/scala/org/apache/spark/sql/execution/command/preaaggregate/PreAggregateUtil.scala --- @@ -166,127 +208,160 @@ object PreAggregateUtil { aggFunctions: AggregateFunction, parentTableName: String, parentDatabaseName: String, - parentTableId: String) : scala.collection.mutable.ListBuffer[(Field, DataMapField)] = { + parentTableId: String, + newColumnName: String) : scala.collection.mutable.ListBuffer[(Field, DataMapField)] = { val list = scala.collection.mutable.ListBuffer.empty[(Field, DataMapField)] aggFunctions match { - case sum@Sum(attr: AttributeReference) => - list += getField(attr.name, - attr.dataType, - sum.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case sum@Sum(Cast(attr: AttributeReference, changeDataType: DataType)) => - list += getField(attr.name, + case sum@Sum(MatchCastExpression(exp: Expression, changeDataType: DataType)) => + list += getFieldForAggregateExpression(exp, changeDataType, - sum.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case count@Count(Seq(attr: AttributeReference)) => - list += getField(attr.name, - attr.dataType, - count.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case count@Count(Seq(Cast(attr: AttributeReference, _))) => - list += getField(attr.name, - attr.dataType, - count.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case min@Min(attr: AttributeReference) => - list += getField(attr.name, - attr.dataType, - min.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case min@Min(Cast(attr: AttributeReference, changeDataType: DataType)) => - list += getField(attr.name, + carbonTable, + newColumnName, + sum.prettyName) + case sum@Sum(exp: Expression) => + list += getFieldForAggregateExpression(exp, + sum.dataType, + carbonTable, + newColumnName, + sum.prettyName) + case count@Count(Seq(MatchCastExpression(exp: Expression, changeDataType: DataType))) => + list += getFieldForAggregateExpression(exp, changeDataType, - min.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case max@Max(attr: AttributeReference) => - list += getField(attr.name, - attr.dataType, - max.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case max@Max(Cast(attr: AttributeReference, changeDataType: DataType)) => - list += getField(attr.name, + carbonTable, + newColumnName, + count.prettyName) + case count@Count(Seq(expression: Expression)) => + list += getFieldForAggregateExpression(expression, + count.dataType, + carbonTable, + newColumnName, + count.prettyName) + case min@Min(MatchCastExpression(exp: Expression, changeDataType: DataType)) => + list += getFieldForAggregateExpression(exp, changeDataType, - max.prettyName, - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case Average(attr: AttributeReference) => - list += getField(attr.name, - attr.dataType, - "sum", - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - list += getField(attr.name, - attr.dataType, - "count", - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - case Average(Cast(attr: AttributeReference, changeDataType: DataType)) => - list += getField(attr.name, + carbonTable, + newColumnName, + min.prettyName) + case min@Min(expression: Expression) => + list += getFieldForAggregateExpression(expression, + min.dataType, + carbonTable, + newColumnName, + min.prettyName) + case max@Max(MatchCastExpression(exp: Expression, changeDataType: DataType)) => + list += getFieldForAggregateExpression(exp, changeDataType, - "sum", - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) - list += getField(attr.name, + carbonTable, + newColumnName, + max.prettyName) + case max@Max(expression: Expression) => + list += getFieldForAggregateExpression(expression, + max.dataType, + carbonTable, + newColumnName, + max.prettyName) + case Average(MatchCastExpression(exp: Expression, changeDataType: DataType)) => + list += getFieldForAggregateExpression(exp, changeDataType, - "count", - carbonTable.getColumnByName(parentTableName, attr.name).getColumnId, - parentTableName, - parentDatabaseName, parentTableId = parentTableId) + carbonTable, + newColumnName, + "sum") + list += getFieldForAggregateExpression(exp, + changeDataType, + carbonTable, + newColumnName, + "count") + case avg@Average(exp: Expression) => + list += getFieldForAggregateExpression(exp, + avg.dataType, + carbonTable, + newColumnName, + "sum") + list += getFieldForAggregateExpression(exp, + avg.dataType, + carbonTable, + newColumnName, + "count") case others@_ => throw new MalformedCarbonCommandException(s"Un-Supported Aggregation Type: ${ others.prettyName}") } } + /** + * Below method will be used to get the field and its data map field object + * for aggregate expression + * @param expression + * expression in aggregate function + * @param dataType + * data type + * @param carbonTable + * parent carbon table + * @param newColumnName + * column name of aggregate table + * @param aggregationName + * aggregate function name + * @return field and its metadata tuple + */ + def getFieldForAggregateExpression(expression: Expression, --- End diff -- move parameter to next line, please follow this in the future
---