maropu commented on a change in pull request #32494: URL: https://github.com/apache/spark/pull/32494#discussion_r634128432
########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala ########## @@ -70,8 +68,20 @@ object UnionEstimation { None } - val unionOutput = union.output + val newMinMaxStats = computeMinMaxStats(union) + val newNullCountStats = computeNullCountStats(union) + val newAttrStats = combineStats(newMinMaxStats, newNullCountStats) Review comment: It seems this method just merges the two stats, so how about inlining it here? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala ########## @@ -81,40 +91,76 @@ object UnionEstimation { attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats } } - - val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) { - val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() - attrToComputeMinMaxStats.foreach { + val outputAttrStats = attrToComputeMinMaxStats.map { case (attrs, outputIndex) => val dataType = unionOutput(outputIndex).dataType val statComparator = createStatComparator(dataType) val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) { - case ((minVal, maxVal), (attr, childIndex)) => - val colStat = union.children(childIndex).stats.attributeStats(attr) - val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { - colStat.min - } else { - minVal - } - val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { - colStat.max - } else { - maxVal - } - (min, max) - } + case ((minVal, maxVal), (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { + colStat.min + } else { + minVal + } + val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { + colStat.max + } else { + maxVal + } + (min, max) + } val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2) - outputAttrStats += unionOutput(outputIndex) -> newStat + unionOutput(outputIndex) -> newStat } + if (outputAttrStats.nonEmpty) { AttributeMap(outputAttrStats.toSeq) } else { AttributeMap.empty[ColumnStat] } + } - Some( - Statistics( - sizeInBytes = sizeInBytes, - rowCount = outputRows, - attributeStats = newAttrStats)) + /** This method computes the null count statistics and return the attribute stats Map. */ + private def computeNullCountStats(union: Union) = { + val unionOutput = union.output + val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter { + case (attrs, _) => attrs.zipWithIndex.forall { + case (attr, childIndex) => + val attrStats = union.children(childIndex).stats.attributeStats + attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined + } + } + val outputAttrStats = attrToComputeNullCount.map { + case (attrs, outputIndex) => + val firstStat = union.children.head.stats.attributeStats(attrs.head) + val firstNullCount = firstStat.nullCount.get + val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) { + case (totalNullCount, (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + totalNullCount + colStat.nullCount.get + } + val newStat = ColumnStat(nullCount = Some(colWithNullStatValues)) + unionOutput(outputIndex) -> newStat + } + if (outputAttrStats.nonEmpty) { + AttributeMap(outputAttrStats.toSeq) + } else { + AttributeMap.empty[ColumnStat] + } + } + + // Combine the two Maps by updating the min-max stats Map with null count stats. + private def combineStats( + minMaxStats: AttributeMap[ColumnStat], + nullCountStats: AttributeMap[ColumnStat]) = { + val updatedNullCountStats = nullCountStats.keys.map { key => + if (minMaxStats.get(key).isDefined) { + val updatedColsStats = minMaxStats(key).copy(nullCount = nullCountStats(key).nullCount) + key -> updatedColsStats + } else { + key -> nullCountStats(key) + } + } + AttributeMap(minMaxStats.toSeq ++ updatedNullCountStats) Review comment: out of curiosity; I'm not familiar with it though, Scala explicitly defines this last-win logic for `Map`? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala ########## @@ -81,40 +91,76 @@ object UnionEstimation { attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats } } - - val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) { - val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() - attrToComputeMinMaxStats.foreach { + val outputAttrStats = attrToComputeMinMaxStats.map { case (attrs, outputIndex) => val dataType = unionOutput(outputIndex).dataType val statComparator = createStatComparator(dataType) val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) { - case ((minVal, maxVal), (attr, childIndex)) => - val colStat = union.children(childIndex).stats.attributeStats(attr) - val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { - colStat.min - } else { - minVal - } - val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { - colStat.max - } else { - maxVal - } - (min, max) - } + case ((minVal, maxVal), (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { + colStat.min + } else { + minVal + } + val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { + colStat.max + } else { + maxVal + } + (min, max) + } val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2) - outputAttrStats += unionOutput(outputIndex) -> newStat + unionOutput(outputIndex) -> newStat } + if (outputAttrStats.nonEmpty) { AttributeMap(outputAttrStats.toSeq) } else { AttributeMap.empty[ColumnStat] } + } - Some( - Statistics( - sizeInBytes = sizeInBytes, - rowCount = outputRows, - attributeStats = newAttrStats)) + /** This method computes the null count statistics and return the attribute stats Map. */ + private def computeNullCountStats(union: Union) = { + val unionOutput = union.output + val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter { + case (attrs, _) => attrs.zipWithIndex.forall { + case (attr, childIndex) => + val attrStats = union.children(childIndex).stats.attributeStats + attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined + } + } + val outputAttrStats = attrToComputeNullCount.map { + case (attrs, outputIndex) => + val firstStat = union.children.head.stats.attributeStats(attrs.head) + val firstNullCount = firstStat.nullCount.get + val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) { + case (totalNullCount, (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + totalNullCount + colStat.nullCount.get + } + val newStat = ColumnStat(nullCount = Some(colWithNullStatValues)) + unionOutput(outputIndex) -> newStat + } + if (outputAttrStats.nonEmpty) { + AttributeMap(outputAttrStats.toSeq) Review comment: ditto: we don't need `toSeq`? ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala ########## @@ -70,8 +68,20 @@ object UnionEstimation { None } - val unionOutput = union.output + val newMinMaxStats = computeMinMaxStats(union) + val newNullCountStats = computeNullCountStats(union) + val newAttrStats = combineStats(newMinMaxStats, newNullCountStats) + Some( + Statistics( + sizeInBytes = sizeInBytes, + rowCount = outputRows, + attributeStats = newAttrStats)) + } + + // This method computes the min-max statistics and return the attribute stats Map. Review comment: nit: this comment does not look meaningful because the method name is clear enough. ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala ########## @@ -81,40 +91,76 @@ object UnionEstimation { attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats } } - - val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) { - val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() - attrToComputeMinMaxStats.foreach { + val outputAttrStats = attrToComputeMinMaxStats.map { case (attrs, outputIndex) => val dataType = unionOutput(outputIndex).dataType val statComparator = createStatComparator(dataType) val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) { - case ((minVal, maxVal), (attr, childIndex)) => - val colStat = union.children(childIndex).stats.attributeStats(attr) - val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { - colStat.min - } else { - minVal - } - val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { - colStat.max - } else { - maxVal - } - (min, max) - } + case ((minVal, maxVal), (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { + colStat.min + } else { + minVal + } + val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { + colStat.max + } else { + maxVal + } + (min, max) + } val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2) - outputAttrStats += unionOutput(outputIndex) -> newStat + unionOutput(outputIndex) -> newStat } + if (outputAttrStats.nonEmpty) { AttributeMap(outputAttrStats.toSeq) } else { AttributeMap.empty[ColumnStat] } + } - Some( - Statistics( - sizeInBytes = sizeInBytes, - rowCount = outputRows, - attributeStats = newAttrStats)) + /** This method computes the null count statistics and return the attribute stats Map. */ + private def computeNullCountStats(union: Union) = { + val unionOutput = union.output + val attrToComputeNullCount = union.children.map(_.output).transpose.zipWithIndex.filter { + case (attrs, _) => attrs.zipWithIndex.forall { + case (attr, childIndex) => + val attrStats = union.children(childIndex).stats.attributeStats + attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined + } + } + val outputAttrStats = attrToComputeNullCount.map { + case (attrs, outputIndex) => + val firstStat = union.children.head.stats.attributeStats(attrs.head) + val firstNullCount = firstStat.nullCount.get + val colWithNullStatValues = attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) { + case (totalNullCount, (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + totalNullCount + colStat.nullCount.get + } + val newStat = ColumnStat(nullCount = Some(colWithNullStatValues)) + unionOutput(outputIndex) -> newStat + } + if (outputAttrStats.nonEmpty) { + AttributeMap(outputAttrStats.toSeq) + } else { + AttributeMap.empty[ColumnStat] + } + } + + // Combine the two Maps by updating the min-max stats Map with null count stats. Review comment: ditto ########## File path: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala ########## @@ -81,40 +91,76 @@ object UnionEstimation { attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats } } - - val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) { - val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]() - attrToComputeMinMaxStats.foreach { + val outputAttrStats = attrToComputeMinMaxStats.map { case (attrs, outputIndex) => val dataType = unionOutput(outputIndex).dataType val statComparator = createStatComparator(dataType) val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], Option[Any])]((None, None)) { - case ((minVal, maxVal), (attr, childIndex)) => - val colStat = union.children(childIndex).stats.attributeStats(attr) - val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { - colStat.min - } else { - minVal - } - val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { - colStat.max - } else { - maxVal - } - (min, max) - } + case ((minVal, maxVal), (attr, childIndex)) => + val colStat = union.children(childIndex).stats.attributeStats(attr) + val min = if (minVal.isEmpty || statComparator(colStat.min.get, minVal.get)) { + colStat.min + } else { + minVal + } + val max = if (maxVal.isEmpty || statComparator(maxVal.get, colStat.max.get)) { + colStat.max + } else { + maxVal + } + (min, max) + } val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2) - outputAttrStats += unionOutput(outputIndex) -> newStat + unionOutput(outputIndex) -> newStat } + if (outputAttrStats.nonEmpty) { AttributeMap(outputAttrStats.toSeq) Review comment: nit: we don't need `toSeq`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org