This is an automated email from the ASF dual-hosted git repository.

zero323 pushed a commit to branch branch-3.3
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.3 by this push:
     new 60ce69df029 [SPARK-37234][PYTHON] Inline type hints for 
python/pyspark/mllib/stat/_statistics.py
60ce69df029 is described below

commit 60ce69df029b1e1d7cf7f7eece02e668de24cca8
Author: dch nguyen <dchvn....@gmail.com>
AuthorDate: Sun Apr 10 14:14:33 2022 +0200

    [SPARK-37234][PYTHON] Inline type hints for 
python/pyspark/mllib/stat/_statistics.py
    
    ### What changes were proposed in this pull request?
    Inline type hints for python/pyspark/mllib/stat/_statistics.py
    ### Why are the changes needed?
    We can take advantage of static type checking within the functions by 
inlining the type hints.
    ### Does this PR introduce _any_ user-facing change?
    No
    ### How was this patch tested?
    Existing tests
    
    Closes #34513 from dchvn/SPARK-37234.
    
    Lead-authored-by: dch nguyen <dchvn....@gmail.com>
    Co-authored-by: dch nguyen <dgd_contribu...@viettel.com.vn>
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
    (cherry picked from commit c3dcdb118ca403a8fbefc3308a116d9e12a1f038)
    Signed-off-by: zero323 <mszymkiew...@gmail.com>
---
 python/pyspark/mllib/_typing.pyi          |  5 ++
 python/pyspark/mllib/stat/_statistics.py  | 94 +++++++++++++++++++++++--------
 python/pyspark/mllib/stat/_statistics.pyi | 63 ---------------------
 3 files changed, 74 insertions(+), 88 deletions(-)

diff --git a/python/pyspark/mllib/_typing.pyi b/python/pyspark/mllib/_typing.pyi
index 6a1a0f53a59..4fbaeca39be 100644
--- a/python/pyspark/mllib/_typing.pyi
+++ b/python/pyspark/mllib/_typing.pyi
@@ -17,7 +17,9 @@
 # under the License.
 
 from typing import List, Tuple, TypeVar, Union
+
 from typing_extensions import Literal
+
 from pyspark.mllib.linalg import Vector
 from numpy import ndarray  # noqa: F401
 from py4j.java_gateway import JavaObject
@@ -25,4 +27,7 @@ from py4j.java_gateway import JavaObject
 VectorLike = Union[ndarray, Vector, List[float], Tuple[float, ...]]
 C = TypeVar("C", bound=type)
 JavaObjectOrPickleDump = Union[JavaObject, bytearray, bytes]
+
+CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]]
+DistName = Literal["norm"]
 NormType = Union[None, float, Literal["fro"], Literal["nuc"]]
diff --git a/python/pyspark/mllib/stat/_statistics.py 
b/python/pyspark/mllib/stat/_statistics.py
index 34a373d5358..25095d99dd9 100644
--- a/python/pyspark/mllib/stat/_statistics.py
+++ b/python/pyspark/mllib/stat/_statistics.py
@@ -16,13 +16,19 @@
 #
 
 import sys
+from typing import cast, overload, List, Optional, TYPE_CHECKING, Union
+
+from numpy import ndarray
+from py4j.java_gateway import JavaObject
 
 from pyspark.rdd import RDD
 from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Matrix, _convert_to_vector
+from pyspark.mllib.linalg import Matrix, Vector, _convert_to_vector
 from pyspark.mllib.regression import LabeledPoint
 from pyspark.mllib.stat.test import ChiSqTestResult, 
KolmogorovSmirnovTestResult
 
+if TYPE_CHECKING:
+    from pyspark.mllib._typing import CorrelationMethod, DistName
 
 __all__ = ["MultivariateStatisticalSummary", "Statistics"]
 
@@ -33,34 +39,34 @@ class MultivariateStatisticalSummary(JavaModelWrapper):
     Trait for multivariate statistical summary of a data matrix.
     """
 
-    def mean(self):
-        return self.call("mean").toArray()
+    def mean(self) -> ndarray:
+        return cast(JavaObject, self.call("mean")).toArray()
 
-    def variance(self):
-        return self.call("variance").toArray()
+    def variance(self) -> ndarray:
+        return cast(JavaObject, self.call("variance")).toArray()
 
-    def count(self):
+    def count(self) -> int:
         return int(self.call("count"))
 
-    def numNonzeros(self):
-        return self.call("numNonzeros").toArray()
+    def numNonzeros(self) -> ndarray:
+        return cast(JavaObject, self.call("numNonzeros")).toArray()
 
-    def max(self):
-        return self.call("max").toArray()
+    def max(self) -> ndarray:
+        return cast(JavaObject, self.call("max")).toArray()
 
-    def min(self):
-        return self.call("min").toArray()
+    def min(self) -> ndarray:
+        return cast(JavaObject, self.call("min")).toArray()
 
-    def normL1(self):
-        return self.call("normL1").toArray()
+    def normL1(self) -> ndarray:
+        return cast(JavaObject, self.call("normL1")).toArray()
 
-    def normL2(self):
-        return self.call("normL2").toArray()
+    def normL2(self) -> ndarray:
+        return cast(JavaObject, self.call("normL2")).toArray()
 
 
 class Statistics:
     @staticmethod
-    def colStats(rdd):
+    def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary:
         """
         Computes column-wise summary statistics for the input RDD[Vector].
 
@@ -98,8 +104,22 @@ class Statistics:
         cStats = callMLlibFunc("colStats", rdd.map(_convert_to_vector))
         return MultivariateStatisticalSummary(cStats)
 
+    @overload
+    @staticmethod
+    def corr(x: RDD[Vector], *, method: Optional["CorrelationMethod"] = ...) 
-> Matrix:
+        ...
+
+    @overload
     @staticmethod
-    def corr(x, y=None, method=None):
+    def corr(x: RDD[float], y: RDD[float], method: 
Optional["CorrelationMethod"] = ...) -> float:
+        ...
+
+    @staticmethod
+    def corr(
+        x: Union[RDD[Vector], RDD[float]],
+        y: Optional[RDD[float]] = None,
+        method: Optional["CorrelationMethod"] = None,
+    ) -> Union[float, Matrix]:
         """
         Compute the correlation (matrix) for the input RDD(s) using the
         specified method.
@@ -168,12 +188,34 @@ class Statistics:
             raise TypeError("Use 'method=' to specify method name.")
 
         if not y:
-            return callMLlibFunc("corr", x.map(_convert_to_vector), 
method).toArray()
+            return cast(
+                JavaObject, callMLlibFunc("corr", x.map(_convert_to_vector), 
method)
+            ).toArray()
         else:
-            return callMLlibFunc("corr", x.map(float), y.map(float), method)
+            return cast(
+                float,
+                callMLlibFunc("corr", cast(RDD[float], x).map(float), 
y.map(float), method),
+            )
+
+    @overload
+    @staticmethod
+    def chiSqTest(observed: Matrix) -> ChiSqTestResult:
+        ...
+
+    @overload
+    @staticmethod
+    def chiSqTest(observed: Vector, expected: Optional[Vector] = ...) -> 
ChiSqTestResult:
+        ...
+
+    @overload
+    @staticmethod
+    def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]:
+        ...
 
     @staticmethod
-    def chiSqTest(observed, expected=None):
+    def chiSqTest(
+        observed: Union[Matrix, RDD[LabeledPoint], Vector], expected: 
Optional[Vector] = None
+    ) -> Union[ChiSqTestResult, List[ChiSqTestResult]]:
         """
         If `observed` is Vector, conduct Pearson's chi-squared goodness
         of fit test of the observed data against the expected distribution,
@@ -270,7 +312,9 @@ class Statistics:
         return ChiSqTestResult(jmodel)
 
     @staticmethod
-    def kolmogorovSmirnovTest(data, distName="norm", *params):
+    def kolmogorovSmirnovTest(
+        data: RDD[float], distName: "DistName" = "norm", *params: float
+    ) -> KolmogorovSmirnovTestResult:
         """
         Performs the Kolmogorov-Smirnov (KS) test for data sampled from
         a continuous distribution. It tests the null hypothesis that
@@ -334,13 +378,13 @@ class Statistics:
         if not isinstance(distName, str):
             raise TypeError("distName should be a string, got %s." % 
type(distName))
 
-        params = [float(param) for param in params]
+        param_list = [float(param) for param in params]
         return KolmogorovSmirnovTestResult(
-            callMLlibFunc("kolmogorovSmirnovTest", data, distName, params)
+            callMLlibFunc("kolmogorovSmirnovTest", data, distName, param_list)
         )
 
 
-def _test():
+def _test() -> None:
     import doctest
     import numpy
     from pyspark.sql import SparkSession
diff --git a/python/pyspark/mllib/stat/_statistics.pyi 
b/python/pyspark/mllib/stat/_statistics.pyi
deleted file mode 100644
index 1bf76dd3af0..00000000000
--- a/python/pyspark/mllib/stat/_statistics.pyi
+++ /dev/null
@@ -1,63 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-from typing import List, Optional, overload, Union
-from typing_extensions import Literal
-
-from numpy import ndarray
-
-from pyspark.mllib.common import JavaModelWrapper
-from pyspark.mllib.linalg import Vector, Matrix
-from pyspark.mllib.regression import LabeledPoint
-from pyspark.mllib.stat.test import ChiSqTestResult, 
KolmogorovSmirnovTestResult
-from pyspark.rdd import RDD
-
-CorrelationMethod = Union[Literal["spearman"], Literal["pearson"]]
-
-class MultivariateStatisticalSummary(JavaModelWrapper):
-    def mean(self) -> ndarray: ...
-    def variance(self) -> ndarray: ...
-    def count(self) -> int: ...
-    def numNonzeros(self) -> ndarray: ...
-    def max(self) -> ndarray: ...
-    def min(self) -> ndarray: ...
-    def normL1(self) -> ndarray: ...
-    def normL2(self) -> ndarray: ...
-
-class Statistics:
-    @staticmethod
-    def colStats(rdd: RDD[Vector]) -> MultivariateStatisticalSummary: ...
-    @overload
-    @staticmethod
-    def corr(x: RDD[Vector], *, method: Optional[CorrelationMethod] = ...) -> 
Matrix: ...
-    @overload
-    @staticmethod
-    def corr(x: RDD[float], y: RDD[float], method: Optional[CorrelationMethod] 
= ...) -> float: ...
-    @overload
-    @staticmethod
-    def chiSqTest(observed: Matrix) -> ChiSqTestResult: ...
-    @overload
-    @staticmethod
-    def chiSqTest(observed: Vector, expected: Optional[Vector] = ...) -> 
ChiSqTestResult: ...
-    @overload
-    @staticmethod
-    def chiSqTest(observed: RDD[LabeledPoint]) -> List[ChiSqTestResult]: ...
-    @staticmethod
-    def kolmogorovSmirnovTest(
-        data: RDD[float], distName: Literal["norm"] = ..., *params: float
-    ) -> KolmogorovSmirnovTestResult: ...


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to