This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dde94a68b85 [SPARK-40305][PS] Implement Groupby.sem
dde94a68b85 is described below

commit dde94a68b850b23df8fca3531350a7b5643b3cd1
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Sun Sep 4 14:25:47 2022 +0800

    [SPARK-40305][PS] Implement Groupby.sem
    
    ### What changes were proposed in this pull request?
    Implement Groupby.sem
    
    ### Why are the changes needed?
    to increase API coverage
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new API
    
    ### How was this patch tested?
    added UT
    
    Closes #37756 from zhengruifeng/ps_groupby_sem.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../source/reference/pyspark.pandas/groupby.rst    |  1 +
 python/pyspark/pandas/generic.py                   |  2 +-
 python/pyspark/pandas/groupby.py                   | 80 ++++++++++++++++++++--
 python/pyspark/pandas/missing/groupby.py           |  2 -
 python/pyspark/pandas/tests/test_groupby.py        | 21 ++++++
 5 files changed, 97 insertions(+), 9 deletions(-)

diff --git a/python/docs/source/reference/pyspark.pandas/groupby.rst 
b/python/docs/source/reference/pyspark.pandas/groupby.rst
index 6d8eed8e684..b331a49b683 100644
--- a/python/docs/source/reference/pyspark.pandas/groupby.rst
+++ b/python/docs/source/reference/pyspark.pandas/groupby.rst
@@ -74,6 +74,7 @@ Computations / Descriptive Stats
    GroupBy.median
    GroupBy.min
    GroupBy.rank
+   GroupBy.sem
    GroupBy.std
    GroupBy.sum
    GroupBy.var
diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py
index bd2b68da51f..8ce3061b7cd 100644
--- a/python/pyspark/pandas/generic.py
+++ b/python/pyspark/pandas/generic.py
@@ -2189,7 +2189,7 @@ class Frame(object, metaclass=ABCMeta):
                 return F.stddev_samp(spark_column)
 
         def sem(psser: "Series") -> Column:
-            return std(psser) / pow(Frame._count_expr(psser), 0.5)
+            return std(psser) / F.sqrt(Frame._count_expr(psser))
 
         return self._reduce_for_stat_function(
             sem,
diff --git a/python/pyspark/pandas/groupby.py b/python/pyspark/pandas/groupby.py
index 4377ad6a5c9..84a5a3377f3 100644
--- a/python/pyspark/pandas/groupby.py
+++ b/python/pyspark/pandas/groupby.py
@@ -650,12 +650,11 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
         assert ddof in (0, 1)
 
         # Raise the TypeError when all aggregation columns are of unaccepted 
data types
-        all_unaccepted = True
-        for _agg_col in self._agg_columns:
-            if isinstance(_agg_col.spark.data_type, (NumericType, 
BooleanType)):
-                all_unaccepted = False
-                break
-        if all_unaccepted:
+        any_accepted = any(
+            isinstance(_agg_col.spark.data_type, (NumericType, BooleanType))
+            for _agg_col in self._agg_columns
+        )
+        if not any_accepted:
             raise TypeError(
                 "Unaccepted data types of aggregation columns; numeric or bool 
expected."
             )
@@ -827,6 +826,75 @@ class GroupBy(Generic[FrameLike], metaclass=ABCMeta):
 
         return self._prepare_return(DataFrame(internal))
 
+    def sem(self, ddof: int = 1) -> FrameLike:
+        """
+        Compute standard error of the mean of groups, excluding missing values.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        ddof : int, default 1
+            Delta Degrees of Freedom. The divisor used in calculations is N - 
ddof,
+            where N represents the number of elements.
+
+        Examples
+        --------
+        >>> df = ps.DataFrame({"A": [1, 2, 1, 1], "B": [True, False, False, 
True],
+        ...                    "C": [3, None, 3, 4], "D": ["a", "b", "b", 
"a"]})
+
+        >>> df.groupby("A").sem()
+                  B         C
+        A
+        1  0.333333  0.333333
+        2       NaN       NaN
+
+        >>> df.groupby("D").sem(ddof=1)
+             A    B    C
+        D
+        a  0.0  0.0  0.5
+        b  0.5  0.0  NaN
+
+        >>> df.B.groupby(df.A).sem()
+        A
+        1    0.333333
+        2         NaN
+        Name: B, dtype: float64
+
+        See Also
+        --------
+        pyspark.pandas.Series.sem
+        pyspark.pandas.DataFrame.sem
+        """
+        if ddof not in [0, 1]:
+            raise TypeError("ddof must be 0 or 1")
+
+        # Raise the TypeError when all aggregation columns are of unaccepted 
data types
+        any_accepted = any(
+            isinstance(_agg_col.spark.data_type, (NumericType, BooleanType))
+            for _agg_col in self._agg_columns
+        )
+        if not any_accepted:
+            raise TypeError(
+                "Unaccepted data types of aggregation columns; numeric or bool 
expected."
+            )
+
+        if ddof == 0:
+
+            def sem(col: Column) -> Column:
+                return F.stddev_pop(col) / F.sqrt(F.count(col))
+
+        else:
+
+            def sem(col: Column) -> Column:
+                return F.stddev_samp(col) / F.sqrt(F.count(col))
+
+        return self._reduce_for_stat_function(
+            sem,
+            accepted_spark_types=(NumericType, BooleanType),
+            bool_to_numeric=True,
+        )
+
     def all(self, skipna: bool = True) -> FrameLike:
         """
         Returns True if all values in the group are truthful, else False.
diff --git a/python/pyspark/pandas/missing/groupby.py 
b/python/pyspark/pandas/missing/groupby.py
index ce61b1df1e1..8ae8a68b5fe 100644
--- a/python/pyspark/pandas/missing/groupby.py
+++ b/python/pyspark/pandas/missing/groupby.py
@@ -65,7 +65,6 @@ class MissingPandasLikeDataFrameGroupBy:
     pipe = _unsupported_function("pipe")
     prod = _unsupported_function("prod")
     resample = _unsupported_function("resample")
-    sem = _unsupported_function("sem")
 
 
 class MissingPandasLikeSeriesGroupBy:
@@ -100,4 +99,3 @@ class MissingPandasLikeSeriesGroupBy:
     pipe = _unsupported_function("pipe")
     prod = _unsupported_function("prod")
     resample = _unsupported_function("resample")
-    sem = _unsupported_function("sem")
diff --git a/python/pyspark/pandas/tests/test_groupby.py 
b/python/pyspark/pandas/tests/test_groupby.py
index cff2ce706d8..d1a1c9afcc7 100644
--- a/python/pyspark/pandas/tests/test_groupby.py
+++ b/python/pyspark/pandas/tests/test_groupby.py
@@ -1321,11 +1321,21 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
         ):
             psdf.groupby("A")[["C"]].std()
 
+        with self.assertRaisesRegex(
+            TypeError, "Unaccepted data types of aggregation columns; numeric 
or bool expected."
+        ):
+            psdf.groupby("A")[["C"]].sem()
+
         self.assert_eq(
             psdf.groupby("A").std().sort_index(),
             pdf.groupby("A").std().sort_index(),
             check_exact=False,
         )
+        self.assert_eq(
+            psdf.groupby("A").sem().sort_index(),
+            pdf.groupby("A").sem().sort_index(),
+            check_exact=False,
+        )
 
         # TODO: fix bug of `sum` and re-enable the test below
         # self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), 
check_exact=False)
@@ -3055,6 +3065,17 @@ class GroupByTest(PandasOnSparkTestCase, TestUtils):
                 psdf.groupby("a")["b"].var(ddof=ddof).sort_index(),
                 check_exact=False,
             )
+            # sem
+            self.assert_eq(
+                pdf.groupby("a").sem(ddof=ddof).sort_index(),
+                psdf.groupby("a").sem(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
+            self.assert_eq(
+                pdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
+                psdf.groupby("a")["b"].sem(ddof=ddof).sort_index(),
+                check_exact=False,
+            )
 
     def test_getitem(self):
         psdf = ps.DataFrame(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to