Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20132#discussion_r159157608
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala 
---
    @@ -205,60 +210,58 @@ class OneHotEncoderModel private[ml] (
     
       import OneHotEncoderModel._
     
    -  // Returns the category size for a given index with `dropLast` and 
`handleInvalid`
    +  // Returns the category size for each index with `dropLast` and 
`handleInvalid`
       // taken into account.
    -  private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = {
    +  private def getConfigedCategorySizes: Array[Int] = {
         val dropLast = getDropLast
         val keepInvalid = getHandleInvalid == 
OneHotEncoderEstimator.KEEP_INVALID
     
         if (!dropLast && keepInvalid) {
           // When `handleInvalid` is "keep", an extra category is added as 
last category
           // for invalid data.
    -      orgCategorySize + 1
    +      categorySizes.map(_ + 1)
         } else if (dropLast && !keepInvalid) {
           // When `dropLast` is true, the last category is removed.
    -      orgCategorySize - 1
    +      categorySizes.map(_ - 1)
         } else {
           // When `dropLast` is true and `handleInvalid` is "keep", the extra 
category for invalid
           // data is removed. Thus, it is the same as the plain number of 
categories.
    -      orgCategorySize
    +      categorySizes
         }
       }
     
       private def encoder: UserDefinedFunction = {
    -    val oneValue = Array(1.0)
    -    val emptyValues = Array.empty[Double]
    -    val emptyIndices = Array.empty[Int]
    -    val dropLast = getDropLast
    -    val handleInvalid = getHandleInvalid
    -    val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID
    +    val keepInvalid = getHandleInvalid == 
OneHotEncoderEstimator.KEEP_INVALID
    +    val configedSizes = getConfigedCategorySizes
    +    val localCategorySizes = categorySizes
     
         // The udf performed on input data. The first parameter is the input 
value. The second
    -    // parameter is the index of input.
    -    udf { (label: Double, idx: Int) =>
    -      val plainNumCategories = categorySizes(idx)
    -      val size = configedCategorySize(plainNumCategories, idx)
    -
    -      if (label < 0) {
    -        throw new SparkException(s"Negative value: $label. Input can't be 
negative.")
    -      } else if (label == size && dropLast && !keepInvalid) {
    -        // When `dropLast` is true and `handleInvalid` is not "keep",
    -        // the last category is removed.
    -        Vectors.sparse(size, emptyIndices, emptyValues)
    -      } else if (label >= plainNumCategories && keepInvalid) {
    -        // When `handleInvalid` is "keep", encodes invalid data to last 
category (and removed
    -        // if `dropLast` is true)
    -        if (dropLast) {
    -          Vectors.sparse(size, emptyIndices, emptyValues)
    +    // parameter is the index in inputCols of the column being encoded.
    +    udf { (label: Double, colIdx: Int) =>
    +      val origCategorySize = localCategorySizes(colIdx)
    +      // idx: index in vector of the single 1-valued element
    +      val idx = if (label >= 0 && label < origCategorySize) {
    +        label
    +      } else {
    +        if (keepInvalid) {
    +          origCategorySize
             } else {
    -          Vectors.sparse(size, Array(size - 1), oneValue)
    +          if (label < 0) {
    +            throw new SparkException(s"Negative value: $label. Input can't 
be negative. " +
    --- End diff --
    
    I have a question. Since we don't allow negative value when fitting, should 
we allow it in transforming even handleInvalid is KEEP_INVALID?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to