This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ae08787f5c5 [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods ` ae08787f5c5 is described below commit ae08787f5c50e485ef4432a0c2da8b3b7290d725 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Sep 13 14:44:18 2022 +0800 [SPARK-40399][PS] Make `pearson` correlation in `DataFrame.corr` support missing values and `min_periods ` ### What changes were proposed in this pull request? refactor `pearson` correlation in `DataFrame.corr` to: 1. support missing values; 2. add parameter `min_periods`; 3. enable arrow execution since no longer depend on `VectorUDT`; 4. support lazy evaluation; before ``` In [1]: import pyspark.pandas as ps In [2]: df = ps.DataFrame([[1,2], [3,None]]) In [3]: df 0 1 0 1 2.0 1 3 NaN In [4]: df.corr() 22/09/09 16:53:18 ERROR Executor: Exception in task 9.0 in stage 5.0 (TID 24) org.apache.spark.SparkException: [FAILED_EXECUTE_UDF] Failed to execute user defined function (VectorAssembler$$Lambda$2660/0x0000000801215840: (struct<0_double_VectorAssembler_0915f96ec689:double,1:double>) => struct<type:tinyint,size:int,indices:array<int>,values:array<double>>) ``` after ``` In [1]: import pyspark.pandas as ps In [2]: df = ps.DataFrame([[1,2], [3,None]]) In [3]: df.corr() 0 1 0 1.0 NaN 1 NaN NaN In [4]: df.to_pandas().corr() /Users/ruifeng.zheng/Dev/spark/python/pyspark/pandas/utils.py:976: PandasAPIOnSparkAdviceWarning: `to_pandas` loads all data into the driver's memory. It should only be used if the resulting pandas DataFrame is expected to be small. warnings.warn(message, PandasAPIOnSparkAdviceWarning) Out[4]: 0 1 0 1.0 NaN 1 NaN NaN ``` ### Why are the changes needed? for API coverage and support common cases containing missing values ### Does this PR introduce _any_ user-facing change? yes, API change, new parameter supported ### How was this patch tested? added UT Closes #37845 from zhengruifeng/ps_df_corr_missing_value. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/pandas/frame.py | 209 +++++++++++++++++++++++++++++- python/pyspark/pandas/tests/test_stats.py | 34 +++++ 2 files changed, 238 insertions(+), 5 deletions(-) diff --git a/python/pyspark/pandas/frame.py b/python/pyspark/pandas/frame.py index 3438d07896e..cf14a548266 100644 --- a/python/pyspark/pandas/frame.py +++ b/python/pyspark/pandas/frame.py @@ -1417,15 +1417,23 @@ class DataFrame(Frame, Generic[T]): agg = aggregate - def corr(self, method: str = "pearson") -> "DataFrame": + def corr(self, method: str = "pearson", min_periods: Optional[int] = None) -> "DataFrame": """ Compute pairwise correlation of columns, excluding NA/null values. + .. versionadded:: 3.3.0 + Parameters ---------- method : {'pearson', 'spearman'} * pearson : standard correlation coefficient * spearman : Spearman rank correlation + min_periods : int, optional + Minimum number of observations required per pair of columns + to have a valid result. Currently only available for Pearson + correlation. + + .. versionadded:: 3.4.0 Returns ------- @@ -1454,11 +1462,202 @@ class DataFrame(Frame, Generic[T]): There are behavior differences between pandas-on-Spark and pandas. * the `method` argument only accepts 'pearson', 'spearman' - * the data should not contain NaNs. pandas-on-Spark will return an error. - * pandas-on-Spark doesn't support the following argument(s). + * if the `method` is `spearman`, the data should not contain NaNs. + * if the `method` is `spearman`, `min_periods` argument is not supported. + """ + if method not in ["pearson", "spearman", "kendall"]: + raise ValueError(f"Invalid method {method}") + if method == "kendall": + raise NotImplementedError("method doesn't support kendall for now") + if min_periods is not None and not isinstance(min_periods, int): + raise TypeError(f"Invalid min_periods type {type(min_periods).__name__}") + if min_periods is not None and method == "spearman": + raise NotImplementedError("min_periods doesn't support spearman for now") + + if method == "pearson": + min_periods = 1 if min_periods is None else min_periods + internal = self._internal.resolved_copy + numeric_labels = [ + label + for label in internal.column_labels + if isinstance(internal.spark_type_for(label), (NumericType, BooleanType)) + ] + numeric_scols: List[Column] = [ + internal.spark_column_for(label).cast("double") for label in numeric_labels + ] + numeric_col_names: List[str] = [name_like_string(label) for label in numeric_labels] + num_scols = len(numeric_scols) + + sdf = internal.spark_frame + tmp_index_1_col_name = verify_temp_column_name(sdf, "__tmp_index_1_col__") + tmp_index_2_col_name = verify_temp_column_name(sdf, "__tmp_index_2_col__") + tmp_value_1_col_name = verify_temp_column_name(sdf, "__tmp_value_1_col__") + tmp_value_2_col_name = verify_temp_column_name(sdf, "__tmp_value_2_col__") + + # simple dataset + # +---+---+----+ + # | A| B| C| + # +---+---+----+ + # | 1| 2| 3.0| + # | 4| 1|null| + # +---+---+----+ + + pair_scols: List[Column] = [] + for i in range(0, num_scols): + for j in range(i, num_scols): + pair_scols.append( + F.struct( + F.lit(i).alias(tmp_index_1_col_name), + F.lit(j).alias(tmp_index_2_col_name), + numeric_scols[i].alias(tmp_value_1_col_name), + numeric_scols[j].alias(tmp_value_2_col_name), + ) + ) + + # +-------------------+-------------------+-------------------+-------------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_value_1_col__|__tmp_value_2_col__| + # +-------------------+-------------------+-------------------+-------------------+ + # | 0| 0| 1.0| 1.0| + # | 0| 1| 1.0| 2.0| + # | 0| 2| 1.0| 3.0| + # | 1| 1| 2.0| 2.0| + # | 1| 2| 2.0| 3.0| + # | 2| 2| 3.0| 3.0| + # | 0| 0| 4.0| 4.0| + # | 0| 1| 4.0| 1.0| + # | 0| 2| 4.0| null| + # | 1| 1| 1.0| 1.0| + # | 1| 2| 1.0| null| + # | 2| 2| null| null| + # +-------------------+-------------------+-------------------+-------------------+ + tmp_tuple_col_name = verify_temp_column_name(sdf, "__tmp_tuple_col__") + sdf = sdf.select(F.explode(F.array(*pair_scols)).alias(tmp_tuple_col_name)).select( + F.col(f"{tmp_tuple_col_name}.{tmp_index_1_col_name}").alias(tmp_index_1_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_index_2_col_name}").alias(tmp_index_2_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_value_1_col_name}").alias(tmp_value_1_col_name), + F.col(f"{tmp_tuple_col_name}.{tmp_value_2_col_name}").alias(tmp_value_2_col_name), + ) + + # +-------------------+-------------------+------------------------+-----------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__|__tmp_count_col__| + # +-------------------+-------------------+------------------------+-----------------+ + # | 2| 2| null| 1| + # | 1| 2| null| 1| + # | 1| 1| 1.0| 2| + # | 0| 0| 1.0| 2| + # | 0| 1| -1.0| 2| + # | 0| 2| null| 1| + # +-------------------+-------------------+------------------------+-----------------+ + tmp_corr_col_name = verify_temp_column_name(sdf, "__tmp_pearson_corr_col__") + tmp_count_col_name = verify_temp_column_name(sdf, "__tmp_count_col__") + sdf = sdf.groupby(tmp_index_1_col_name, tmp_index_2_col_name).agg( + F.corr(tmp_value_1_col_name, tmp_value_2_col_name).alias(tmp_corr_col_name), + F.count( + F.when( + F.col(tmp_value_1_col_name).isNotNull() + & F.col(tmp_value_2_col_name).isNotNull(), + 1, + ) + ).alias(tmp_count_col_name), + ) + + # +-------------------+-------------------+------------------------+ + # |__tmp_index_1_col__|__tmp_index_2_col__|__tmp_pearson_corr_col__| + # +-------------------+-------------------+------------------------+ + # | 2| 2| null| + # | 1| 2| null| + # | 2| 1| null| + # | 1| 1| 1.0| + # | 0| 0| 1.0| + # | 0| 1| -1.0| + # | 1| 0| -1.0| + # | 0| 2| null| + # | 2| 0| null| + # +-------------------+-------------------+------------------------+ + sdf = ( + sdf.withColumn( + tmp_corr_col_name, + F.when( + F.col(tmp_count_col_name) >= min_periods, F.col(tmp_corr_col_name) + ).otherwise(F.lit(None)), + ) + .withColumn( + tmp_tuple_col_name, + F.explode( + F.when( + F.col(tmp_index_1_col_name) == F.col(tmp_index_2_col_name), + F.lit([0]), + ).otherwise(F.lit([0, 1])) + ), + ) + .select( + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_1_col_name)) + .otherwise(F.col(tmp_index_2_col_name)) + .alias(tmp_index_1_col_name), + F.when(F.col(tmp_tuple_col_name) == 0, F.col(tmp_index_2_col_name)) + .otherwise(F.col(tmp_index_1_col_name)) + .alias(tmp_index_2_col_name), + F.col(tmp_corr_col_name), + ) + ) + + # +-------------------+--------------------+ + # |__tmp_index_1_col__| __tmp_array_col__| + # +-------------------+--------------------+ + # | 0|[{0, 1.0}, {1, -1...| + # | 1|[{0, -1.0}, {1, 1...| + # | 2|[{0, null}, {1, n...| + # +-------------------+--------------------+ + tmp_array_col_name = verify_temp_column_name(sdf, "__tmp_array_col__") + sdf = ( + sdf.groupby(tmp_index_1_col_name) + .agg( + F.array_sort( + F.collect_list( + F.struct(F.col(tmp_index_2_col_name), F.col(tmp_corr_col_name)) + ) + ).alias(tmp_array_col_name) + ) + .orderBy(tmp_index_1_col_name) + ) + + for i in range(0, num_scols): + sdf = sdf.withColumn( + tmp_tuple_col_name, F.get(F.col(tmp_array_col_name), i) + ).withColumn( + numeric_col_names[i], + F.col(f"{tmp_tuple_col_name}.{tmp_corr_col_name}"), + ) + + index_col_names: List[str] = [] + if internal.column_labels_level > 1: + for level in range(0, internal.column_labels_level): + index_col_name = SPARK_INDEX_NAME_FORMAT(level) + indices = [label[level] for label in numeric_labels] + sdf = sdf.withColumn( + index_col_name, F.get(F.lit(indices), F.col(tmp_index_1_col_name)) + ) + index_col_names.append(index_col_name) + else: + sdf = sdf.withColumn( + SPARK_DEFAULT_INDEX_NAME, + F.get(F.lit(numeric_col_names), F.col(tmp_index_1_col_name)), + ) + index_col_names = [SPARK_DEFAULT_INDEX_NAME] + + sdf = sdf.select(*index_col_names, *numeric_col_names) + + return DataFrame( + InternalFrame( + spark_frame=sdf, + index_spark_columns=[ + scol_for(sdf, index_col_name) for index_col_name in index_col_names + ], + column_labels=numeric_labels, + column_label_names=internal.column_label_names, + ) + ) - * `min_periods` argument is not supported - """ return cast(DataFrame, ps.from_pandas(corr(self, method))) # TODO: add axis parameter and support more methods diff --git a/python/pyspark/pandas/tests/test_stats.py b/python/pyspark/pandas/tests/test_stats.py index e8f5048033b..7e2ca96e60f 100644 --- a/python/pyspark/pandas/tests/test_stats.py +++ b/python/pyspark/pandas/tests/test_stats.py @@ -257,6 +257,40 @@ class StatsTest(PandasOnSparkTestCase, SQLTestUtils): self.assert_eq(psdf.skew(), pdf.skew(), almost=True) self.assert_eq(psdf.kurt(), pdf.kurt(), almost=True) + def test_dataframe_corr(self): + # existing 'test_corr' is mixed by df.corr and ser.corr, will delete 'test_corr' + # when we have separate tests for df.corr and ser.corr + pdf = makeMissingDataframe(0.3, 42) + psdf = ps.from_pandas(pdf) + + with self.assertRaisesRegex(ValueError, "Invalid method"): + psdf.corr("std") + with self.assertRaisesRegex(NotImplementedError, "kendall for now"): + psdf.corr("kendall") + with self.assertRaisesRegex(TypeError, "Invalid min_periods type"): + psdf.corr(min_periods="3") + with self.assertRaisesRegex(NotImplementedError, "spearman for now"): + psdf.corr(method="spearman", min_periods=3) + + self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) + self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) + self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + self.assert_eq( + (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False + ) + + # multi-index columns + columns = pd.MultiIndex.from_tuples([("X", "A"), ("X", "B"), ("Y", "C"), ("Z", "D")]) + pdf.columns = columns + psdf.columns = columns + + self.assert_eq(psdf.corr(), pdf.corr(), check_exact=False) + self.assert_eq(psdf.corr(min_periods=1), pdf.corr(min_periods=1), check_exact=False) + self.assert_eq(psdf.corr(min_periods=3), pdf.corr(min_periods=3), check_exact=False) + self.assert_eq( + (psdf + 1).corr(min_periods=2), (pdf + 1).corr(min_periods=2), check_exact=False + ) + def test_corr(self): # Disable arrow execution since corr() is using UDT internally which is not supported. with self.sql_conf({SPARK_CONF_ARROW_ENABLED: False}): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org