Github user jackylk commented on a diff in the pull request: https://github.com/apache/carbondata/pull/1728#discussion_r159609322 --- Diff: integration/spark2/src/main/scala/org/apache/spark/sql/hive/CarbonPreAggregateRules.scala --- @@ -1218,112 +1180,125 @@ case class CarbonPreAggregateQueryRules(sparkSession: SparkSession) extends Rule * parent column name * @param carbonTable * parent carbon table - * @param tableName - * parent table name - * @param aggFunction - * aggregate function applied - * @param dataType - * data type of the column - * @param isChangedDataType - * is cast is applied on column * @param isFilterColumn * is filter is applied on column * @return query column */ def getQueryColumn(columnName: String, carbonTable: CarbonTable, - tableName: String, - aggFunction: String = "", - dataType: String = "", - isChangedDataType: Boolean = false, isFilterColumn: Boolean = false, timeseriesFunction: String = ""): QueryColumn = { - val columnSchema = carbonTable.getColumnByName(tableName, columnName.toLowerCase) + val columnSchema = carbonTable.getColumnByName(carbonTable.getTableName, columnName.toLowerCase) if(null == columnSchema) { null } else { - if (isChangedDataType) { new QueryColumn(columnSchema.getColumnSchema, - columnSchema.getDataType.getName, - aggFunction.toLowerCase, isFilterColumn, timeseriesFunction.toLowerCase) - } else { - new QueryColumn(columnSchema.getColumnSchema, - CarbonScalaUtil.convertSparkToCarbonSchemaDataType(dataType), - aggFunction.toLowerCase, - isFilterColumn, - timeseriesFunction.toLowerCase) - } } } } -object CarbonPreAggregateDataLoadingRules extends Rule[LogicalPlan] { - +/** + * Data loading rule class to validate and update the data loading query plan + * Validation rule: + * 1. update the avg aggregate expression with two columns sum and count + * 2. Remove duplicate sum and count expression if already there in plan + * @param sparkSession + * spark session + */ +case class CarbonPreAggregateDataLoadingRules(sparkSession: SparkSession) + extends Rule[LogicalPlan] { + lazy val parser = new CarbonSpark2SqlParser override def apply(plan: LogicalPlan): LogicalPlan = { - val validExpressionsMap = scala.collection.mutable.LinkedHashMap.empty[String, NamedExpression] + val validExpressionsMap = scala.collection.mutable.HashSet.empty[AggExpToColumnMappingModel] + val namedExpressionList = scala.collection.mutable.ListBuffer.empty[NamedExpression] plan transform { - case aggregate@Aggregate(_, aExp, _) if validateAggregateExpressions(aExp) => + case aggregate@Aggregate(_, + aExp, + CarbonSubqueryAlias(_, logicalRelation: LogicalRelation)) + if validateAggregateExpressions(aExp) && + logicalRelation.relation.isInstanceOf[CarbonDatasourceHadoopRelation] => + val carbonTable = logicalRelation.relation.asInstanceOf[CarbonDatasourceHadoopRelation] + .carbonTable aExp.foreach { - case alias: Alias => - validExpressionsMap ++= validateAggregateFunctionAndGetAlias(alias) - case _: UnresolvedAlias => - case namedExpr: NamedExpression => validExpressionsMap.put(namedExpr.name, namedExpr) + case attr: AttributeReference => + namedExpressionList += attr + case alias@Alias(_: AttributeReference, _) => + namedExpressionList += alias + case alias@Alias(aggExp: AggregateExpression, name) => + // get the updated expression for avg convert it to two expression + // sum and count + val expressions = PreAggregateUtil.getUpdateAggregateExpressions(aggExp) + // if size is more than one then it was for average + if(expressions.size > 1) { + // get the logical plan for sum expression + val logicalPlan_sum = PreAggregateUtil.getLogicalPlanFromAggExp( + expressions.head, + carbonTable.getTableName, + carbonTable.getDatabaseName, + logicalRelation, + sparkSession, + parser) + // get the logical plan fro count expression + val logicalPlan_count = PreAggregateUtil.getLogicalPlanFromAggExp( + expressions.last, + carbonTable.getTableName, + carbonTable.getDatabaseName, + logicalRelation, + sparkSession, + parser) + // check with same expression already sum is present then do not add to + // named expression list otherwise update the list and add it to set + if (!validExpressionsMap.contains(AggExpToColumnMappingModel(logicalPlan_sum))) { + namedExpressionList += + Alias(expressions.head, name + " _ sum")(NamedExpression.newExprId, + alias.qualifier, + Some(alias.metadata), + alias.isGenerated) + validExpressionsMap += AggExpToColumnMappingModel(logicalPlan_sum) + } + // check with same expression already count is present then do not add to + // named expression list otherwise update the list and add it to set + if (!validExpressionsMap.contains(AggExpToColumnMappingModel(logicalPlan_count))) { + namedExpressionList += + Alias(expressions.last, name + " _ count")(NamedExpression.newExprId, + alias.qualifier, + Some(alias.metadata), + alias.isGenerated) + validExpressionsMap += AggExpToColumnMappingModel(logicalPlan_count) + } + } else { + // get the logical plan for expression + val logicalPlan = PreAggregateUtil.getLogicalPlanFromAggExp( + expressions.head, + carbonTable.getTableName, + carbonTable.getDatabaseName, + logicalRelation, + sparkSession, + parser) + // check with same expression already present then do not add to + // named expression list otherwise update the list and add it to set + if (!validExpressionsMap.contains(AggExpToColumnMappingModel(logicalPlan))) { + namedExpressionList+=alias + validExpressionsMap += AggExpToColumnMappingModel(logicalPlan) + } + } + case alias@Alias(_: Expression, _) => + namedExpressionList += alias } - aggregate.copy(aggregateExpressions = validExpressionsMap.values.toSeq) + aggregate.copy(aggregateExpressions = namedExpressionList.toSeq) case plan: LogicalPlan => plan } } - /** - * This method will split the avg column into sum and count and will return a sequence of tuple - * of unique name, alias - * - */ - private def validateAggregateFunctionAndGetAlias(alias: Alias): Seq[(String, - NamedExpression)] = { - alias match { - case udf@Alias(_: ScalaUDF, name) => - Seq((name, udf)) - case alias@Alias(attrExpression: AggregateExpression, _) => - attrExpression.aggregateFunction match { - case Sum(attr: AttributeReference) => - (attr.name + "_sum", alias) :: Nil - case Sum(MatchCastExpression(attr: AttributeReference, _)) => - (attr.name + "_sum", alias) :: Nil - case Count(Seq(attr: AttributeReference)) => - (attr.name + "_count", alias) :: Nil - case Count(Seq(MatchCastExpression(attr: AttributeReference, _))) => - (attr.name + "_count", alias) :: Nil - case Average(attr: AttributeReference) => - Seq((attr.name + "_sum", Alias(attrExpression. - copy(aggregateFunction = Sum(attr), - resultId = NamedExpression.newExprId), attr.name + "_sum")()), - (attr.name, Alias(attrExpression. - copy(aggregateFunction = Count(attr), - resultId = NamedExpression.newExprId), attr.name + "_count")())) - case Average(cast@MatchCastExpression(attr: AttributeReference, _)) => - Seq((attr.name + "_sum", Alias(attrExpression. - copy(aggregateFunction = Sum(cast), - resultId = NamedExpression.newExprId), - attr.name + "_sum")()), - (attr.name, Alias(attrExpression. - copy(aggregateFunction = Count(cast), resultId = - NamedExpression.newExprId), attr.name + "_count")())) - case _ => Seq(("", alias)) - } - - } - } - /** * Called by PreAggregateLoadingRules to validate if plan is valid for applying rules or not. * If the plan has PreAggLoad i.e Loading UDF and does not have PreAgg i.e Query UDF then it is * valid. - * * @param namedExpression - * @return + * named expressions --- End diff -- move it up. comment should like `@param namedExpression named expressions`
---