Github user MrBago commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20695#discussion_r175971741
  
    --- Diff: python/pyspark/ml/stat.py ---
    @@ -132,6 +134,172 @@ def corr(dataset, column, method="pearson"):
             return _java2py(sc, javaCorrObj.corr(*args))
     
     
    +class Summarizer(object):
    +    """
    +    .. note:: Experimental
    +
    +    Tools for vectorized statistics on MLlib Vectors.
    +    The methods in this package provide various statistics for Vectors 
contained inside DataFrames.
    +    This class lets users pick the statistics they would like to extract 
for a given column.
    +
    +    >>> from pyspark.ml.stat import Summarizer
    +    >>> from pyspark.sql import Row
    +    >>> from pyspark.ml.linalg import Vectors
    +    >>> summarizer = Summarizer.metrics("mean", "count")
    +    >>> df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 
1.0, 1.0)),
    +    ...                      Row(weight=0.0, features=Vectors.dense(1.0, 
2.0, 3.0))]).toDF()
    +    >>> df.select(summarizer.summary(df.features, 
df.weight)).show(truncate=False)
    +    +-----------------------------------+
    +    |aggregate_metrics(features, weight)|
    +    +-----------------------------------+
    +    |[[1.0,1.0,1.0], 1]                 |
    +    +-----------------------------------+
    +    <BLANKLINE>
    +    >>> df.select(summarizer.summary(df.features)).show(truncate=False)
    +    +--------------------------------+
    +    |aggregate_metrics(features, 1.0)|
    +    +--------------------------------+
    +    |[[1.0,1.5,2.0], 2]              |
    +    +--------------------------------+
    +    <BLANKLINE>
    +    >>> df.select(Summarizer.mean(df.features, 
df.weight)).show(truncate=False)
    +    +--------------+
    +    |mean(features)|
    +    +--------------+
    +    |[1.0,1.0,1.0] |
    +    +--------------+
    +    <BLANKLINE>
    +    >>> df.select(Summarizer.mean(df.features)).show(truncate=False)
    +    +--------------+
    +    |mean(features)|
    +    +--------------+
    +    |[1.0,1.5,2.0] |
    +    +--------------+
    +    <BLANKLINE>
    +
    +    .. versionadded:: 2.4.0
    +
    +    """
    +    def __init__(self, js):
    +        self._js = js
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def mean(col, weightCol=None):
    +        """
    +        return a column of mean summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "mean")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def variance(col, weightCol=None):
    +        """
    +        return a column of variance summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "variance")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def count(col, weightCol=None):
    +        """
    +        return a column of count summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "count")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def numNonZeros(col, weightCol=None):
    +        """
    +        return a column of numNonZero summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "numNonZeros")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def max(col, weightCol=None):
    +        """
    +        return a column of max summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "max")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def min(col, weightCol=None):
    +        """
    +        return a column of min summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "min")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def normL1(col, weightCol=None):
    +        """
    +        return a column of normL1 summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "normL1")
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def normL2(col, weightCol=None):
    +        """
    +        return a column of normL2 summary
    +        """
    +        return Summarizer._get_single_metric(col, weightCol, "normL2")
    +
    +    @staticmethod
    +    def _check_param(featureCol, weightCol):
    +        if weightCol is None:
    +            weightCol = lit(1.0)
    +        if not isinstance(featureCol, Column) or not isinstance(weightCol, 
Column):
    +            raise TypeError("featureCol and weightCol should be a Column")
    +        return featureCol, weightCol
    +
    +    @staticmethod
    +    def _get_single_metric(col, weightCol, metric):
    +        col, weightCol = Summarizer._check_param(col, weightCol)
    +        return 
Column(JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer." + 
metric,
    +                                                col._jc, weightCol._jc))
    +
    +    @staticmethod
    +    @since("2.4.0")
    +    def metrics(*metrics):
    +        """
    +        Given a list of metrics, provides a builder that it turns computes 
metrics from a column.
    +
    +        See the documentation of [[Summarizer]] for an example.
    +
    +        The following metrics are accepted (case sensitive):
    +         - mean: a vector that contains the coefficient-wise mean.
    +         - variance: a vector tha contains the coefficient-wise variance.
    +         - count: the count of all vectors seen.
    +         - numNonzeros: a vector with the number of non-zeros for each 
coefficients
    +         - max: the maximum for each coefficient.
    +         - min: the minimum for each coefficient.
    +         - normL2: the Euclidian norm for each coefficient.
    +         - normL1: the L1 norm of each coefficient (sum of the absolute 
values).
    +
    +        :param metrics metrics that can be provided.
    +        :return a Summarizer
    +
    +        Note: Currently, the performance of this interface is about 2x~3x 
slower then using the RDD
    +        interface.
    +        """
    +        sc = SparkContext._active_spark_context
    +        js = 
JavaWrapper._new_java_obj("org.apache.spark.ml.stat.Summarizer.metrics",
    +                                       _to_seq(sc, metrics))
    +        return Summarizer(js)
    +
    +    @since("2.4.0")
    +    def summary(self, featureCol, weightCol=None):
    --- End diff --
    
    We might want to move the "summary" method into another class, and have 
Summary only contain static methods. That will help with autocomplete so that 
it's clear that you're not meant to do `Summery.metrics("min").mean(features).`


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to