This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 69b8655d529 [SPARK-39192][PS][SQL] Make pandas-on-spark's kurt 
consistent with pandas
69b8655d529 is described below

commit 69b8655d529a36d9b041a9e03d3349f5f2c4cdfc
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed May 18 08:39:16 2022 +0900

    [SPARK-39192][PS][SQL] Make pandas-on-spark's kurt consistent with pandas
    
    ### What changes were proposed in this pull request?
    make pandas-on-spark's kurt consistent with pandas
    
    ### Why are the changes needed?
    1, the formulas of Kurtosis were different between spark sql and pandas;
    2, pandas zeros out small `numerator` and `denominator` for better 
numerical stability;
    
    ### Does this PR introduce _any_ user-facing change?
    yes, the logic of kurt changed
    
    ### How was this patch tested?
    added UT
    
    Closes #36560 from zhengruifeng/impl_pandas_kurt.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/generic.py                   | 12 ++++----
 .../pyspark/pandas/tests/test_generic_functions.py |  4 +--
 python/pyspark/pandas/tests/test_stats.py          |  9 +++++-
 .../expressions/aggregate/CentralMomentAgg.scala   | 33 ++++++++++++++++++++++
 .../spark/sql/api/python/PythonSQLUtils.scala      |  6 +++-
 5 files changed, 55 insertions(+), 9 deletions(-)

diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index f5073315164..ec38935ced8 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -1556,20 +1556,20 @@ class Frame(object, metaclass=ABCMeta):
         Examples
         --------
 
-        >>> df = ps.DataFrame({'a': [1, 2, 3, np.nan], 'b': [0.1, 0.2, 0.3, 
np.nan]},
+        >>> df = ps.DataFrame({'a': [1, 2, 3, np.nan, 6], 'b': [0.1, 0.2, 0.3, 
np.nan, 0.8]},
         ...                   columns=['a', 'b'])
 
         On a DataFrame:
 
         >>> df.kurtosis()
-        a   -1.5
-        b   -1.5
+        a    1.500000
+        b    2.703924
         dtype: float64
 
         On a Series:
 
         >>> df['a'].kurtosis()
-        -1.5
+        1.5
         """
         axis = validate_axis(axis)
 
@@ -1587,7 +1587,9 @@ class Frame(object, metaclass=ABCMeta):
                         spark_type_to_pandas_dtype(spark_type), 
spark_type.simpleString()
                     )
                 )
-            return F.kurtosis(spark_column)
+
+            sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
+            return Column(sql_utils.pandasKurtosis(spark_column._jc))
 
         return self._reduce_for_stat_function(
             kurtosis,
diff --git a/python/pyspark/pandas/tests/test_generic_functions.py 
b/python/pyspark/pandas/tests/test_generic_functions.py
index 5062daa77e2..2a83a038713 100644
--- a/python/pyspark/pandas/tests/test_generic_functions.py
+++ b/python/pyspark/pandas/tests/test_generic_functions.py
@@ -150,8 +150,8 @@ class GenericFunctionsTest(PandasOnSparkTestCase, 
TestUtils):
         self.assert_eq(pdf.a.kurtosis(skipna=False), 
psdf.a.kurtosis(skipna=False))
         self.assert_eq(pdf.a.kurtosis(), psdf.a.kurtosis())
         self.assert_eq(pdf.b.kurtosis(skipna=False), 
psdf.b.kurtosis(skipna=False))
-        # self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis())  AssertionError: 
nan != -2.0
-        self.assert_eq(-1.5, psdf.c.kurtosis())
+        self.assert_eq(pdf.b.kurtosis(), psdf.b.kurtosis())
+        self.assert_eq(pdf.c.kurtosis(), psdf.c.kurtosis())
 
 
 if __name__ == "__main__":
diff --git a/python/pyspark/pandas/tests/test_stats.py 
b/python/pyspark/pandas/tests/test_stats.py
index ccce140a4ac..e8f5048033b 100644
--- a/python/pyspark/pandas/tests/test_stats.py
+++ b/python/pyspark/pandas/tests/test_stats.py
@@ -180,6 +180,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
             self.assert_eq(psdf.min(axis=1), pdf.min(axis=1))
             self.assert_eq(psdf.sum(axis=1), pdf.sum(axis=1))
             self.assert_eq(psdf.product(axis=1), pdf.product(axis=1))
+            self.assert_eq(psdf.kurtosis(axis=0), pdf.kurtosis(axis=0), 
almost=True)
             self.assert_eq(psdf.kurtosis(axis=1), pdf.kurtosis(axis=1))
             self.assert_eq(psdf.skew(axis=0), pdf.skew(axis=0), almost=True)
             self.assert_eq(psdf.skew(axis=1), pdf.skew(axis=1))
@@ -216,6 +217,11 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
                 psdf.product(axis=1, numeric_only=True),
                 pdf.product(axis=1, numeric_only=True).astype(float),
             )
+            self.assert_eq(
+                psdf.kurtosis(axis=0, numeric_only=True),
+                pdf.kurtosis(axis=0, numeric_only=True),
+                almost=True,
+            )
             self.assert_eq(
                 psdf.kurtosis(axis=1, numeric_only=True), pdf.kurtosis(axis=1, 
numeric_only=True)
             )
@@ -236,7 +242,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
                 pdf.sem(axis=1, ddof=0, numeric_only=True),
             )
 
-    def test_skew_numerical_stability(self):
+    def test_skew_kurt_numerical_stability(self):
         pdf = pd.DataFrame(
             {
                 "A": [1, 1, 1, 1, 1],
@@ -249,6 +255,7 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils):
         )
         psdf = ps.from_pandas(pdf)
         self.assert_eq(psdf.skew(), pdf.skew(), almost=True)
+        self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True)
 
     def test_corr(self):
         # Disable arrow execution since corr() is using UDT internally which 
is not supported.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
index 2371cb5143a..782e7c71425 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala
@@ -367,3 +367,36 @@ case class PandasSkewness(child: Expression)
   override protected def withNewChildInternal(newChild: Expression): 
PandasSkewness =
     copy(child = newChild)
 }
+
+/**
+ * Kurtosis in Pandas' fashion. This expression is dedicated only for Pandas 
API on Spark.
+ * Refer to pandas.core.nanops.nankurt.
+ */
+case class PandasKurtosis(child: Expression)
+  extends CentralMomentAgg(child, true) {
+
+  override protected def momentOrder = 4
+
+  override val evaluateExpression: Expression = {
+    val adj = ((n - 1) / (n - 2)) * ((n - 1) / (n - 3)) * 3
+    val numerator = n * (n + 1) * (n - 1) * m4
+    val denominator = (n - 2) * (n - 3) * m2 * m2
+
+    // floating point error
+    //
+    // Pandas #18044 in _libs/windows.pyx calc_kurt follow this behavior
+    // to fix the fperr to treat denom <1e-14 as zero
+    //
+    // see https://github.com/pandas-dev/pandas/issues/18044 for details
+    val _numerator = If(abs(numerator) < 1e-14, Literal(0.0), numerator)
+    val _denominator = If(abs(denominator) < 1e-14, Literal(0.0), denominator)
+
+    If(n < 4, Literal.create(null, DoubleType),
+      If(_denominator === 0.0, Literal(0.0), _numerator / _denominator - adj))
+  }
+
+  override def prettyName: String = "pandas_kurtosis"
+
+  override protected def withNewChildInternal(newChild: Expression): 
PandasKurtosis =
+    copy(child = newChild)
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index a6307922dee..5b12ab8adce 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.{Column, DataFrame, Row, 
SparkSession}
 import org.apache.spark.sql.catalyst.CatalystTypeConverters
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.PandasSkewness
+import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
 import org.apache.spark.sql.execution.arrow.ArrowConverters
@@ -127,6 +127,10 @@ private[sql] object PythonSQLUtils extends Logging {
   def pandasSkewness(e: Column): Column = {
     Column(PandasSkewness(e.expr).toAggregateExpression(false))
   }
+
+  def pandasKurtosis(e: Column): Column = {
+    Column(PandasKurtosis(e.expr).toAggregateExpression(false))
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to