This is an automated email from the ASF dual-hosted git repository. weichenxu123 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9bff2c8bc505 [SPARK-48463][ML] Make StringIndexer supporting nested input columns 9bff2c8bc505 is described below commit 9bff2c8bc5059f5be0dc6e8105c11403942a0b9f Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Mon Jul 15 15:19:59 2024 +0800 [SPARK-48463][ML] Make StringIndexer supporting nested input columns ### What changes were proposed in this pull request? Make StringIndexer supporting nested input columns ### Why are the changes needed? User demand. ### Does this PR introduce _any_ user-facing change? Yes. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? Closes #47283 from WeichenXu123/SPARK-48463. Lead-authored-by: Weichen Xu <weichen...@databricks.com> Co-authored-by: WeichenXu <weichen...@databricks.com> Signed-off-by: Weichen Xu <weichen...@databricks.com> --- .../apache/spark/ml/feature/StringIndexer.scala | 37 +++++++++++------ .../spark/ml/feature/StringIndexerSuite.scala | 47 +++++++++++++++++++++- 2 files changed, 71 insertions(+), 13 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 60dc4d024071..34f77f029395 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import java.util.ArrayList + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -27,7 +29,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row} +import org.apache.spark.sql.{AnalysisException, Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.functions._ @@ -103,8 +105,8 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi private def validateAndTransformField( schema: StructType, inputColName: String, + inputDataType: DataType, outputColName: String): StructField = { - val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], s"The input column $inputColName must be either string type or numeric type, " + s"but got $inputDataType.") @@ -122,12 +124,22 @@ private[feature] trait StringIndexerBase extends Params with HasHandleInvalid wi require(outputColNames.distinct.length == outputColNames.length, s"Output columns should not be duplicate.") + val sparkSession = SparkSession.getActiveSession.get + val transformDataset = sparkSession.createDataFrame(new ArrayList[Row](), schema = schema) val outputFields = inputColNames.zip(outputColNames).flatMap { case (inputColName, outputColName) => - schema.fieldNames.contains(inputColName) match { - case true => Some(validateAndTransformField(schema, inputColName, outputColName)) - case false if skipNonExistsCol => None - case _ => throw new SparkException(s"Input column $inputColName does not exist.") + try { + val dtype = transformDataset.col(inputColName).expr.dataType + Some( + validateAndTransformField(schema, inputColName, dtype, outputColName) + ) + } catch { + case _: AnalysisException => + if (skipNonExistsCol) { + None + } else { + throw new SparkException(s"Input column $inputColName does not exist.") + } } } StructType(schema.fields ++ outputFields) @@ -431,11 +443,8 @@ class StringIndexerModel ( val labelToIndex = labelsToIndexArray(i) val labels = labelsArray(i) - if (!dataset.schema.fieldNames.contains(inputColName)) { - logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, inputColName)} does not exist " + - log"during transformation. Skip StringIndexerModel for this column.") - outputColNames(i) = null - } else { + try { + dataset.col(inputColName) val filteredLabels = getHandleInvalid match { case StringIndexer.KEEP_INVALID => labels :+ "__unknown" case _ => labels @@ -449,9 +458,13 @@ class StringIndexerModel ( outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) .as(outputColName, metadata) + } catch { + case _: AnalysisException => + logWarning(log"Input column ${MDC(LogKeys.COLUMN_NAME, inputColName)} does not exist " + + log"during transformation. Skip StringIndexerModel for this column.") + outputColNames(i) = null } } - val filteredOutputColNames = outputColNames.filter(_ != null) val filteredOutputColumns = outputColumns.filter(_ != null) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 99f12eab7d69..8f3750959d2b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -21,7 +21,8 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} import org.apache.spark.sql.Row -import org.apache.spark.sql.functions.col +import org.apache.spark.sql.catalyst.parser.DataTypeParser +import org.apache.spark.sql.functions.{col, struct} import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} class StringIndexerSuite extends MLTest with DefaultReadWriteTest { @@ -113,6 +114,50 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest { assert(outSchema("output2").dataType === DoubleType) } + test("StringIndexer.transformSchema nested col") { + val outputCols = Array("output", "output2", "output3", "output4", "output5") + val idxToStr = new StringIndexer().setInputCols( + Array("input1.a.f1", "input1.a.f2", "input2.b1", "input2.b2", "input3") + ).setOutputCols(outputCols) + + val inSchema = DataTypeParser.parseTableSchema( + "input1 struct<a struct<f1 string, f2 string>>, " + + "input2 struct<b1 string, b2 string>, input3 string" + ) + val outSchema = idxToStr.transformSchema(inSchema) + + for (outputCol <- outputCols) { + assert(outSchema(outputCol).dataType === DoubleType) + } + } + + test("StringIndexer nested input cols") { + val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")) + val df = data.toDF("id", "label") + .select(col("id"), struct(col("label")).alias("c1")) + val indexer = new StringIndexer() + .setInputCol("c1.label") + .setOutputCol("labelIndex") + val indexerModel = indexer.fit(df) + MLTestingUtils.checkCopyAndUids(indexer, indexerModel) + // a -> 0, b -> 2, c -> 1 + val expected = Seq( + (0, 0.0), + (1, 2.0), + (2, 1.0), + (3, 0.0), + (4, 0.0), + (5, 1.0) + ).toDF("id", "labelIndex") + + val dfOutput = indexerModel.transform(df) + val outputs = dfOutput.select("id", "labelIndex").collect().toSeq + val attr = Attribute.fromStructField(outputs.head.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attr.values.get === Array("a", "c", "b")) + assert(outputs === expected.collect().toSeq) + } + test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org