This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3310789  [SPARK-11215][ML] Add multiple columns support to 
StringIndexer
3310789 is described below

commit 33107897ada29d1ed17f091f93260dfcef11c2e7
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Tue Jan 29 09:21:25 2019 -0600

    [SPARK-11215][ML] Add multiple columns support to StringIndexer
    
    ## What changes were proposed in this pull request?
    
    This takes over #19621 to add multi-column support to StringIndexer:
    
    1. Supports encoding multiple columns.
    2. Previously, when specifying `frequencyDesc` or `frequencyAsc` as 
`stringOrderType` param in `StringIndexer`, in case of equal frequency, the 
order of strings is undefined. After this change, the strings with equal 
frequency are further sorted alphabetically.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #20146 from viirya/SPARK-11215.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Sean Owen <sean.o...@databricks.com>
---
 R/pkg/tests/fulltests/test_mllib_classification.R  |   6 +-
 R/pkg/tests/fulltests/test_mllib_regression.R      |  42 ++-
 docs/ml-features.md                                |   6 +-
 docs/ml-guide.md                                   |   9 +
 .../apache/spark/ml/feature/StringIndexer.scala    | 409 +++++++++++++++++----
 ...2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet | Bin 0 -> 478 bytes
 .../test-data/strIndexerModel/metadata/part-00000  |   1 +
 .../apache/spark/ml/feature/RFormulaSuite.scala    |  28 +-
 .../spark/ml/feature/StringIndexerSuite.scala      | 139 ++++++-
 project/MimaExcludes.scala                         |   4 +
 10 files changed, 531 insertions(+), 113 deletions(-)

diff --git a/R/pkg/tests/fulltests/test_mllib_classification.R 
b/R/pkg/tests/fulltests/test_mllib_classification.R
index 023686e..9fdb0cf 100644
--- a/R/pkg/tests/fulltests/test_mllib_classification.R
+++ b/R/pkg/tests/fulltests/test_mllib_classification.R
@@ -313,7 +313,7 @@ test_that("spark.mlp", {
   # Test predict method
   mlpTestDF <- df
   mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
-  expect_equal(head(mlpPredictions$prediction, 6), c("1.0", "0.0", "0.0", 
"0.0", "0.0", "0.0"))
+  expect_equal(head(mlpPredictions$prediction, 6), c("0.0", "1.0", "1.0", 
"1.0", "1.0", "1.0"))
 
   # Test model save/load
   if (windows_with_hadoop()) {
@@ -348,12 +348,12 @@ test_that("spark.mlp", {
 
   # Test random seed
   # default seed
-  model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 
10)
+  model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 
100)
   mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
   expect_equal(head(mlpPredictions$prediction, 10),
                c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", 
"1.0", "0.0"))
   # seed equals 10
-  model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 
10, seed = 10)
+  model <- spark.mlp(df, label ~ features, layers = c(4, 5, 4, 3), maxIter = 
100, seed = 10)
   mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
   expect_equal(head(mlpPredictions$prediction, 10),
                c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", 
"1.0", "0.0"))
diff --git a/R/pkg/tests/fulltests/test_mllib_regression.R 
b/R/pkg/tests/fulltests/test_mllib_regression.R
index 23daca7..b40c4cb 100644
--- a/R/pkg/tests/fulltests/test_mllib_regression.R
+++ b/R/pkg/tests/fulltests/test_mllib_regression.R
@@ -102,10 +102,18 @@ test_that("spark.glm and predict", {
 })
 
 test_that("spark.glm summary", {
+  # prepare dataset
+  Sepal.Length <- c(2.0, 1.5, 1.8, 3.4, 5.1, 1.8, 1.0, 2.3)
+  Sepal.Width <- c(2.1, 2.3, 5.4, 4.7, 3.1, 2.1, 3.1, 5.5)
+  Petal.Length <- c(1.8, 2.1, 7.1, 2.5, 3.7, 6.3, 2.2, 7.2)
+  Species <- c("setosa", "versicolor", "versicolor", "versicolor", 
"virginica", "virginica",
+               "versicolor", "virginica")
+  dataset <- data.frame(Sepal.Length, Sepal.Width, Petal.Length, Species, 
stringsAsFactors = TRUE)
+
   # gaussian family
-  training <- suppressWarnings(createDataFrame(iris))
+  training <- suppressWarnings(createDataFrame(dataset))
   stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species))
-  rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
+  rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = dataset))
 
   # test summary coefficients return matrix type
   expect_true(class(stats$coefficients) == "matrix")
@@ -126,15 +134,15 @@ test_that("spark.glm summary", {
 
   out <- capture.output(print(stats))
   expect_match(out[2], "Deviance Residuals:")
-  expect_true(any(grepl("AIC: 59.22", out)))
+  expect_true(any(grepl("AIC: 35.84", out)))
 
   # binomial family
-  df <- suppressWarnings(createDataFrame(iris))
+  df <- suppressWarnings(createDataFrame(dataset))
   training <- df[df$Species %in% c("versicolor", "virginica"), ]
   stats <- summary(spark.glm(training, Species ~ Sepal_Length + Sepal_Width,
                              family = binomial(link = "logit")))
 
-  rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
+  rTraining <- dataset[dataset$Species %in% c("versicolor", "virginica"), ]
   rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
                         family = binomial(link = "logit")))
 
@@ -174,17 +182,17 @@ test_that("spark.glm summary", {
   expect_equal(stats$aic, rStats$aic)
 
   # Test spark.glm works with offset
-  training <- suppressWarnings(createDataFrame(iris))
+  training <- suppressWarnings(createDataFrame(dataset))
   stats <- summary(spark.glm(training, Sepal_Width ~ Sepal_Length + Species,
                              family = poisson(), offsetCol = "Petal_Length"))
   rStats <- suppressWarnings(summary(glm(Sepal.Width ~ Sepal.Length + Species,
-                        data = iris, family = poisson(), offset = 
iris$Petal.Length)))
+                        data = dataset, family = poisson(), offset = 
dataset$Petal.Length)))
   expect_true(all(abs(rStats$coefficients - stats$coefficients) < 1e-3))
 
   # Test summary works on base GLM models
-  baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris)
+  baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = dataset)
   baseSummary <- summary(baseModel)
-  expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
+  expect_true(abs(baseSummary$deviance - 11.84013) < 1e-4)
 
   # Test spark.glm works with regularization parameter
   data <- as.data.frame(cbind(a1, a2, b))
@@ -300,11 +308,19 @@ test_that("glm and predict", {
 })
 
 test_that("glm summary", {
+  # prepare dataset
+  Sepal.Length <- c(2.0, 1.5, 1.8, 3.4, 5.1, 1.8, 1.0, 2.3)
+  Sepal.Width <- c(2.1, 2.3, 5.4, 4.7, 3.1, 2.1, 3.1, 5.5)
+  Petal.Length <- c(1.8, 2.1, 7.1, 2.5, 3.7, 6.3, 2.2, 7.2)
+  Species <- c("setosa", "versicolor", "versicolor", "versicolor", 
"virginica", "virginica",
+               "versicolor", "virginica")
+  dataset <- data.frame(Sepal.Length, Sepal.Width, Petal.Length, Species, 
stringsAsFactors = TRUE)
+
   # gaussian family
-  training <- suppressWarnings(createDataFrame(iris))
+  training <- suppressWarnings(createDataFrame(dataset))
   stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training))
 
-  rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris))
+  rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = dataset))
 
   coefs <- stats$coefficients
   rCoefs <- rStats$coefficients
@@ -320,12 +336,12 @@ test_that("glm summary", {
   expect_equal(stats$aic, rStats$aic)
 
   # binomial family
-  df <- suppressWarnings(createDataFrame(iris))
+  df <- suppressWarnings(createDataFrame(dataset))
   training <- df[df$Species %in% c("versicolor", "virginica"), ]
   stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training,
                        family = binomial(link = "logit")))
 
-  rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ]
+  rTraining <- dataset[dataset$Species %in% c("versicolor", "virginica"), ]
   rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining,
                         family = binomial(link = "logit")))
 
diff --git a/docs/ml-features.md b/docs/ml-features.md
index a140bc6..33373e0 100644
--- a/docs/ml-features.md
+++ b/docs/ml-features.md
@@ -585,11 +585,13 @@ for more details on the API.
 ## StringIndexer
 
 `StringIndexer` encodes a string column of labels to a column of label indices.
-The indices are in `[0, numLabels)`, and four ordering options are supported:
+`StringIndexer` can encode multiple columns. The indices are in `[0, 
numLabels)`, and four ordering options are supported:
 "frequencyDesc": descending order by label frequency (most frequent label 
assigned 0),
 "frequencyAsc": ascending order by label frequency (least frequent label 
assigned 0),
 "alphabetDesc": descending alphabetical order, and "alphabetAsc": ascending 
alphabetical order 
-(default = "frequencyDesc").
+(default = "frequencyDesc"). Note that in case of equal frequency when under
+"frequencyDesc"/"frequencyAsc", the strings are further sorted by alphabet.
+
 The unseen labels will be put at index numLabels if user chooses to keep them.
 If the input column is numeric, we cast it to string and index the string
 values. When downstream pipeline components such as `Estimator` or
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index 57d4e1f..cffe419 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -110,6 +110,15 @@ and the migration guide below will explain all changes 
between releases.
 
 * `OneHotEncoder` which is deprecated in 2.3, is removed in 3.0 and 
`OneHotEncoderEstimator` is now renamed to `OneHotEncoder`.
 
+### Changes of behavior
+
+* [SPARK-11215](https://issues.apache.org/jira/browse/SPARK-11215):
+ In Spark 2.4 and previous versions, when specifying `frequencyDesc` or 
`frequencyAsc` as
+ `stringOrderType` param in `StringIndexer`, in case of equal frequency, the 
order of
+ strings is undefined. Since Spark 3.0, the strings with equal frequency are 
further
+ sorted by alphabet. And since Spark 3.0, `StringIndexer` supports encoding 
multiple
+ columns.
+
 ## From 2.2 to 2.3
 
 ### Breaking changes
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a833d8b..f2e6012 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -26,18 +26,22 @@ import org.apache.spark.annotation.Since
 import org.apache.spark.ml.{Estimator, Model, Transformer}
 import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol, 
HasOutputCol}
+import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, 
Row}
+import org.apache.spark.sql.catalyst.expressions.{If, Literal}
+import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.types._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.VersionUtils.majorMinorVersion
 import org.apache.spark.util.collection.OpenHashMap
 
 /**
  * Base trait for [[StringIndexer]] and [[StringIndexerModel]].
  */
 private[feature] trait StringIndexerBase extends Params with HasHandleInvalid 
with HasInputCol
-  with HasOutputCol {
+  with HasOutputCol with HasInputCols with HasOutputCols {
 
   /**
    * Param for how to handle invalid data (unseen labels or NULL values).
@@ -66,6 +70,9 @@ private[feature] trait StringIndexerBase extends Params with 
HasHandleInvalid wi
    *   - 'alphabetAsc': ascending alphabetical order
    * Default is 'frequencyDesc'.
    *
+   * Note: In case of equal frequency when under frequencyDesc/Asc, the 
strings are further sorted
+   *       alphabetically.
+   *
    * @group param
    */
   @Since("2.3.0")
@@ -79,26 +86,56 @@ private[feature] trait StringIndexerBase extends Params 
with HasHandleInvalid wi
   @Since("2.3.0")
   def getStringOrderType: String = $(stringOrderType)
 
-  /** Validates and transforms the input schema. */
-  protected def validateAndTransformSchema(schema: StructType): StructType = {
-    val inputColName = $(inputCol)
+  /** Returns the input and output column names corresponding in pair. */
+  private[feature] def getInOutCols(): (Array[String], Array[String]) = {
+    ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), 
Seq(outputCols))
+
+    if (isSet(inputCol)) {
+      (Array($(inputCol)), Array($(outputCol)))
+    } else {
+      require($(inputCols).length == $(outputCols).length,
+        "The number of input columns does not match output columns")
+      ($(inputCols), $(outputCols))
+    }
+  }
+
+  private def validateAndTransformField(
+      schema: StructType,
+      inputColName: String,
+      outputColName: String): StructField = {
     val inputDataType = schema(inputColName).dataType
     require(inputDataType == StringType || 
inputDataType.isInstanceOf[NumericType],
       s"The input column $inputColName must be either string type or numeric 
type, " +
         s"but got $inputDataType.")
-    val inputFields = schema.fields
-    val outputColName = $(outputCol)
-    require(inputFields.forall(_.name != outputColName),
+    require(schema.fields.forall(_.name != outputColName),
       s"Output column $outputColName already exists.")
-    val attr = NominalAttribute.defaultAttr.withName($(outputCol))
-    val outputFields = inputFields :+ attr.toStructField()
-    StructType(outputFields)
+    NominalAttribute.defaultAttr.withName($(outputCol)).toStructField()
+  }
+
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(
+      schema: StructType,
+      skipNonExistsCol: Boolean = false): StructType = {
+    val (inputColNames, outputColNames) = getInOutCols()
+
+    require(outputColNames.distinct.length == outputColNames.length,
+      s"Output columns should not be duplicate.")
+
+    val outputFields = inputColNames.zip(outputColNames).flatMap {
+      case (inputColName, outputColName) =>
+        schema.fieldNames.contains(inputColName) match {
+          case true => Some(validateAndTransformField(schema, inputColName, 
outputColName))
+          case false if skipNonExistsCol => None
+          case _ => throw new SparkException(s"Input column $inputColName does 
not exist.")
+        }
+    }
+    StructType(schema.fields ++ outputFields)
   }
 }
 
 /**
- * A label indexer that maps a string column of labels to an ML column of 
label indices.
- * If the input column is numeric, we cast it to string and index the string 
values.
+ * A label indexer that maps string column(s) of labels to ML column(s) of 
label indices.
+ * If the input columns are numeric, we cast them to string and index the 
string values.
  * The indices are in [0, numLabels). By default, this is ordered by label 
frequencies
  * so the most frequent label gets index 0. The ordering behavior is 
controlled by
  * setting `stringOrderType`.
@@ -130,21 +167,86 @@ class StringIndexer @Since("1.4.0") (
   @Since("1.4.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
+  /** @group setParam */
+  @Since("3.0.0")
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
+
+  /** @group setParam */
+  @Since("3.0.0")
+  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+  private def countByValue(
+      dataset: Dataset[_],
+      inputCols: Array[String]): Array[OpenHashMap[String, Long]] = {
+
+    val aggregator = new StringIndexerAggregator(inputCols.length)
+    implicit val encoder = Encoders.kryo[Array[OpenHashMap[String, Long]]]
+
+    val selectedCols = inputCols.map { colName =>
+      val col = dataset.col(colName)
+      if (col.expr.dataType == StringType) {
+        col
+      } else {
+        // We don't count for NaN values. Because `StringIndexerAggregator` 
only processes strings,
+        // we replace NaNs with null in advance.
+        new Column(If(col.isNaN.expr, Literal(null), 
col.expr)).cast(StringType)
+      }
+    }
+
+    dataset.select(selectedCols: _*)
+      .toDF
+      .groupBy().agg(aggregator.toColumn)
+      .as[Array[OpenHashMap[String, Long]]]
+      .collect()(0)
+  }
+
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): StringIndexerModel = {
     transformSchema(dataset.schema, logging = true)
-    val values = dataset.na.drop(Array($(inputCol)))
-      .select(col($(inputCol)).cast(StringType))
-      .rdd.map(_.getString(0))
-    val labels = $(stringOrderType) match {
-      case StringIndexer.frequencyDesc => 
values.countByValue().toSeq.sortBy(-_._2)
-        .map(_._1).toArray
-      case StringIndexer.frequencyAsc => 
values.countByValue().toSeq.sortBy(_._2)
-        .map(_._1).toArray
-      case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > 
_)
-      case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _)
-    }
-    copyValues(new StringIndexerModel(uid, labels).setParent(this))
+
+    val (inputCols, _) = getInOutCols()
+
+    // If input dataset is not originally cached, we need to unpersist it
+    // once we persist it later.
+    val needUnpersist = dataset.storageLevel == StorageLevel.NONE
+
+    // In case of equal frequency when frequencyDesc/Asc, the strings are 
further sorted
+    // alphabetically.
+    val labelsArray = $(stringOrderType) match {
+      case StringIndexer.frequencyDesc =>
+        val sortFunc = StringIndexer.getSortFunc(ascending = false)
+        countByValue(dataset, inputCols).map { counts =>
+          counts.toSeq.sortWith(sortFunc).map(_._1).toArray
+        }
+      case StringIndexer.frequencyAsc =>
+        val sortFunc = StringIndexer.getSortFunc(ascending = true)
+        countByValue(dataset, inputCols).map { counts =>
+          counts.toSeq.sortWith(sortFunc).map(_._1).toArray
+        }
+      case StringIndexer.alphabetDesc =>
+        import dataset.sparkSession.implicits._
+        dataset.persist()
+        val labels = inputCols.map { inputCol =>
+          
dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").desc)
+            .as[String].collect()
+        }
+        if (needUnpersist) {
+          dataset.unpersist()
+        }
+        labels
+      case StringIndexer.alphabetAsc =>
+        import dataset.sparkSession.implicits._
+        dataset.persist()
+        val labels = inputCols.map { inputCol =>
+          
dataset.select(inputCol).na.drop().distinct().sort(dataset(s"$inputCol").asc)
+            .as[String].collect()
+        }
+        if (needUnpersist) {
+          dataset.unpersist()
+        }
+        labels
+     }
+    copyValues(new StringIndexerModel(uid, labelsArray).setParent(this))
   }
 
   @Since("1.4.0")
@@ -172,37 +274,76 @@ object StringIndexer extends 
DefaultParamsReadable[StringIndexer] {
 
   @Since("1.6.0")
   override def load(path: String): StringIndexer = super.load(path)
+
+  // Returns a function used to sort strings by frequency (ascending or 
descending).
+  // In case of equal frequency, it sorts strings by alphabet (ascending).
+  private[feature] def getSortFunc(
+      ascending: Boolean): ((String, Long), (String, Long)) => Boolean = {
+    if (ascending) {
+      case ((strA: String, freqA: Long), (strB: String, freqB: Long)) =>
+        if (freqA == freqB) {
+          strA < strB
+        } else {
+          freqA < freqB
+        }
+    } else {
+      case ((strA: String, freqA: Long), (strB: String, freqB: Long)) =>
+        if (freqA == freqB) {
+          strA < strB
+        } else {
+          freqA > freqB
+        }
+    }
+  }
 }
 
 /**
  * Model fitted by [[StringIndexer]].
  *
- * @param labels  Ordered list of labels, corresponding to indices to be 
assigned.
+ * @param labelsArray Array of ordered list of labels, corresponding to 
indices to be assigned
+ *                    for each input column.
  *
- * @note During transformation, if the input column does not exist,
- * `StringIndexerModel.transform` would return the input dataset unmodified.
+ * @note During transformation, if any input column does not exist,
+ * `StringIndexerModel.transform` would skip the input column.
+ * If all input columns do not exist, it returns the input dataset unmodified.
  * This is a temporary fix for the case when target labels do not exist during 
prediction.
  */
 @Since("1.4.0")
 class StringIndexerModel (
     @Since("1.4.0") override val uid: String,
-    @Since("1.5.0") val labels: Array[String])
+    @Since("3.0.0") val labelsArray: Array[Array[String]])
   extends Model[StringIndexerModel] with StringIndexerBase with MLWritable {
 
   import StringIndexerModel._
 
   @Since("1.5.0")
-  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), 
labels)
-
-  private val labelToIndex: OpenHashMap[String, Double] = {
-    val n = labels.length
-    val map = new OpenHashMap[String, Double](n)
-    var i = 0
-    while (i < n) {
-      map.update(labels(i), i)
-      i += 1
+  def this(uid: String, labels: Array[String]) = this(uid, Array(labels))
+
+  @Since("1.5.0")
+  def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), 
Array(labels))
+
+  @Since("3.0.0")
+  def this(labelsArray: Array[Array[String]]) = 
this(Identifiable.randomUID("strIdx"), labelsArray)
+
+  @deprecated("`labels` is deprecated and will be removed in 3.1.0. Use 
`labelsArray` " +
+    "instead.", "3.0.0")
+  @Since("1.5.0")
+  def labels: Array[String] = {
+    require(labelsArray.length == 1, "This StringIndexerModel is fit on 
multiple columns. " +
+      "Call `labelsArray` instead.")
+    labelsArray(0)
+  }
+
+  // Prepares the maps for string values to corresponding index values.
+  private val labelsToIndexArray: Array[OpenHashMap[String, Double]] = {
+    for (labels <- labelsArray) yield {
+      val n = labels.length
+      val map = new OpenHashMap[String, Double](n)
+      labels.zipWithIndex.foreach { case (label, idx) =>
+        map.update(label, idx)
+      }
+      map
     }
-    map
   }
 
   /** @group setParam */
@@ -217,33 +358,39 @@ class StringIndexerModel (
   @Since("1.4.0")
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
-  @Since("2.0.0")
-  override def transform(dataset: Dataset[_]): DataFrame = {
-    if (!dataset.schema.fieldNames.contains($(inputCol))) {
-      logInfo(s"Input column ${$(inputCol)} does not exist during 
transformation. " +
-        "Skip StringIndexerModel.")
-      return dataset.toDF
-    }
-    transformSchema(dataset.schema, logging = true)
+  /** @group setParam */
+  @Since("3.0.0")
+  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
 
-    val filteredLabels = getHandleInvalid match {
-      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
-      case _ => labels
+  /** @group setParam */
+  @Since("3.0.0")
+  def setOutputCols(value: Array[String]): this.type = set(outputCols, value)
+
+  // This filters out any null values and also the input labels which are not 
in
+  // the dataset used for fitting.
+  private def filterInvalidData(dataset: Dataset[_], inputColNames: 
Seq[String]): Dataset[_] = {
+    val conditions: Seq[Column] = (0 until inputColNames.length).map { i =>
+      val inputColName = inputColNames(i)
+      val labelToIndex = labelsToIndexArray(i)
+      // We have this additional lookup at `labelToIndex` when `handleInvalid` 
is set to
+      // `StringIndexer.SKIP_INVALID`. Another idea is to do this lookup 
natively by SQL
+      // expression, however, lookup for a key in a map is not efficient in 
SparkSQL now.
+      // See `ElementAt` and `GetMapValue` expressions. If SQL's map lookup is 
improved,
+      // we can consider to change this.
+      val filter = udf { label: String =>
+        labelToIndex.contains(label)
+      }
+      filter(dataset(inputColName))
     }
 
-    val metadata = NominalAttribute.defaultAttr
-      .withName($(outputCol)).withValues(filteredLabels).toMetadata()
-    // If we are skipping invalid records, filter them out.
-    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
-      case StringIndexer.SKIP_INVALID =>
-        val filterer = udf { label: String =>
-          labelToIndex.contains(label)
-        }
-        
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), 
false)
-      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
-    }
+    
dataset.na.drop(inputColNames.filter(dataset.schema.fieldNames.contains(_)))
+      .where(conditions.reduce(_ and _))
+  }
 
-    val indexer = udf { label: String =>
+  private def getIndexer(labels: Seq[String], labelToIndex: 
OpenHashMap[String, Double]) = {
+    val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID)
+
+    udf { label: String =>
       if (label == null) {
         if (keepInvalid) {
           labels.length
@@ -257,29 +404,73 @@ class StringIndexerModel (
         } else if (keepInvalid) {
           labels.length
         } else {
-          throw new SparkException(s"Unseen label: $label.  To handle unseen 
labels, " +
+          throw new SparkException(s"Unseen label: $label. To handle unseen 
labels, " +
             s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
         }
       }
     }.asNondeterministic()
+  }
 
-    filteredDataset.select(col("*"),
-      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), 
metadata))
+  @Since("2.0.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
+
+    var (inputColNames, outputColNames) = getInOutCols()
+    val outputColumns = new Array[Column](outputColNames.length)
+
+    // Skips invalid rows if `handleInvalid` is set to 
`StringIndexer.SKIP_INVALID`.
+    val filteredDataset = if (getHandleInvalid == StringIndexer.SKIP_INVALID) {
+      filterInvalidData(dataset, inputColNames)
+    } else {
+      dataset
+    }
+
+    for (i <- 0 until outputColNames.length) {
+      val inputColName = inputColNames(i)
+      val outputColName = outputColNames(i)
+      val labelToIndex = labelsToIndexArray(i)
+      val labels = labelsArray(i)
+
+      if (!dataset.schema.fieldNames.contains(inputColName)) {
+        logWarning(s"Input column ${inputColName} does not exist during 
transformation. " +
+          "Skip StringIndexerModel for this column.")
+        outputColNames(i) = null
+      } else {
+        val filteredLabels = getHandleInvalid match {
+          case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
+          case _ => labels
+        }
+        val metadata = NominalAttribute.defaultAttr
+          .withName(outputColName)
+          .withValues(filteredLabels)
+          .toMetadata()
+
+        val indexer = getIndexer(labels, labelToIndex)
+
+        outputColumns(i) = indexer(dataset(inputColName).cast(StringType))
+          .as(outputColName, metadata)
+      }
+    }
+
+    val filteredOutputColNames = outputColNames.filter(_ != null)
+    val filteredOutputColumns = outputColumns.filter(_ != null)
+
+    require(filteredOutputColNames.length == filteredOutputColumns.length)
+    if (filteredOutputColNames.length > 0) {
+      filteredDataset.withColumns(filteredOutputColNames, 
filteredOutputColumns)
+    } else {
+      filteredDataset.toDF()
+    }
   }
 
   @Since("1.4.0")
   override def transformSchema(schema: StructType): StructType = {
-    if (schema.fieldNames.contains($(inputCol))) {
-      validateAndTransformSchema(schema)
-    } else {
-      // If the input column does not exist during transformation, we skip 
StringIndexerModel.
-      schema
-    }
+    validateAndTransformSchema(schema, skipNonExistsCol = true)
   }
 
   @Since("1.4.1")
   override def copy(extra: ParamMap): StringIndexerModel = {
-    val copied = new StringIndexerModel(uid, labels)
+    val copied = new StringIndexerModel(uid, labelsArray)
     copyValues(copied, extra).setParent(parent)
   }
 
@@ -293,11 +484,11 @@ object StringIndexerModel extends 
MLReadable[StringIndexerModel] {
   private[StringIndexerModel]
   class StringIndexModelWriter(instance: StringIndexerModel) extends MLWriter {
 
-    private case class Data(labels: Array[String])
+    private case class Data(labelsArray: Array[Array[String]])
 
     override protected def saveImpl(path: String): Unit = {
       DefaultParamsWriter.saveMetadata(instance, path, sc)
-      val data = Data(instance.labels)
+      val data = Data(instance.labelsArray)
       val dataPath = new Path(path, "data").toString
       
sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
     }
@@ -310,11 +501,25 @@ object StringIndexerModel extends 
MLReadable[StringIndexerModel] {
     override def load(path: String): StringIndexerModel = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
       val dataPath = new Path(path, "data").toString
-      val data = sparkSession.read.parquet(dataPath)
-        .select("labels")
-        .head()
-      val labels = data.getAs[Seq[String]](0).toArray
-      val model = new StringIndexerModel(metadata.uid, labels)
+
+      // We support loading old `StringIndexerModel` saved by previous Spark 
versions.
+      // Previous model has `labels`, but new model has `labelsArray`.
+      val (majorVersion, minorVersion) = 
majorMinorVersion(metadata.sparkVersion)
+      val labelsArray = if (majorVersion < 3) {
+        // Spark 2.4 and before.
+        val data = sparkSession.read.parquet(dataPath)
+          .select("labels")
+          .head()
+        val labels = data.getAs[Seq[String]](0).toArray
+        Array(labels)
+      } else {
+        // After Spark 3.0.
+        val data = sparkSession.read.parquet(dataPath)
+          .select("labelsArray")
+          .head()
+        data.getAs[Seq[Seq[String]]](0).map(_.toArray).toArray
+      }
+      val model = new StringIndexerModel(metadata.uid, labelsArray)
       metadata.getAndSetParams(model)
       model
     }
@@ -421,3 +626,47 @@ object IndexToString extends 
DefaultParamsReadable[IndexToString] {
   @Since("1.6.0")
   override def load(path: String): IndexToString = super.load(path)
 }
+
+/**
+ * A SQL `Aggregator` used by `StringIndexer` to count labels in string 
columns during fitting.
+ */
+private class StringIndexerAggregator(numColumns: Int)
+  extends Aggregator[Row, Array[OpenHashMap[String, Long]], 
Array[OpenHashMap[String, Long]]] {
+
+  override def zero: Array[OpenHashMap[String, Long]] =
+    Array.fill(numColumns)(new OpenHashMap[String, Long]())
+
+  def reduce(
+      array: Array[OpenHashMap[String, Long]],
+      row: Row): Array[OpenHashMap[String, Long]] = {
+    for (i <- 0 until numColumns) {
+      val stringValue = row.getString(i)
+      // We don't count for null values.
+      if (stringValue != null) {
+        array(i).changeValue(stringValue, 1L, _ + 1)
+      }
+    }
+    array
+  }
+
+  def merge(
+      array1: Array[OpenHashMap[String, Long]],
+      array2: Array[OpenHashMap[String, Long]]): Array[OpenHashMap[String, 
Long]] = {
+    for (i <- 0 until numColumns) {
+      array2(i).foreach { case (key: String, count: Long) =>
+        array1(i).changeValue(key, count, _ + count)
+      }
+    }
+    array1
+  }
+
+  def finish(array: Array[OpenHashMap[String, Long]]): 
Array[OpenHashMap[String, Long]] = array
+
+  override def bufferEncoder: Encoder[Array[OpenHashMap[String, Long]]] = {
+    Encoders.kryo[Array[OpenHashMap[String, Long]]]
+  }
+
+  override def outputEncoder: Encoder[Array[OpenHashMap[String, Long]]] = {
+    Encoders.kryo[Array[OpenHashMap[String, Long]]]
+  }
+}
diff --git 
a/mllib/src/test/resources/test-data/strIndexerModel/data/part-00000-cfefeb56-2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet
 
b/mllib/src/test/resources/test-data/strIndexerModel/data/part-00000-cfefeb56-2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet
new file mode 100644
index 0000000..917984c
Binary files /dev/null and 
b/mllib/src/test/resources/test-data/strIndexerModel/data/part-00000-cfefeb56-2980-4c42-b8a7-a5a94265c479-c000.snappy.parquet
 differ
diff --git 
a/mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000 
b/mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000
new file mode 100644
index 0000000..5650199
--- /dev/null
+++ b/mllib/src/test/resources/test-data/strIndexerModel/metadata/part-00000
@@ -0,0 +1 @@
+{"class":"org.apache.spark.ml.feature.StringIndexerModel","timestamp":1545536052048,"sparkVersion":"2.4.1-SNAPSHOT","uid":"strIdx_056bb5da1bf2","paramMap":{"outputCol":"index","inputCol":"str"},"defaultParamMap":{"outputCol":"strIdx_056bb5da1bf2__output","stringOrderType":"frequencyDesc","handleInvalid":"error"}}
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index 0de6528..675e7b6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -137,14 +137,17 @@ class RFormulaSuite extends MLTest with 
DefaultReadWriteTest {
 
   test("encodes string terms") {
     val formula = new RFormula().setFormula("id ~ a + b")
-    val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 
5))
+    val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 
5),
+      (5, "bar", 6), (6, "foo", 6))
       .toDF("id", "a", "b")
     val model = formula.fit(original)
     val expected = Seq(
         (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
         (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
         (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
-        (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0)
+        (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0),
+        (5, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 5.0),
+        (6, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 6.0)
       ).toDF("id", "a", "b", "features", "label")
     testRFormulaTransform[(Int, String, Int)](original, model, expected)
   }
@@ -303,7 +306,8 @@ class RFormulaSuite extends MLTest with 
DefaultReadWriteTest {
   test("index string label") {
     val formula = new RFormula().setFormula("id ~ a + b")
     val original =
-      Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), 
("male", "baz", 5))
+      Seq(("male", "foo", 4), ("female", "bar", 4), ("female", "bar", 5), 
("male", "baz", 5),
+        ("female", "bar", 6), ("female", "foo", 6))
         .toDF("id", "a", "b")
     val model = formula.fit(original)
     val attr = NominalAttribute.defaultAttr
@@ -311,7 +315,9 @@ class RFormulaSuite extends MLTest with 
DefaultReadWriteTest {
         ("male", "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
         ("female", "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
         ("female", "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 0.0),
-        ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0)
+        ("male", "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 1.0),
+        ("female", "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0),
+        ("female", "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 0.0)
     ).toDF("id", "a", "b", "features", "label")
       .select($"id", $"a", $"b", $"features", $"label".as("label", 
attr.toMetadata()))
     testRFormulaTransform[(String, String, Int)](original, model, expected)
@@ -320,7 +326,8 @@ class RFormulaSuite extends MLTest with 
DefaultReadWriteTest {
   test("force to index label even it is numeric type") {
     val formula = new RFormula().setFormula("id ~ a + 
b").setForceIndexLabel(true)
     val original = spark.createDataFrame(
-      Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5))
+      Seq((1.0, "foo", 4), (1.0, "bar", 4), (0.0, "bar", 5), (1.0, "baz", 5),
+      (1.0, "bar", 6), (0.0, "foo", 6))
     ).toDF("id", "a", "b")
     val model = formula.fit(original)
     val attr = NominalAttribute.defaultAttr
@@ -328,7 +335,9 @@ class RFormulaSuite extends MLTest with 
DefaultReadWriteTest {
         (1.0, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 0.0),
         (1.0, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 0.0),
         (0.0, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 1.0),
-        (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0))
+        (1.0, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 0.0),
+        (1.0, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 0.0),
+        (0.0, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 1.0))
       .toDF("id", "a", "b", "features", "label")
       .select($"id", $"a", $"b", $"features", $"label".as("label", 
attr.toMetadata()))
     testRFormulaTransform[(Double, String, Int)](original, model, expected)
@@ -336,14 +345,17 @@ class RFormulaSuite extends MLTest with 
DefaultReadWriteTest {
 
   test("attribute generation") {
     val formula = new RFormula().setFormula("id ~ a + b")
-    val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 
5))
+    val original = Seq((1, "foo", 4), (2, "bar", 4), (3, "bar", 5), (4, "baz", 
5),
+      (1, "bar", 6), (0, "foo", 6))
       .toDF("id", "a", "b")
     val model = formula.fit(original)
     val expected = Seq(
       (1, "foo", 4, Vectors.dense(0.0, 1.0, 4.0), 1.0),
       (2, "bar", 4, Vectors.dense(1.0, 0.0, 4.0), 2.0),
       (3, "bar", 5, Vectors.dense(1.0, 0.0, 5.0), 3.0),
-      (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+      (4, "baz", 5, Vectors.dense(0.0, 0.0, 5.0), 4.0),
+      (1, "bar", 6, Vectors.dense(1.0, 0.0, 6.0), 1.0),
+      (0, "foo", 6, Vectors.dense(0.0, 1.0, 6.0), 0.0))
       .toDF("id", "a", "b", "features", "label")
     val expectedAttrs = new AttributeGroup(
       "features",
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index df24367..f542e34 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -30,12 +30,46 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
 
   test("params") {
     ParamsSuite.checkParams(new StringIndexer)
-    val model = new StringIndexerModel("indexer", Array("a", "b"))
+    val model = new StringIndexerModel("indexer", Array(Array("a", "b")))
     val modelWithoutUid = new StringIndexerModel(Array("a", "b"))
     ParamsSuite.checkParams(model)
     ParamsSuite.checkParams(modelWithoutUid)
   }
 
+  test("params: input/output columns") {
+    val stringIndexerSingleCol = new StringIndexer()
+      .setInputCol("in").setOutputCol("out")
+    val inOutCols1 = stringIndexerSingleCol.getInOutCols()
+    assert(inOutCols1._1 === Array("in"))
+    assert(inOutCols1._2 === Array("out"))
+
+    val stringIndexerMultiCol = new StringIndexer()
+      .setInputCols(Array("in1", "in2")).setOutputCols(Array("out1", "out2"))
+    val inOutCols2 = stringIndexerMultiCol.getInOutCols()
+    assert(inOutCols2._1 === Array("in1", "in2"))
+    assert(inOutCols2._2 === Array("out1", "out2"))
+
+
+    val df = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, 
"c")).toDF("id", "label")
+
+    intercept[IllegalArgumentException] {
+      new StringIndexer().setInputCol("in").setOutputCols(Array("out1", 
"out2")).fit(df)
+    }
+    intercept[IllegalArgumentException] {
+      new StringIndexer().setInputCols(Array("in1", 
"in2")).setOutputCol("out1").fit(df)
+    }
+    intercept[IllegalArgumentException] {
+      new StringIndexer().setInputCols(Array("in1", "in2"))
+        .setOutputCols(Array("out1", "out2", "out3"))
+        .fit(df)
+    }
+    intercept[IllegalArgumentException] {
+      new StringIndexer().setInputCols(Array("in1", "in2"))
+        .setOutputCols(Array("out1", "out1"))
+        .fit(df)
+    }
+  }
+
   test("StringIndexer") {
     val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
     val df = data.toDF("id", "label")
@@ -51,7 +85,7 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
       (2, 1.0),
       (3, 0.0),
       (4, 0.0),
-       (5, 1.0)
+      (5, 1.0)
     ).toDF("id", "labelIndex")
 
     testTransformerByGlobalCheckFunc[(Int, String)](df, indexerModel, "id", 
"labelIndex") { rows =>
@@ -167,7 +201,7 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
   }
 
   test("StringIndexerModel should keep silent if the input column does not 
exist.") {
-    val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
+    val indexerModel = new StringIndexerModel("indexer", Array(Array("a", "b", 
"c")))
       .setInputCol("label")
       .setOutputCol("labelIndex")
     val df = spark.range(0L, 10L).toDF()
@@ -207,7 +241,7 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
   }
 
   test("StringIndexerModel read/write") {
-    val instance = new StringIndexerModel("myStringIndexerModel", Array("a", 
"b", "c"))
+    val instance = new StringIndexerModel("myStringIndexerModel", 
Array(Array("a", "b", "c")))
       .setInputCol("myInputCol")
       .setOutputCol("myOutputCol")
       .setHandleInvalid("skip")
@@ -323,11 +357,32 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
     }
   }
 
+  test("StringIndexer order types: secondary sort by alphabets when frequency 
equal") {
+    val data = Seq((0, "a"), (1, "a"), (2, "b"), (3, "b"), (4, "c"), (5, "d"))
+    val df = data.toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+
+    val expected = Seq(Set((0, 0.0), (1, 0.0), (2, 1.0), (3, 1.0), (4, 2.0), 
(5, 3.0)),
+      Set((0, 2.0), (1, 2.0), (2, 3.0), (3, 3.0), (4, 0.0), (5, 1.0)))
+
+    var idx = 0
+    for (orderType <- Seq("frequencyDesc", "frequencyAsc")) {
+      val transformed = 
indexer.setStringOrderType(orderType).fit(df).transform(df)
+      val output = transformed.select("id", "labelIndex").rdd.map { r =>
+        (r.getInt(0), r.getDouble(1))
+      }.collect().toSet
+      assert(output === expected(idx))
+      idx += 1
+    }
+  }
+
   test("SPARK-22446: StringIndexerModel's indexer UDF should not apply on 
filtered data") {
     val df = List(
-         ("A", "London", "StrA"),
-         ("B", "Bristol", null),
-         ("C", "New York", "StrC")).toDF("ID", "CITY", "CONTENT")
+      ("A", "London", "StrA"),
+      ("B", "Bristol", null),
+      ("C", "New York", "StrC")).toDF("ID", "CITY", "CONTENT")
 
     val dfNoBristol = df.filter($"CONTENT".isNotNull)
 
@@ -343,4 +398,74 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
       assert(rows.toList.count(_.getDouble(0) == 1.0) === 1)
     }
   }
+
+  test("StringIndexer multiple input columns") {
+    val data = Seq(
+      Row("a", 0.0, "e", 1.0),
+      Row("b", 2.0, "f", 0.0),
+      Row("c", 1.0, "e", 1.0),
+      Row("a", 0.0, "f", 0.0),
+      Row("a", 0.0, "f", 0.0),
+      Row("c", 1.0, "f", 0.0))
+
+    val schema = StructType(Array(
+      StructField("label1", StringType),
+      StructField("expected1", DoubleType),
+      StructField("label2", StringType),
+      StructField("expected2", DoubleType)))
+
+    val df = spark.createDataFrame(sc.parallelize(data), schema)
+
+    val indexer = new StringIndexer()
+      .setInputCols(Array("label1", "label2"))
+      .setOutputCols(Array("labelIndex1", "labelIndex2"))
+    val indexerModel = indexer.fit(df)
+
+    MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
+
+    val transformed = indexerModel.transform(df)
+
+    // Checks output attribute correctness.
+    val attr1 = Attribute.fromStructField(transformed.schema("labelIndex1"))
+      .asInstanceOf[NominalAttribute]
+    assert(attr1.values.get === Array("a", "c", "b"))
+    val attr2 = Attribute.fromStructField(transformed.schema("labelIndex2"))
+      .asInstanceOf[NominalAttribute]
+    assert(attr2.values.get === Array("f", "e"))
+
+    transformed.select("labelIndex1", "expected1").rdd.map { r =>
+      (r.getDouble(0), r.getDouble(1))
+    }.collect().foreach { case (index, expected) =>
+      assert(index == expected)
+    }
+
+    transformed.select("labelIndex2", "expected2").rdd.map { r =>
+      (r.getDouble(0), r.getDouble(1))
+    }.collect().foreach { case (index, expected) =>
+      assert(index == expected)
+    }
+  }
+
+  test("Correctly skipping NULL and NaN values") {
+    val df = Seq(("a", Double.NaN), (null, 1.0), ("b", 2.0), (null, 
3.0)).toDF("str", "double")
+
+    val indexer = new StringIndexer()
+      .setInputCols(Array("str", "double"))
+      .setOutputCols(Array("strIndex", "doubleIndex"))
+
+    val model = indexer.fit(df)
+    assert(model.labelsArray(0) === Array("a", "b"))
+    assert(model.labelsArray(1) === Array("1.0", "2.0", "3.0"))
+  }
+
+  test("Load StringIndexderModel prior to Spark 3.0") {
+    val modelPath = testFile("test-data/strIndexerModel")
+
+    val loadedModel = StringIndexerModel.load(modelPath)
+    assert(loadedModel.labelsArray === Array(Array("b", "c", "a")))
+
+    val metadata = spark.read.json(s"$modelPath/metadata")
+    val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
+    assert(sparkVersionStr == "2.4.1-SNAPSHOT")
+  }
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 0cdef00..a13ee51 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -278,6 +278,10 @@ object MimaExcludes {
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"),
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"),
 
+    // [SPARK-11215][ML] Add multiple columns support to StringIndexer
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"),
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"),
+
     // [SPARK-26616][MLlib] Expose document frequency in IDFModel
     
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"),
     
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to