itholic commented on a change in pull request #34931: URL: https://github.com/apache/spark/pull/34931#discussion_r773645465
########## File path: python/pyspark/pandas/frame.py ########## @@ -8828,22 +8847,138 @@ def describe(self, percentiles: Optional[List[float]] = None) -> "DataFrame": else: percentiles = [0.25, 0.5, 0.75] - formatted_perc = ["{:.0%}".format(p) for p in sorted(percentiles)] - stats = ["count", "mean", "stddev", "min", *formatted_perc, "max"] + # Identify the cases + only_string_cols = ( + len(psser_numeric) == 0 and len(psser_timestamp) == 0 and len(psser_string) > 0 + ) + only_numeric_cols = len(psser_numeric) > 0 and len(psser_timestamp) == 0 + all_timestamp_cols = len(psser_numeric) == 0 and len(psser_timestamp) > 0 + any_timestamp_cols = len(psser_numeric) > 0 and len(psser_timestamp) > 0 + + if only_string_cols: + # Handling string type columns + # We will retrive the `count`, `unique`, `top` and `freq`. + exprs_string = [psser.spark.column for psser in psser_string] + sdf = self._internal.spark_frame.select(*exprs_string) + + # Get `count` & `unique` for each columns + counts, uniques = map(lambda x: x[1:], sdf.summary("count", "count_distinct").take(2)) + + # Get `top` & `freq` for each columns + tops = [] + freqs = [] + # TODO: We should do it in single pass since invoking Spark job for every columns + # is too expensive. + for column in exprs_string: + top, freq = sdf.groupby(column).count().sort("count", ascending=False).first() + tops.append(str(top)) + freqs.append(str(freq)) + + stats = [counts, uniques, tops, freqs] + stats_names = ["count", "unique", "top", "freq"] + + result: DataFrame = DataFrame( + data=stats, + index=stats_names, + columns=column_names, + ) + elif only_numeric_cols: + # Handling numeric columns + exprs_numeric = [ + psser._dtype_op.nan_to_null(psser).spark.column for psser in psser_numeric + ] + formatted_perc = ["{:.0%}".format(p) for p in sorted(percentiles)] + stats = ["count", "mean", "stddev", "min", *formatted_perc, "max"] - sdf = self._internal.spark_frame.select(*exprs).summary(*stats) - sdf = sdf.replace("stddev", "std", subset=["summary"]) + # In this case, we can simply use `summary` to calculate the stats. + sdf = self._internal.spark_frame.select(*exprs_numeric).summary(*stats) + sdf = sdf.replace("stddev", "std", subset=["summary"]) - internal = InternalFrame( - spark_frame=sdf, - index_spark_columns=[scol_for(sdf, "summary")], - column_labels=column_labels, - data_spark_columns=[ - scol_for(sdf, self._internal.spark_column_name_for(label)) - for label in column_labels - ], - ) - return DataFrame(internal).astype("float64") + internal = InternalFrame( + spark_frame=sdf, + index_spark_columns=[scol_for(sdf, "summary")], + column_labels=column_labels, + data_spark_columns=[ + scol_for(sdf, self._internal.spark_column_name_for(label)) + for label in column_labels + ], + ) + result = DataFrame(internal).astype("float64") + elif all_timestamp_cols or any_timestamp_cols: + column_names = [ + self._internal.spark_column_name_for(column_label) for column_label in column_labels + ] + column_length = len(column_labels) + + # Apply stat functions for each column. + count_exprs = map(F.count, column_names) + min_exprs = map(F.min, column_names) + # Here we try to flat the multiple map into single list that contains each calculated + # percentile using `chain`. + # e.g. flat the `[<map object at 0x7fc1907dc280>, <map object at 0x7fc1907dcc70>]` + # to `[Column<'percentile_approx(A, 0.2, 10000)'>, Column<'percentile_approx(B, 0.2, 10000)'>, + # Column<'percentile_approx(A, 0.5, 10000)'>, Column<'percentile_approx(B, 0.5, 10000)'>]` + perc_exprs = chain( + *[ + map(F.percentile_approx, column_names, [percentile] * column_length) + for percentile in percentiles + ] + ) + max_exprs = map(F.max, column_names) + mean_exprs = [] + for column_name, spark_data_type in zip(column_names, spark_data_types): + mean_exprs.append(F.mean(column_name).astype(spark_data_type)) + exprs = [*count_exprs, *mean_exprs, *min_exprs, *perc_exprs, *max_exprs] + + formatted_perc = ["{:.0%}".format(p) for p in sorted(percentiles)] + stats_names = ["count", "mean", "min", *formatted_perc, "max"] + + # If not all columns are timestamp type, + # we also need to calculate the `std` for numeric columns + if any_timestamp_cols: Review comment: `any_timestamp_cols` is boolean now, but let me address it. Thanks! -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org