Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/19621#discussion_r152252491 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -130,21 +160,49 @@ class StringIndexer @Since("1.4.0") ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val values = dataset.na.drop(Array($(inputCol))) - .select(col($(inputCol)).cast(StringType)) - .rdd.map(_.getString(0)) - val labels = $(stringOrderType) match { - case StringIndexer.frequencyDesc => values.countByValue().toSeq.sortBy(-_._2) - .map(_._1).toArray - case StringIndexer.frequencyAsc => values.countByValue().toSeq.sortBy(_._2) - .map(_._1).toArray - case StringIndexer.alphabetDesc => values.distinct.collect.sortWith(_ > _) - case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ < _) + + val inputCols = getInOutCols._1 + + val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, Long]()) + + val countByValueArray = dataset.na.drop(inputCols) + .select(inputCols.map(col(_).cast(StringType)): _*) + .rdd.aggregate(zeroState)( + (state: Array[OpenHashMap[String, Long]], row: Row) => { + for (i <- 0 until inputCols.length) { + state(i).changeValue(row.getString(i), 1L, _ + 1) + } + state + }, + (state1: Array[OpenHashMap[String, Long]], state2: Array[OpenHashMap[String, Long]]) => { + for (i <- 0 until inputCols.length) { + state2(i).foreach { case (key: String, count: Long) => + state1(i).changeValue(key, count, _ + count) + } + } + state1 + } + ) + val labelsArray = countByValueArray.map { countByValue => + $(stringOrderType) match { + case StringIndexer.frequencyDesc => countByValue.toSeq.sortBy(-_._2).map(_._1).toArray + case StringIndexer.frequencyAsc => countByValue.toSeq.sortBy(_._2).map(_._1).toArray + case StringIndexer.alphabetDesc => countByValue.toSeq.map(_._1).sortWith(_ > _).toArray + case StringIndexer.alphabetAsc => countByValue.toSeq.map(_._1).sortWith(_ < _).toArray --- End diff -- Yes, but will aggregate count bring apparent overhead ? I don't want the code including too many `if ..else`.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org