Github user dbtsai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20146#discussion_r183255152
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -130,21 +161,53 @@ class StringIndexer @Since("1.4.0") (
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    +
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, 
value)
    +
       @Since("2.0.0")
       override def fit(dataset: Dataset[_]): StringIndexerModel = {
         transformSchema(dataset.schema, logging = true)
    -    val values = dataset.na.drop(Array($(inputCol)))
    -      .select(col($(inputCol)).cast(StringType))
    -      .rdd.map(_.getString(0))
    -    val labels = $(stringOrderType) match {
    -      case StringIndexer.frequencyDesc => 
values.countByValue().toSeq.sortBy(-_._2)
    -        .map(_._1).toArray
    -      case StringIndexer.frequencyAsc => 
values.countByValue().toSeq.sortBy(_._2)
    -        .map(_._1).toArray
    -      case StringIndexer.alphabetDesc => 
values.distinct.collect.sortWith(_ > _)
    -      case StringIndexer.alphabetAsc => values.distinct.collect.sortWith(_ 
< _)
    -    }
    -    copyValues(new StringIndexerModel(uid, labels).setParent(this))
    +
    +    val (inputCols, _) = getInOutCols()
    +    val zeroState = Array.fill(inputCols.length)(new OpenHashMap[String, 
Long]())
    +
    +    // Counts by the string values in the dataset.
    +    val countByValueArray = dataset.na.drop(inputCols)
    +      .select(inputCols.map(col(_).cast(StringType)): _*)
    +      .rdd.treeAggregate(zeroState)(
    +        (state: Array[OpenHashMap[String, Long]], row: Row) => {
    +          for (i <- 0 until inputCols.length) {
    +            state(i).changeValue(row.getString(i), 1L, _ + 1)
    +          }
    +          state
    +        },
    +        (state1: Array[OpenHashMap[String, Long]], state2: 
Array[OpenHashMap[String, Long]]) => {
    +          for (i <- 0 until inputCols.length) {
    +            state2(i).foreach { case (key: String, count: Long) =>
    +              state1(i).changeValue(key, count, _ + count)
    +            }
    +          }
    +          state1
    +        }
    +      )
    +
    +    // In case of equal frequency when frequencyDesc/Asc, we further sort 
the strings by alphabet.
    +    val labelsArray = countByValueArray.map { countByValue =>
    +      $(stringOrderType) match {
    +        case StringIndexer.frequencyDesc =>
    +          countByValue.toSeq.sortBy(_._1).sortBy(-_._2).map(_._1).toArray
    +        case StringIndexer.frequencyAsc =>
    +          countByValue.toSeq.sortBy(_._1).sortBy(_._2).map(_._1).toArray
    +        case StringIndexer.alphabetDesc => 
countByValue.toSeq.map(_._1).sortWith(_ > _).toArray
    --- End diff --
    
    I think we can break the code into two paths. One is sorting by frequency 
which requires to compute the counts, and the other is sorting by alphabet 
which only requires distinct. We could move the `countByValueArray` code into 
labelsArray.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to