This is an automated email from the ASF dual-hosted git repository. zero323 pushed a commit to branch branch-3.3 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.3 by this push: new 60ce69df029 [SPARK-37234][PYTHON] Inline type hints for python/pyspark/mllib/stat/_statistics.py 60ce69df029 is described below commit 60ce69df029b1e1d7cf7f7eece02e668de24cca8 Author: dch nguyen <dchvn....@gmail.com> AuthorDate: Sun Apr 10 14:14:33 2022 +0200 [SPARK-37234][PYTHON] Inline type hints for python/pyspark/mllib/stat/_statistics.py ### What changes were proposed in this pull request? Inline type hints for python/pyspark/mllib/stat/_statistics.py ### Why are the changes needed? We can take advantage of static type checking within the functions by inlining the type hints. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #34513 from dchvn/SPARK-37234. Lead-authored-by: dch nguyen <dchvn....@gmail.com> Co-authored-by: dch nguyen <dgd_contribu...@viettel.com.vn> Signed-off-by: zero323 <mszymkiew...@gmail.com> (cherry picked from commit c3dcdb118ca403a8fbefc3308a116d9e12a1f038) Signed-off-by: zero323 <mszymkiew...@gmail.com> --- python/pyspark/mllib/_typing.pyi | 5 ++ python/pyspark/mllib/stat/_statistics.py | 94 +++++++++++++++++++++++-------- python/pyspark/mllib/stat/_statistics.pyi | 63 --------------------- 3 files changed, 74 insertions(+), 88 deletions(-) diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi index 6a1a0f53a59..4fbaeca39be 100644 --- a/python/pyspark/mllib/_typing.pyi +++ b/python/pyspark/mllib/_typing.pyi @@ -17,7 +17,9 @@ # under the License. from typing import List, Tuple, TypeVar, Union + from typing_extensions import Literal + from pyspark.mllib.linalg import Vector from numpy import ndarray # noqa: F401 from py4j.java_gateway import JavaObject @@ -25,4 +27,7 @@ from py4j.java_gateway import JavaObject VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]] C = TypeVar("C", bound=type) JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes] + +CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]] +DistName = Literal["norm"] NormType = Union[None, float, Literal["fro"], Literal["nuc"]] diff --git a/python/pyspark/mllib/stat/_statistics.py b/python/pyspark/mllib/stat/_statistics.py index 34a373d5358..25095d99dd9 100644 --- a/python/pyspark/mllib/stat/_statistics.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -16,13 +16,19 @@ # import sys +from typing import cast, overload, List, Optional, TYPE_CHECKING, Union + +from numpy import ndarray +from py4j.java_gateway import JavaObject from pyspark.rdd import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import Matrix, _convert_to_vector +from pyspark.mllib.linalg import Matrix, Vector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult +if TYPE_CHECKING: + from pyspark.mllib._typing import CorrelationMethod, DistName __all__ = ["MultivariateStatisticalSummary", "Statistics"] @@ -33,34 +39,34 @@ class MultivariateStatisticalSummary(JavaModelWrapper): Trait for multivariate statistical summary of a data matrix. """ - def mean(self): - return self.call("mean").toArray() + def mean(self) -> ndarray: + return cast(JavaObject, self.call("mean")).toArray() - def variance(self): - return self.call("variance").toArray() + def variance(self) -> ndarray: + return cast(JavaObject, self.call("variance")).toArray() - def count(self): + def count(self) -> int: return int(self.call("count")) - def numNonzeros(self): - return self.call("numNonzeros").toArray() + def numNonzeros(self) -> ndarray: + return cast(JavaObject, self.call("numNonzeros")).toArray() - def max(self): - return self.call("max").toArray() + def max(self) -> ndarray: + return cast(JavaObject, self.call("max")).toArray() - def min(self): - return self.call("min").toArray() + def min(self) -> ndarray: + return cast(JavaObject, self.call("min")).toArray() - def normL1(self): - return self.call("normL1").toArray() + def normL1(self) -> ndarray: + return cast(JavaObject, self.call("normL1")).toArray() - def normL2(self): - return self.call("normL2").toArray() + def normL2(self) -> ndarray: + return cast(JavaObject, self.call("normL2")).toArray() class Statistics: @staticmethod - def colStats(rdd): + def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary: """ Computes column-wise summary statistics for the input RDD[Vector]. @@ -98,8 +104,22 @@ class Statistics: cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector)) return MultivariateStatisticalSummary(cStats) + @overload + @staticmethod + def corr(x: RDD[Vector], *, method: Optional["CorrelationMethod"] = ...) -> Matrix: + ... + + @overload @staticmethod - def corr(x, y=None, method=None): + def corr(x: RDD[float], y: RDD[float], method: Optional["CorrelationMethod"] = ...) -> float: + ... + + @staticmethod + def corr( + x: Union[RDD[Vector], RDD[float]], + y: Optional[RDD[float]] = None, + method: Optional["CorrelationMethod"] = None, + ) -> Union[float, Matrix]: """ Compute the correlation (matrix) for the input RDD(s) using the specified method. @@ -168,12 +188,34 @@ class Statistics: raise TypeError("Use 'method=' to specify method name.") if not y: - return callMLlibFunc("corr", x.map(_convert_to_vector), method).toArray() + return cast( + JavaObject, callMLlibFunc("corr", x.map(_convert_to_vector), method) + ).toArray() else: - return callMLlibFunc("corr", x.map(float), y.map(float), method) + return cast( + float, + callMLlibFunc("corr", cast(RDD[float], x).map(float), y.map(float), method), + ) + + @overload + @staticmethod + def chiSqTest(observed: Matrix) -> ChiSqTestResult: + ... + + @overload + @staticmethod + def chiSqTest(observed: Vector, expected: Optional[Vector] = ...) -> ChiSqTestResult: + ... + + @overload + @staticmethod + def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: + ... @staticmethod - def chiSqTest(observed, expected=None): + def chiSqTest( + observed: Union[Matrix, RDD[LabeledPoint], Vector], expected: Optional[Vector] = None + ) -> Union[ChiSqTestResult, List[ChiSqTestResult]]: """ If `observed` is Vector, conduct Pearson's chi-squared goodness of fit test of the observed data against the expected distribution, @@ -270,7 +312,9 @@ class Statistics: return ChiSqTestResult(jmodel) @staticmethod - def kolmogorovSmirnovTest(data, distName="norm", *params): + def kolmogorovSmirnovTest( + data: RDD[float], distName: "DistName" = "norm", *params: float + ) -> KolmogorovSmirnovTestResult: """ Performs the Kolmogorov-Smirnov (KS) test for data sampled from a continuous distribution. It tests the null hypothesis that @@ -334,13 +378,13 @@ class Statistics: if not isinstance(distName, str): raise TypeError("distName should be a string, got %s." % type(distName)) - params = [float(param) for param in params] + param_list = [float(param) for param in params] return KolmogorovSmirnovTestResult( - callMLlibFunc("kolmogorovSmirnovTest", data, distName, params) + callMLlibFunc("kolmogorovSmirnovTest", data, distName, param_list) ) -def _test(): +def _test() -> None: import doctest import numpy from pyspark.sql import SparkSession diff --git a/python/pyspark/mllib/stat/_statistics.pyi b/python/pyspark/mllib/stat/_statistics.pyi deleted file mode 100644 index 1bf76dd3af0..00000000000 --- a/python/pyspark/mllib/stat/_statistics.pyi +++ /dev/null @@ -1,63 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from typing import List, Optional, overload, Union -from typing_extensions import Literal - -from numpy import ndarray - -from pyspark.mllib.common import JavaModelWrapper -from pyspark.mllib.linalg import Vector, Matrix -from pyspark.mllib.regression import LabeledPoint -from pyspark.mllib.stat.test import ChiSqTestResult, KolmogorovSmirnovTestResult -from pyspark.rdd import RDD - -CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]] - -class MultivariateStatisticalSummary(JavaModelWrapper): - def mean(self) -> ndarray: ... - def variance(self) -> ndarray: ... - def count(self) -> int: ... - def numNonzeros(self) -> ndarray: ... - def max(self) -> ndarray: ... - def min(self) -> ndarray: ... - def normL1(self) -> ndarray: ... - def normL2(self) -> ndarray: ... - -class Statistics: - @staticmethod - def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary: ... - @overload - @staticmethod - def corr(x: RDD[Vector], *, method: Optional[CorrelationMethod] = ...) -> Matrix: ... - @overload - @staticmethod - def corr(x: RDD[float], y: RDD[float], method: Optional[CorrelationMethod] = ...) -> float: ... - @overload - @staticmethod - def chiSqTest(observed: Matrix) -> ChiSqTestResult: ... - @overload - @staticmethod - def chiSqTest(observed: Vector, expected: Optional[Vector] = ...) -> ChiSqTestResult: ... - @overload - @staticmethod - def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ... - @staticmethod - def kolmogorovSmirnovTest( - data: RDD[float], distName: Literal["norm"] = ..., *params: float - ) -> KolmogorovSmirnovTestResult: ... --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org