maropu commented on a change in pull request #32494:
URL: https://github.com/apache/spark/pull/32494#discussion_r634128432



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala
##########
@@ -70,8 +68,20 @@ object UnionEstimation {
       None
     }
 
-    val unionOutput = union.output
+    val newMinMaxStats = computeMinMaxStats(union)
+    val newNullCountStats = computeNullCountStats(union)
+    val newAttrStats = combineStats(newMinMaxStats, newNullCountStats)

Review comment:
       It seems this method just merges the two stats, so how about inlining it 
here?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala
##########
@@ -81,40 +91,76 @@ object UnionEstimation {
             attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
         }
     }
-
-    val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
-      val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
-      attrToComputeMinMaxStats.foreach {
+    val outputAttrStats = attrToComputeMinMaxStats.map {
         case (attrs, outputIndex) =>
           val dataType = unionOutput(outputIndex).dataType
           val statComparator = createStatComparator(dataType)
           val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], 
Option[Any])]((None, None)) {
-              case ((minVal, maxVal), (attr, childIndex)) =>
-                val colStat = 
union.children(childIndex).stats.attributeStats(attr)
-                val min = if (minVal.isEmpty || 
statComparator(colStat.min.get, minVal.get)) {
-                  colStat.min
-                } else {
-                  minVal
-                }
-                val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
-                  colStat.max
-                } else {
-                  maxVal
-                }
-                (min, max)
-            }
+            case ((minVal, maxVal), (attr, childIndex)) =>
+              val colStat = 
union.children(childIndex).stats.attributeStats(attr)
+              val min = if (minVal.isEmpty || statComparator(colStat.min.get, 
minVal.get)) {
+                colStat.min
+              } else {
+                minVal
+              }
+              val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
+                colStat.max
+              } else {
+                maxVal
+              }
+              (min, max)
+          }
           val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
-          outputAttrStats += unionOutput(outputIndex) -> newStat
+          unionOutput(outputIndex) -> newStat
       }
+    if (outputAttrStats.nonEmpty) {
       AttributeMap(outputAttrStats.toSeq)
     } else {
       AttributeMap.empty[ColumnStat]
     }
+  }
 
-    Some(
-      Statistics(
-        sizeInBytes = sizeInBytes,
-        rowCount = outputRows,
-        attributeStats = newAttrStats))
+  /** This method computes the null count statistics and return the attribute 
stats Map. */
+  private def computeNullCountStats(union: Union) = {
+    val unionOutput = union.output
+    val attrToComputeNullCount = 
union.children.map(_.output).transpose.zipWithIndex.filter {
+      case (attrs, _) => attrs.zipWithIndex.forall {
+        case (attr, childIndex) =>
+          val attrStats = union.children(childIndex).stats.attributeStats
+          attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
+      }
+    }
+    val outputAttrStats = attrToComputeNullCount.map {
+      case (attrs, outputIndex) =>
+        val firstStat = union.children.head.stats.attributeStats(attrs.head)
+        val firstNullCount = firstStat.nullCount.get
+        val colWithNullStatValues = 
attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
+          case (totalNullCount, (attr, childIndex)) =>
+            val colStat = union.children(childIndex).stats.attributeStats(attr)
+            totalNullCount + colStat.nullCount.get
+        }
+        val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
+        unionOutput(outputIndex) -> newStat
+    }
+    if (outputAttrStats.nonEmpty) {
+      AttributeMap(outputAttrStats.toSeq)
+    } else {
+      AttributeMap.empty[ColumnStat]
+    }
+  }
+
+  // Combine the two Maps by updating the min-max stats Map with null count 
stats.
+  private def combineStats(
+      minMaxStats: AttributeMap[ColumnStat],
+      nullCountStats: AttributeMap[ColumnStat]) = {
+    val updatedNullCountStats = nullCountStats.keys.map { key =>
+      if (minMaxStats.get(key).isDefined) {
+        val updatedColsStats = minMaxStats(key).copy(nullCount = 
nullCountStats(key).nullCount)
+        key -> updatedColsStats
+      } else {
+        key -> nullCountStats(key)
+      }
+    }
+    AttributeMap(minMaxStats.toSeq ++ updatedNullCountStats)

Review comment:
       out of curiosity; I'm not familiar with it though, Scala explicitly 
defines this last-win logic for `Map`?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala
##########
@@ -81,40 +91,76 @@ object UnionEstimation {
             attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
         }
     }
-
-    val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
-      val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
-      attrToComputeMinMaxStats.foreach {
+    val outputAttrStats = attrToComputeMinMaxStats.map {
         case (attrs, outputIndex) =>
           val dataType = unionOutput(outputIndex).dataType
           val statComparator = createStatComparator(dataType)
           val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], 
Option[Any])]((None, None)) {
-              case ((minVal, maxVal), (attr, childIndex)) =>
-                val colStat = 
union.children(childIndex).stats.attributeStats(attr)
-                val min = if (minVal.isEmpty || 
statComparator(colStat.min.get, minVal.get)) {
-                  colStat.min
-                } else {
-                  minVal
-                }
-                val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
-                  colStat.max
-                } else {
-                  maxVal
-                }
-                (min, max)
-            }
+            case ((minVal, maxVal), (attr, childIndex)) =>
+              val colStat = 
union.children(childIndex).stats.attributeStats(attr)
+              val min = if (minVal.isEmpty || statComparator(colStat.min.get, 
minVal.get)) {
+                colStat.min
+              } else {
+                minVal
+              }
+              val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
+                colStat.max
+              } else {
+                maxVal
+              }
+              (min, max)
+          }
           val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
-          outputAttrStats += unionOutput(outputIndex) -> newStat
+          unionOutput(outputIndex) -> newStat
       }
+    if (outputAttrStats.nonEmpty) {
       AttributeMap(outputAttrStats.toSeq)
     } else {
       AttributeMap.empty[ColumnStat]
     }
+  }
 
-    Some(
-      Statistics(
-        sizeInBytes = sizeInBytes,
-        rowCount = outputRows,
-        attributeStats = newAttrStats))
+  /** This method computes the null count statistics and return the attribute 
stats Map. */
+  private def computeNullCountStats(union: Union) = {
+    val unionOutput = union.output
+    val attrToComputeNullCount = 
union.children.map(_.output).transpose.zipWithIndex.filter {
+      case (attrs, _) => attrs.zipWithIndex.forall {
+        case (attr, childIndex) =>
+          val attrStats = union.children(childIndex).stats.attributeStats
+          attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
+      }
+    }
+    val outputAttrStats = attrToComputeNullCount.map {
+      case (attrs, outputIndex) =>
+        val firstStat = union.children.head.stats.attributeStats(attrs.head)
+        val firstNullCount = firstStat.nullCount.get
+        val colWithNullStatValues = 
attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
+          case (totalNullCount, (attr, childIndex)) =>
+            val colStat = union.children(childIndex).stats.attributeStats(attr)
+            totalNullCount + colStat.nullCount.get
+        }
+        val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
+        unionOutput(outputIndex) -> newStat
+    }
+    if (outputAttrStats.nonEmpty) {
+      AttributeMap(outputAttrStats.toSeq)

Review comment:
       ditto: we don't need `toSeq`?

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala
##########
@@ -70,8 +68,20 @@ object UnionEstimation {
       None
     }
 
-    val unionOutput = union.output
+    val newMinMaxStats = computeMinMaxStats(union)
+    val newNullCountStats = computeNullCountStats(union)
+    val newAttrStats = combineStats(newMinMaxStats, newNullCountStats)
 
+    Some(
+      Statistics(
+        sizeInBytes = sizeInBytes,
+        rowCount = outputRows,
+        attributeStats = newAttrStats))
+  }
+
+  // This method computes the min-max statistics and return the attribute 
stats Map.

Review comment:
       nit: this comment does not look meaningful because the method name is 
clear enough.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala
##########
@@ -81,40 +91,76 @@ object UnionEstimation {
             attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
         }
     }
-
-    val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
-      val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
-      attrToComputeMinMaxStats.foreach {
+    val outputAttrStats = attrToComputeMinMaxStats.map {
         case (attrs, outputIndex) =>
           val dataType = unionOutput(outputIndex).dataType
           val statComparator = createStatComparator(dataType)
           val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], 
Option[Any])]((None, None)) {
-              case ((minVal, maxVal), (attr, childIndex)) =>
-                val colStat = 
union.children(childIndex).stats.attributeStats(attr)
-                val min = if (minVal.isEmpty || 
statComparator(colStat.min.get, minVal.get)) {
-                  colStat.min
-                } else {
-                  minVal
-                }
-                val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
-                  colStat.max
-                } else {
-                  maxVal
-                }
-                (min, max)
-            }
+            case ((minVal, maxVal), (attr, childIndex)) =>
+              val colStat = 
union.children(childIndex).stats.attributeStats(attr)
+              val min = if (minVal.isEmpty || statComparator(colStat.min.get, 
minVal.get)) {
+                colStat.min
+              } else {
+                minVal
+              }
+              val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
+                colStat.max
+              } else {
+                maxVal
+              }
+              (min, max)
+          }
           val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
-          outputAttrStats += unionOutput(outputIndex) -> newStat
+          unionOutput(outputIndex) -> newStat
       }
+    if (outputAttrStats.nonEmpty) {
       AttributeMap(outputAttrStats.toSeq)
     } else {
       AttributeMap.empty[ColumnStat]
     }
+  }
 
-    Some(
-      Statistics(
-        sizeInBytes = sizeInBytes,
-        rowCount = outputRows,
-        attributeStats = newAttrStats))
+  /** This method computes the null count statistics and return the attribute 
stats Map. */
+  private def computeNullCountStats(union: Union) = {
+    val unionOutput = union.output
+    val attrToComputeNullCount = 
union.children.map(_.output).transpose.zipWithIndex.filter {
+      case (attrs, _) => attrs.zipWithIndex.forall {
+        case (attr, childIndex) =>
+          val attrStats = union.children(childIndex).stats.attributeStats
+          attrStats.get(attr).isDefined && attrStats(attr).nullCount.isDefined
+      }
+    }
+    val outputAttrStats = attrToComputeNullCount.map {
+      case (attrs, outputIndex) =>
+        val firstStat = union.children.head.stats.attributeStats(attrs.head)
+        val firstNullCount = firstStat.nullCount.get
+        val colWithNullStatValues = 
attrs.zipWithIndex.tail.foldLeft[BigInt](firstNullCount) {
+          case (totalNullCount, (attr, childIndex)) =>
+            val colStat = union.children(childIndex).stats.attributeStats(attr)
+            totalNullCount + colStat.nullCount.get
+        }
+        val newStat = ColumnStat(nullCount = Some(colWithNullStatValues))
+        unionOutput(outputIndex) -> newStat
+    }
+    if (outputAttrStats.nonEmpty) {
+      AttributeMap(outputAttrStats.toSeq)
+    } else {
+      AttributeMap.empty[ColumnStat]
+    }
+  }
+
+  // Combine the two Maps by updating the min-max stats Map with null count 
stats.

Review comment:
       ditto

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/UnionEstimation.scala
##########
@@ -81,40 +91,76 @@ object UnionEstimation {
             attrStats.get(attr).isDefined && attrStats(attr).hasMinMaxStats
         }
     }
-
-    val newAttrStats = if (attrToComputeMinMaxStats.nonEmpty) {
-      val outputAttrStats = new ArrayBuffer[(Attribute, ColumnStat)]()
-      attrToComputeMinMaxStats.foreach {
+    val outputAttrStats = attrToComputeMinMaxStats.map {
         case (attrs, outputIndex) =>
           val dataType = unionOutput(outputIndex).dataType
           val statComparator = createStatComparator(dataType)
           val minMaxValue = attrs.zipWithIndex.foldLeft[(Option[Any], 
Option[Any])]((None, None)) {
-              case ((minVal, maxVal), (attr, childIndex)) =>
-                val colStat = 
union.children(childIndex).stats.attributeStats(attr)
-                val min = if (minVal.isEmpty || 
statComparator(colStat.min.get, minVal.get)) {
-                  colStat.min
-                } else {
-                  minVal
-                }
-                val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
-                  colStat.max
-                } else {
-                  maxVal
-                }
-                (min, max)
-            }
+            case ((minVal, maxVal), (attr, childIndex)) =>
+              val colStat = 
union.children(childIndex).stats.attributeStats(attr)
+              val min = if (minVal.isEmpty || statComparator(colStat.min.get, 
minVal.get)) {
+                colStat.min
+              } else {
+                minVal
+              }
+              val max = if (maxVal.isEmpty || statComparator(maxVal.get, 
colStat.max.get)) {
+                colStat.max
+              } else {
+                maxVal
+              }
+              (min, max)
+          }
           val newStat = ColumnStat(min = minMaxValue._1, max = minMaxValue._2)
-          outputAttrStats += unionOutput(outputIndex) -> newStat
+          unionOutput(outputIndex) -> newStat
       }
+    if (outputAttrStats.nonEmpty) {
       AttributeMap(outputAttrStats.toSeq)

Review comment:
       nit: we don't need `toSeq`?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to