This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1dfb4399334 [SPARK-38943][PYTHON] EWM support ignore_na
1dfb4399334 is described below

commit 1dfb4399334a02bf2e54faeb214c4a387753ddce
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Apr 20 10:46:58 2022 +0900

    [SPARK-38943][PYTHON] EWM support ignore_na
    
    ### What changes were proposed in this pull request?
    EWM support ignore_na
    
    ### Why are the changes needed?
    `ignore_na` is supported in pandas.
    after adding this param, EWM can deal with dataset containing NaN/Null.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, a new param added
    
    ### How was this patch tested?
    added testsuites
    
    Closes #36257 from zhengruifeng/ewm_support_ingnore_na.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 python/pyspark/pandas/generic.py                   | 26 +++++-
 python/pyspark/pandas/missing/window.py            |  1 -
 python/pyspark/pandas/tests/test_ewm.py            | 92 ++++++++++++++++++++++
 python/pyspark/pandas/window.py                    | 27 +++++--
 .../catalyst/expressions/windowExpressions.scala   | 34 +++++---
 .../spark/sql/api/python/PythonSQLUtils.scala      |  3 +-
 6 files changed, 159 insertions(+), 24 deletions(-)

diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index bb27c633a2b..21c880373ad 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -2619,7 +2619,7 @@ class Frame(object, metaclass=ABCMeta):
 
         return Expanding(self, min_periods=min_periods)
 
-    # TODO: 'adjust', 'ignore_na', 'axis', 'method' parameter should be 
implemented.
+    # TODO: 'adjust', 'axis', 'method' parameter should be implemented.
     def ewm(
         self: FrameLike,
         com: Optional[float] = None,
@@ -2627,6 +2627,7 @@ class Frame(object, metaclass=ABCMeta):
         halflife: Optional[float] = None,
         alpha: Optional[float] = None,
         min_periods: Optional[int] = None,
+        ignore_na: bool_type = False,
     ) -> "ExponentialMoving[FrameLike]":
         """
         Provide exponentially weighted window transformations.
@@ -2659,6 +2660,21 @@ class Frame(object, metaclass=ABCMeta):
             Minimum number of observations in window required to have a value
             (otherwise result is NA).
 
+        ignore_na : bool, default False
+            Ignore missing values when calculating weights.
+
+            - When ``ignore_na=False`` (default), weights are based on 
absolute positions.
+              For example, the weights of :math:`x_0` and :math:`x_2` used in 
calculating
+              the final weighted average of [:math:`x_0`, None, :math:`x_2`] 
are
+              :math:`(1-\alpha)^2` and :math:`1` if ``adjust=True``, and
+              :math:`(1-\alpha)^2` and :math:`\alpha` if ``adjust=False``.
+
+            - When ``ignore_na=True``, weights are based
+              on relative positions. For example, the weights of :math:`x_0` 
and :math:`x_2`
+              used in calculating the final weighted average of
+              [:math:`x_0`, None, :math:`x_2`] are :math:`1-\alpha` and 
:math:`1` if
+              ``adjust=True``, and :math:`1-\alpha` and :math:`\alpha` if 
``adjust=False``.
+
         Returns
         -------
         a Window sub-classed for the particular operation
@@ -2666,7 +2682,13 @@ class Frame(object, metaclass=ABCMeta):
         from pyspark.pandas.window import ExponentialMoving
 
         return ExponentialMoving(
-            self, com=com, span=span, halflife=halflife, alpha=alpha, 
min_periods=min_periods
+            self,
+            com=com,
+            span=span,
+            halflife=halflife,
+            alpha=alpha,
+            min_periods=min_periods,
+            ignore_na=ignore_na,
         )
 
     def get(self, key: Any, default: Optional[Any] = None) -> Any:
diff --git a/python/pyspark/pandas/missing/window.py 
b/python/pyspark/pandas/missing/window.py
index e6ac39901ff..237dc85c82c 100644
--- a/python/pyspark/pandas/missing/window.py
+++ b/python/pyspark/pandas/missing/window.py
@@ -152,6 +152,5 @@ class MissingPandasLikeExponentialMoving:
     corr = _unsupported_function_exponential_moving("corr")
 
     adjust = _unsupported_property_exponential_moving("adjust")
-    ignore_na = _unsupported_property_exponential_moving("ignore_na")
     axis = _unsupported_property_exponential_moving("axis")
     method = _unsupported_property_exponential_moving("method")
diff --git a/python/pyspark/pandas/tests/test_ewm.py 
b/python/pyspark/pandas/tests/test_ewm.py
index 7306aad44ff..d4c1e1ba06a 100644
--- a/python/pyspark/pandas/tests/test_ewm.py
+++ b/python/pyspark/pandas/tests/test_ewm.py
@@ -109,6 +109,98 @@ class EWMTest(PandasOnSparkTestCase, TestUtils):
             getattr(pdf.ewm(alpha=0.7, min_periods=2), f)().sum(),
         )
 
+        pdf = pd.DataFrame(
+            {
+                "s1": [None, 2, 3, 4],
+                "s2": [1, None, 3, 4],
+                "s3": [1, 3, 4, 5],
+                "s4": [1, 0, 3, 4],
+                "s5": [None, None, 1, None],
+                "s6": [None, None, None, None],
+            }
+        )
+        psdf = ps.from_pandas(pdf)
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=True), f)(),
+            getattr(pdf.ewm(com=0.2, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(com=0.2, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=False), f)(),
+            getattr(pdf.ewm(com=0.2, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(com=0.2, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(com=0.2, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=True), f)(),
+            getattr(pdf.ewm(span=1.7, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(span=1.7, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=False), f)(),
+            getattr(pdf.ewm(span=1.7, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(span=1.7, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(span=1.7, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(halflife=0.5, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False), f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), 
f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=True, min_periods=2), 
f)().sum(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), f)(),
+        )
+        self.assert_eq(
+            getattr(psdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), 
f)().sum(),
+            getattr(pdf.ewm(alpha=0.7, ignore_na=False, min_periods=2), 
f)().sum(),
+        )
+
     def test_ewm_mean(self):
         self._test_ewm_func("mean")
 
diff --git a/python/pyspark/pandas/window.py b/python/pyspark/pandas/window.py
index 4c5ababf0c8..9203a40cb47 100644
--- a/python/pyspark/pandas/window.py
+++ b/python/pyspark/pandas/window.py
@@ -1762,12 +1762,14 @@ class ExponentialMovingLike(Generic[FrameLike], 
metaclass=ABCMeta):
         halflife: Optional[float] = None,
         alpha: Optional[float] = None,
         min_periods: Optional[int] = None,
+        ignore_na: bool = False,
     ):
         if (min_periods is not None) and (min_periods < 0):
             raise ValueError("min_periods must be >= 0")
         if min_periods is None:
             min_periods = 0
         self._min_periods = min_periods
+        self._ignore_na = ignore_na
 
         self._window = window
         # This unbounded Window is later used to handle 'min_periods' for now.
@@ -1828,10 +1830,11 @@ class ExponentialMovingLike(Generic[FrameLike], 
metaclass=ABCMeta):
         unified_alpha = self._compute_unified_alpha()
 
         def mean(scol: Column) -> Column:
-            jf = SparkContext._active_spark_context._jvm.PythonSQLUtils.ewm
+            sql_utils = SparkContext._active_spark_context._jvm.PythonSQLUtils
             return F.when(
-                F.row_number().over(self._unbounded_window) >= 
self._min_periods,
-                Column(jf(scol._jc, unified_alpha)).over(self._window),
+                F.count(F.when(~scol.isNull(), 
1).otherwise(None)).over(self._unbounded_window)
+                >= self._min_periods,
+                Column(sql_utils.ewm(scol._jc, unified_alpha, 
self._ignore_na)).over(self._window),
             ).otherwise(SF.lit(None))
 
         return self._apply_as_series_or_frame(mean)
@@ -1846,6 +1849,7 @@ class ExponentialMoving(ExponentialMovingLike[FrameLike]):
         halflife: Optional[float] = None,
         alpha: Optional[float] = None,
         min_periods: Optional[int] = None,
+        ignore_na: bool = False,
     ):
         from pyspark.pandas.frame import DataFrame
         from pyspark.pandas.series import Series
@@ -1861,7 +1865,7 @@ class ExponentialMoving(ExponentialMovingLike[FrameLike]):
             Window.unboundedPreceding, Window.currentRow
         )
 
-        super().__init__(window_spec, com, span, halflife, alpha, min_periods)
+        super().__init__(window_spec, com, span, halflife, alpha, min_periods, 
ignore_na)
 
     def __getattr__(self, item: str) -> Any:
         if hasattr(MissingPandasLikeExponentialMoving, item):
@@ -1882,7 +1886,6 @@ class ExponentialMoving(ExponentialMovingLike[FrameLike]):
         -----
         There are behavior differences between pandas-on-Spark and pandas.
 
-        * the data should not contain NaNs. pandas-on-Spark will return an 
error.
         * the current implementation of this API uses Spark's Window without
           specifying partition specification. This leads to move all data into
           single partition in single machine and could cause serious
@@ -1928,10 +1931,18 @@ class 
ExponentialMoving(ExponentialMovingLike[FrameLike]):
         """
         return super().mean()
 
-    # TODO: when add 'adjust' and 'ignore_na' parameter, should add to here 
too.
+    # TODO: when add 'adjust' parameter, should add to here too.
     def __repr__(self) -> str:
-        return "ExponentialMoving [com={}, span={}, halflife={}, alpha={}, 
min_periods={}]".format(
-            self._com, self._span, self._halflife, self._alpha, 
self._min_periods
+        return (
+            "ExponentialMoving [com={}, span={}, halflife={}, alpha={}, "
+            "min_periods={}, ignore_na={}]".format(
+                self._com,
+                self._span,
+                self._halflife,
+                self._alpha,
+                self._min_periods,
+                self._ignore_na,
+            )
         )
 
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
index 458c2274432..c11a26ecd4f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala
@@ -1025,7 +1025,7 @@ case class PercentRank(children: Seq[Expression]) extends 
RankLike with SizeBase
  *    y_t = \frac{\sum_{i=0}^t w_i x_{t-i}}{\sum_{i=0}^t w_i},
  * where x_t is the input, y_t is the result and the w_i are the weights.
  */
-case class EWM(input: Expression, alpha: Double)
+case class EWM(input: Expression, alpha: Double, ignoreNA: Boolean)
   extends AggregateWindowFunction with UnaryLike[Expression] {
   assert(0 < alpha && alpha <= 1)
 
@@ -1033,27 +1033,37 @@ case class EWM(input: Expression, alpha: Double)
 
   private val numerator = AttributeReference("numerator", DoubleType, nullable 
= false)()
   private val denominator = AttributeReference("denominator", DoubleType, 
nullable = false)()
-  override def aggBufferAttributes: Seq[AttributeReference] = numerator :: 
denominator :: Nil
+  private val result = AttributeReference("result", DoubleType, nullable = 
true)()
 
-  override val initialValues: Seq[Expression] = Seq(Literal(0.0), Literal(0.0))
+  override def aggBufferAttributes: Seq[AttributeReference] =
+    numerator :: denominator :: result :: Nil
+
+  override val initialValues: Seq[Expression] =
+    Literal(0.0) :: Literal(0.0) :: Literal.create(null, DoubleType) :: Nil
 
   override val updateExpressions: Seq[Expression] = {
     val beta = Literal(1.0 - alpha)
     val casted = input.cast(DoubleType)
-    // TODO: after adding param ignore_na, we can remove this check
-    val error = RaiseError(Literal("Input values must not be null or 
NaN")).cast(DoubleType)
-    val validated = If(IsNull(casted) || IsNaN(casted), error, casted)
-    Seq(
-      /* numerator = */ numerator * beta + validated,
-      /* denominator = */ denominator * beta + Literal(1.0)
-    )
+    val isNA = IsNull(casted)
+    val newNumerator = numerator * beta + casted
+    val newDenominator = denominator * beta + Literal(1.0)
+
+    if (ignoreNA) {
+      /* numerator = */ If(isNA, numerator, newNumerator) ::
+      /* denominator = */ If(isNA, denominator, newDenominator) ::
+      /* result = */ If(isNA, result, newNumerator / newDenominator) :: Nil
+    } else {
+      /* numerator = */ If(isNA, numerator * beta, newNumerator) ::
+      /* denominator = */ If(isNA, denominator * beta, newDenominator) ::
+      /* result = */ If(isNA, result, newNumerator / newDenominator) :: Nil
+    }
   }
 
-  override val evaluateExpression: Expression = numerator / denominator
+  override val evaluateExpression: Expression = result
 
   override def prettyName: String = "ewm"
 
-  override def sql: String = s"$prettyName(${input.sql})"
+  override def sql: String = s"$prettyName(${input.sql}, $alpha, $ignoreNA)"
 
   override def child: Expression = input
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 95a42331308..2cc595ed2bf 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -95,7 +95,8 @@ private[sql] object PythonSQLUtils extends Logging {
 
   def castTimestampNTZToLong(c: Column): Column = 
Column(CastTimestampNTZToLong(c.expr))
 
-  def ewm(e: Column, alpha: Double): Column = Column(EWM(e.expr, alpha))
+  def ewm(e: Column, alpha: Double, ignoreNA: Boolean): Column =
+    Column(EWM(e.expr, alpha, ignoreNA))
 
   def lastNonNull(e: Column): Column = Column(LastNonNull(e.expr))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to