This is an automated email from the ASF dual-hosted git repository.

weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9bff2c8bc505 [SPARK-48463][ML] Make StringIndexer supporting nested 
input columns
9bff2c8bc505 is described below

commit 9bff2c8bc5059f5be0dc6e8105c11403942a0b9f
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Mon Jul 15 15:19:59 2024 +0800

    [SPARK-48463][ML] Make StringIndexer supporting nested input columns
    
    ### What changes were proposed in this pull request?
    
    Make StringIndexer supporting nested input columns
    
    ### Why are the changes needed?
    
    User demand.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes.
    
    ### How was this patch tested?
    
    Unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #47283 from WeichenXu123/SPARK-48463.
    
    Lead-authored-by: Weichen Xu <weichen...@databricks.com>
    Co-authored-by: WeichenXu <weichen...@databricks.com>
    Signed-off-by: Weichen Xu <weichen...@databricks.com>
---
 .../apache/spark/ml/feature/StringIndexer.scala    | 37 +++++++++++------
 .../spark/ml/feature/StringIndexerSuite.scala      | 47 +++++++++++++++++++++-
 2 files changed, 71 insertions(+), 13 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 60dc4d024071..34f77f029395 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml.feature
 
+import java.util.ArrayList
+
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark.SparkException
@@ -27,7 +29,7 @@ import org.apache.spark.ml.attribute.{Attribute, 
NominalAttribute}
 import org.apache.spark.ml.param._
 import org.apache.spark.ml.param.shared._
 import org.apache.spark.ml.util._
-import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, 
Row}
+import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, 
Encoder, Encoders, Row, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.{If, Literal}
 import org.apache.spark.sql.expressions.Aggregator
 import org.apache.spark.sql.functions._
@@ -103,8 +105,8 @@ private[feature] trait StringIndexerBase extends Params 
with HasHandleInvalid wi
   private def validateAndTransformField(
       schema: StructType,
       inputColName: String,
+      inputDataType: DataType,
       outputColName: String): StructField = {
-    val inputDataType = schema(inputColName).dataType
     require(inputDataType == StringType || 
inputDataType.isInstanceOf[NumericType],
       s"The input column $inputColName must be either string type or numeric 
type, " +
         s"but got $inputDataType.")
@@ -122,12 +124,22 @@ private[feature] trait StringIndexerBase extends Params 
with HasHandleInvalid wi
     require(outputColNames.distinct.length == outputColNames.length,
       s"Output columns should not be duplicate.")
 
+    val sparkSession = SparkSession.getActiveSession.get
+    val transformDataset = sparkSession.createDataFrame(new ArrayList[Row](), 
schema = schema)
     val outputFields = inputColNames.zip(outputColNames).flatMap {
       case (inputColName, outputColName) =>
-        schema.fieldNames.contains(inputColName) match {
-          case true => Some(validateAndTransformField(schema, inputColName, 
outputColName))
-          case false if skipNonExistsCol => None
-          case _ => throw new SparkException(s"Input column $inputColName does 
not exist.")
+        try {
+          val dtype = transformDataset.col(inputColName).expr.dataType
+          Some(
+            validateAndTransformField(schema, inputColName, dtype, 
outputColName)
+          )
+        } catch {
+          case _: AnalysisException =>
+            if (skipNonExistsCol) {
+              None
+            } else {
+              throw new SparkException(s"Input column $inputColName does not 
exist.")
+            }
         }
     }
     StructType(schema.fields ++ outputFields)
@@ -431,11 +443,8 @@ class StringIndexerModel (
       val labelToIndex = labelsToIndexArray(i)
       val labels = labelsArray(i)
 
-      if (!dataset.schema.fieldNames.contains(inputColName)) {
-        logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, inputColName)} 
does not exist " +
-          log"during transformation. Skip StringIndexerModel for this column.")
-        outputColNames(i) = null
-      } else {
+      try {
+        dataset.col(inputColName)
         val filteredLabels = getHandleInvalid match {
           case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
           case _ => labels
@@ -449,9 +458,13 @@ class StringIndexerModel (
 
         outputColumns(i) = indexer(dataset(inputColName).cast(StringType))
           .as(outputColName, metadata)
+      } catch {
+        case _: AnalysisException =>
+          logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, 
inputColName)} does not exist " +
+            log"during transformation. Skip StringIndexerModel for this 
column.")
+          outputColNames(i) = null
       }
     }
-
     val filteredOutputColNames = outputColNames.filter(_ != null)
     val filteredOutputColumns = outputColumns.filter(_ != null)
 
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 99f12eab7d69..8f3750959d2b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -21,7 +21,8 @@ import org.apache.spark.ml.attribute.{Attribute, 
NominalAttribute}
 import org.apache.spark.ml.param.ParamsSuite
 import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.sql.Row
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.catalyst.parser.DataTypeParser
+import org.apache.spark.sql.functions.{col, struct}
 import org.apache.spark.sql.types.{DoubleType, StringType, StructField, 
StructType}
 
 class StringIndexerSuite extends MLTest with DefaultReadWriteTest {
@@ -113,6 +114,50 @@ class StringIndexerSuite extends MLTest with 
DefaultReadWriteTest {
     assert(outSchema("output2").dataType === DoubleType)
   }
 
+  test("StringIndexer.transformSchema nested col") {
+    val outputCols = Array("output", "output2", "output3", "output4", 
"output5")
+    val idxToStr = new StringIndexer().setInputCols(
+      Array("input1.a.f1", "input1.a.f2", "input2.b1", "input2.b2", "input3")
+    ).setOutputCols(outputCols)
+
+    val inSchema = DataTypeParser.parseTableSchema(
+      "input1 struct<a struct<f1 string, f2 string>>, " +
+      "input2 struct<b1 string, b2 string>, input3 string"
+    )
+    val outSchema = idxToStr.transformSchema(inSchema)
+
+    for (outputCol <- outputCols) {
+      assert(outSchema(outputCol).dataType === DoubleType)
+    }
+  }
+
+  test("StringIndexer nested input cols") {
+    val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c"))
+    val df = data.toDF("id", "label")
+      .select(col("id"), struct(col("label")).alias("c1"))
+    val indexer = new StringIndexer()
+      .setInputCol("c1.label")
+      .setOutputCol("labelIndex")
+    val indexerModel = indexer.fit(df)
+    MLTestingUtils.checkCopyAndUids(indexer, indexerModel)
+    // a -> 0, b -> 2, c -> 1
+    val expected = Seq(
+      (0, 0.0),
+      (1, 2.0),
+      (2, 1.0),
+      (3, 0.0),
+      (4, 0.0),
+      (5, 1.0)
+    ).toDF("id", "labelIndex")
+
+    val dfOutput = indexerModel.transform(df)
+    val outputs = dfOutput.select("id", "labelIndex").collect().toSeq
+    val attr = Attribute.fromStructField(outputs.head.schema("labelIndex"))
+      .asInstanceOf[NominalAttribute]
+    assert(attr.values.get === Array("a", "c", "b"))
+    assert(outputs === expected.collect().toSeq)
+  }
+
   test("StringIndexerUnseen") {
     val data = Seq((0, "a"), (1, "b"), (4, "b"))
     val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to