This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch branch-3.2
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.2 by this push:
     new 2b54b48cd85 [SPARK-40079] Add Imputer inputCols validation for empty 
input case
2b54b48cd85 is described below

commit 2b54b48cd852f93e8cf24397df6f3ec5b755233e
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Mon Aug 15 18:03:08 2022 +0800

    [SPARK-40079] Add Imputer inputCols validation for empty input case
    
    Signed-off-by: Weichen Xu <weichen.xudatabricks.com>
    
    ### What changes were proposed in this pull request?
    Add Imputer inputCols validation for empty input case
    
    ### Why are the changes needed?
    If Imputer inputCols is empty, the `fit` works fine but when saving model, 
error will be raised:
    
    >
    AnalysisException:
    Datasource does not support writing empty or nested empty schemas.
    Please make sure the data schema has at least one or more column(s).
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Unit test.
    
    Closes #37518 from WeichenXu123/imputer-param-validation.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
    (cherry picked from commit 87094f89655b7df09cdecb47c653461ae855b0ac)
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala |  1 +
 .../test/scala/org/apache/spark/ml/feature/ImputerSuite.scala  | 10 ++++++++++
 2 files changed, 11 insertions(+)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
index 71403acc91b..5998887923f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala
@@ -81,6 +81,7 @@ private[feature] trait ImputerParams extends Params with 
HasInputCol with HasInp
   protected def validateAndTransformSchema(schema: StructType): StructType = {
     ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), 
Seq(outputCols))
     val (inputColNames, outputColNames) = getInOutCols()
+    require(inputColNames.length > 0, "inputCols cannot be empty")
     require(inputColNames.length == inputColNames.distinct.length, s"inputCols 
contains" +
       s" duplicates: (${inputColNames.mkString(", ")})")
     require(outputColNames.length == outputColNames.distinct.length, 
s"outputCols contains" +
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
index 30887f55638..5ef22a282c3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala
@@ -268,6 +268,16 @@ class ImputerSuite extends MLTest with 
DefaultReadWriteTest {
         }
         assert(e.getMessage.contains("outputCols contains duplicates"))
       }
+
+      withClue("Imputer should fail if inputCols param is empty.") {
+        val e: IllegalArgumentException = intercept[IllegalArgumentException] {
+          val imputer = new Imputer().setStrategy(strategy)
+            .setInputCols(Array[String]())
+            .setOutputCols(Array[String]())
+          val model = imputer.fit(df)
+        }
+        assert(e.getMessage.contains("inputCols cannot be empty"))
+      }
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to