viirya commented on a change in pull request #26247: [SPARK-29566][ML] Imputer 
should support single-column input/output
URL: https://github.com/apache/spark/pull/26247#discussion_r338879058
 
 

 ##########
 File path: mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
 ##########
 @@ -228,34 +453,42 @@ object ImputerSuite {
    * Imputation strategy. Available options are ["mean", "median"].
    * @param df DataFrame with columns "id", "value", "expected_mean", 
"expected_median"
    */
-  def iterateStrategyTest(imputer: Imputer, df: DataFrame): Unit = {
+  def iterateStrategyTest(isMultiCol: Boolean, imputer: Imputer, df: 
DataFrame): Unit = {
     Seq("mean", "median").foreach { strategy =>
       imputer.setStrategy(strategy)
       val model = imputer.fit(df)
       val resultDF = model.transform(df)
-      imputer.getInputCols.zip(imputer.getOutputCols).foreach { case 
(inputCol, outputCol) =>
-
-        // check dataType is consistent between input and output
-        val inputType = resultDF.schema(inputCol).dataType
-        val outputType = resultDF.schema(outputCol).dataType
-        assert(inputType == outputType, "Output type is not the same as input 
type.")
-
-        // check value
-        resultDF.select(s"expected_${strategy}_$inputCol", 
outputCol).collect().foreach {
-          case Row(exp: Float, out: Float) =>
-            assert((exp.isNaN && out.isNaN) || (exp == out),
-              s"Imputed values differ. Expected: $exp, actual: $out")
-          case Row(exp: Double, out: Double) =>
-            assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5),
-              s"Imputed values differ. Expected: $exp, actual: $out")
-          case Row(exp: Integer, out: Integer) =>
-            assert(exp == out,
-              s"Imputed values differ. Expected: $exp, actual: $out")
-          case Row(exp: Long, out: Long) =>
-            assert(exp == out,
-              s"Imputed values differ. Expected: $exp, actual: $out")
+      if (isMultiCol) {
+        imputer.getInputCols.zip(imputer.getOutputCols).foreach { case 
(inputCol, outputCol) =>
+          verifyTransformResult(strategy, inputCol, outputCol, resultDF)
         }
+      } else {
+          verifyTransformResult(strategy, imputer.getInputCol, 
imputer.getOutputCol, resultDF)
       }
     }
   }
+
+  def verifyTransformResult(strategy: String, inputCol: String, outputCol: 
String,
+                            resultDF: DataFrame): Unit = {
 
 Review comment:
   seems not usual indent.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to