wbo4958 commented on code in PR #49547:
URL: https://github.com/apache/spark/pull/49547#discussion_r1920289915


##########
python/pyspark/ml/tests/test_evaluation.py:
##########
@@ -14,18 +14,368 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import tempfile
 import unittest
 
 import numpy as np
 
-from pyspark.ml.evaluation import ClusteringEvaluator, RegressionEvaluator
+from pyspark.ml.evaluation import (
+    ClusteringEvaluator,
+    RegressionEvaluator,
+    BinaryClassificationEvaluator,
+    MulticlassClassificationEvaluator,
+    MultilabelClassificationEvaluator,
+    RankingEvaluator,
+)
 from pyspark.ml.linalg import Vectors
-from pyspark.sql import Row
-from pyspark.testing.mlutils import SparkSessionTestCase
+from pyspark.sql import Row, SparkSession
+
+
+class EvaluatorTestsMixin:
+    def test_ranking_evaluator(self):
+        scoreAndLabels = [
+            ([1.0, 6.0, 2.0, 7.0, 8.0, 3.0, 9.0, 10.0, 4.0, 5.0], [1.0, 2.0, 
3.0, 4.0, 5.0]),
+            ([4.0, 1.0, 5.0, 6.0, 2.0, 7.0, 3.0, 8.0, 9.0, 10.0], [1.0, 2.0, 
3.0]),
+            ([1.0, 2.0, 3.0, 4.0, 5.0], []),
+        ]
+        dataset = self.spark.createDataFrame(scoreAndLabels, ["prediction", 
"label"])
+
+        # Initialize RankingEvaluator
+        evaluator = RankingEvaluator().setPredictionCol("prediction")
+
+        # Evaluate the dataset using the default metric (mean average 
precision)
+        mean_average_precision = evaluator.evaluate(dataset)
+        self.assertTrue(np.allclose(mean_average_precision, 0.3550, atol=1e-4))
+
+        # Evaluate the dataset using precisionAtK for k=2
+        precision_at_k = evaluator.evaluate(
+            dataset, {evaluator.metricName: "precisionAtK", evaluator.k: 2}
+        )
+        self.assertTrue(np.allclose(precision_at_k, 0.3333, atol=1e-4))
+
+        # read/write
+        with tempfile.TemporaryDirectory(prefix="save") as tmp_dir:
+            # Save the evaluator
+            ranke_path = tmp_dir + "/ranke"

Review Comment:
   DOne



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to