Github user viirya commented on a diff in the pull request:
    --- Diff: 
    @@ -0,0 +1,439 @@
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +import org.apache.hadoop.fs.Path
    +import org.apache.spark.SparkException
    +import org.apache.spark.annotation.Since
    +import{Estimator, Model, Transformer}
    +import{HasHandleInvalid, HasInputCol, 
HasInputCols, HasOutputCol, HasOutputCols}
    +import org.apache.spark.sql.{DataFrame, Dataset}
    +import org.apache.spark.sql.expressions.UserDefinedFunction
    +import org.apache.spark.sql.functions.{col, udf}
    +import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, 
    +/** Private trait for params for OneHotEncoderEstimator and 
OneHotEncoderModel */
    +private[ml] trait OneHotEncoderParams extends Params with HasHandleInvalid
    +    with HasInputCols with HasOutputCols {
    +  /**
    +   * Param for how to handle invalid data.
    +   * Options are 'skip' (filter out rows with invalid data) or 'error' 
(throw an error).
    +   * Default: "error"
    +   * @group param
    +   */
    +  @Since("2.3.0")
    +  override val handleInvalid: Param[String] = new Param[String](this, 
    +    "How to handle invalid data " +
    +    "Options are 'skip' (filter out rows with invalid data) or error 
(throw an error).",
    +  setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID)
    +  /**
    +   * Whether to drop the last category in the encoded vector (default: 
    +   * @group param
    +   */
    +  @Since("2.3.0")
    +  final val dropLast: BooleanParam =
    +    new BooleanParam(this, "dropLast", "whether to drop the last category")
    +  setDefault(dropLast -> true)
    +  /** @group getParam */
    +  @Since("2.3.0")
    +  def getDropLast: Boolean = $(dropLast)
    + * A one-hot encoder that maps a column of category indices to a column of 
binary vectors, with
    + * at most a single one-value per row that indicates the input category 
    + * For example with 5 categories, an input value of 2.0 would map to an 
output vector of
    + * `[0.0, 0.0, 1.0, 0.0]`.
    + * The last category is not included by default (configurable via 
    + * because it makes the vector entries sum up to one, and hence linearly 
    + * So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
    + *
    + * @note This is different from scikit-learn's OneHotEncoder, which keeps 
all categories.
    + * The output vectors are sparse.
    + *
    + * @see `StringIndexer` for converting categorical values into category 
    + */
    +class OneHotEncoderEstimator @Since("2.3.0") (@Since("2.3.0") override val 
uid: String)
    +    extends Estimator[OneHotEncoderModel] with OneHotEncoderParams with 
DefaultParamsWritable {
    +  @Since("2.3.0")
    +  def this() = this(Identifiable.randomUID("oneHotEncoder"))
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setInputCols(values: Array[String]): this.type = set(inputCols, 
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setOutputCols(values: Array[String]): this.type = set(outputCols, 
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setDropLast(value: Boolean): this.type = set(dropLast, value)
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setHandleInvalid(value: String): this.type = set(handleInvalid, 
    +  @Since("2.3.0")
    +  override def transformSchema(schema: StructType): StructType = {
    +    val inputColNames = $(inputCols)
    +    val outputColNames = $(outputCols)
    +    val inputFields = schema.fields
    +    require(inputColNames.length == outputColNames.length,
    +      s"The number of input columns ${inputColNames.length} must be the 
same as the number of " +
    +        s"output columns ${outputColNames.length}.")
    +    val outputFields = { case 
(inputColName, outputColName) =>
    +      require(schema(inputColName).dataType.isInstanceOf[NumericType],
    +        s"Input column must be of type NumericType but got 
    +      require(!inputFields.exists( == outputColName),
    +        s"Output column $outputColName already exists.")
    +      OneHotEncoderCommon.transformOutputColumnSchema(
    +        schema(inputColName), $(dropLast), outputColName)
    +    }
    +    StructType(inputFields ++ outputFields)
    +  }
    +  @Since("2.3.0")
    +  override def fit(dataset: Dataset[_]): OneHotEncoderModel = {
    +    val transformedSchema = transformSchema(dataset.schema)
    +    val categorySizes = $(outputCols) { case 
(outputColName, idx) =>
    +      val outputAttrGroupFromSchema = AttributeGroup.fromStructField(
    +        transformedSchema(outputColName))
    +      val outputAttrGroup = if (outputAttrGroupFromSchema.size < 0) {
    +        OneHotEncoderCommon.getOutputAttrGroupFromData(
    +          dataset, $(dropLast), $(inputCols)(idx), outputColName)
    +      } else {
    +        outputAttrGroupFromSchema
    +      }
    +      outputAttrGroup.size
    +    }
    +    val model = new OneHotEncoderModel(uid, categorySizes).setParent(this)
    +    copyValues(model)
    +  }
    +  @Since("2.3.0")
    +  override def copy(extra: ParamMap): OneHotEncoderEstimator = 
    +object OneHotEncoderEstimator extends 
DefaultParamsReadable[OneHotEncoderEstimator] {
    +  private[feature] val SKIP_INVALID: String = "skip"
    +  private[feature] val ERROR_INVALID: String = "error"
    +  private[feature] val supportedHandleInvalids: Array[String] = 
    +  @Since("2.3.0")
    +  override def load(path: String): OneHotEncoderEstimator = 
    +class OneHotEncoderModel private[ml] (
    +    @Since("2.3.0") override val uid: String,
    +    @Since("2.3.0") val categorySizes: Array[Int])
    +  extends Model[OneHotEncoderModel] with OneHotEncoderParams with 
MLWritable {
    +  import OneHotEncoderModel._
    +  private def encoders: Array[UserDefinedFunction] = {
    +    val oneValue = Array(1.0)
    +    val emptyValues = Array.empty[Double]
    +    val emptyIndices = Array.empty[Int]
    +    val dropLast = getDropLast
    +    val handleInvalid = getHandleInvalid
    + { size =>
    +      udf { label: Double =>
    +        if (label < size) {
    +          Vectors.sparse(size, Array(label.toInt), oneValue)
    +        } else if (label == size && dropLast) {
    +          Vectors.sparse(size, emptyIndices, emptyValues)
    +        } else {
    +          if (handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) {
    +            throw new SparkException(s"Unseen value: $label. To handle 
unseen values, " +
    +              s"set Param handleInvalid to 
    +          } else {
    +            Vectors.sparse(size, emptyIndices, emptyValues)
    +          }
    +        }
    +      }
    +    }
    +  }
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setInputCols(values: Array[String]): this.type = set(inputCols, 
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setOutputCols(values: Array[String]): this.type = set(outputCols, 
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setDropLast(value: Boolean): this.type = set(dropLast, value)
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setHandleInvalid(value: String): this.type = set(handleInvalid, 
    +  @Since("2.3.0")
    +  override def transformSchema(schema: StructType): StructType = {
    +    val inputColNames = $(inputCols)
    +    val outputColNames = $(outputCols)
    +    val inputFields = schema.fields
    +    require(inputColNames.length == outputColNames.length,
    +      s"The number of input columns ${inputColNames.length} must be the 
same as the number of " +
    +        s"output columns ${outputColNames.length}.")
    +    require(inputColNames.length == categorySizes.length,
    +      s"The number of input columns ${inputColNames.length} must be the 
same as the number of " +
    +        s"features ${categorySizes.length} during fitting.")
    +    val inputOutputPairs =
    +    val outputFields = { case (inputColName, 
outputColName) =>
    +      require(schema(inputColName).dataType.isInstanceOf[NumericType],
    +        s"Input column must be of type NumericType but got 
    +      require(!inputFields.exists( == outputColName),
    +        s"Output column $outputColName already exists.")
    +      OneHotEncoderCommon.transformOutputColumnSchema(
    +        schema(inputColName), $(dropLast), outputColName)
    +    }
    +    verifyNumOfValues(StructType(inputFields ++ outputFields))
    +  }
    +  private def verifyNumOfValues(schema: StructType): StructType = {
    +    $(outputCols).zipWithIndex.foreach { case (outputColName, idx) =>
    +      val inputColName = $(inputCols)(idx)
    +      val attrGroup = AttributeGroup.fromStructField(schema(outputColName))
    +      // If the input metadata specifies number of category,
    +      // compare with expected category number.
    +      if (attrGroup.attributes.nonEmpty) {
    +        require(attrGroup.size == categorySizes(idx), "OneHotEncoderModel 
expected " +
    +          s"${categorySizes(idx)} categorical values for input column 
${inputColName}, but " +
    +            s"the input column had metadata specifying ${attrGroup.size} 
    +      }
    +    }
    +    schema
    +  }
    +  @Since("2.3.0")
    +  override def transform(dataset: Dataset[_]): DataFrame = {
    +    if (getDropLast && getHandleInvalid == 
OneHotEncoderEstimator.SKIP_INVALID) {
    +      throw new IllegalArgumentException("When Param handleInvalid is set 
to " +
    +        s"${OneHotEncoderEstimator.SKIP_INVALID}, Param dropLast can't be 
true, " +
    +        "because last category and invalid values will conflict in encoded 
    +    }
    +    val transformedSchema = transformSchema(dataset.schema, logging = true)
    +    val encodedColumns = { case (encoder, idx) =>
    +      val inputColName = $(inputCols)(idx)
    +      val outputColName = $(outputCols)(idx)
    +      val outputAttrGroupFromSchema =
    +        AttributeGroup.fromStructField(transformedSchema(outputColName))
    +      val metadata = if (outputAttrGroupFromSchema.size < 0) {
    +        OneHotEncoderCommon.createAttrGroupForAttrNames(outputColName, 
    +          categorySizes(idx)).toMetadata()
    +      } else {
    +        outputAttrGroupFromSchema.toMetadata()
    +      }
    +      encoder(col(inputColName).cast(DoubleType)).as(outputColName, 
    +    }
    +    val allCols = Seq(col("*")) ++ encodedColumns
    + _*)
    +  }
    +  @Since("2.3.0")
    +  override def copy(extra: ParamMap): OneHotEncoderModel = {
    +    val copied = new OneHotEncoderModel(uid, categorySizes)
    +    copyValues(copied, extra).setParent(parent)
    +  }
    +  @Since("2.3.0")
    +  override def write: MLWriter = new OneHotEncoderModelWriter(this)
    +object OneHotEncoderModel extends MLReadable[OneHotEncoderModel] {
    +  private[OneHotEncoderModel]
    +  class OneHotEncoderModelWriter(instance: OneHotEncoderModel) extends 
MLWriter {
    +    private case class Data(categorySizes: Array[Int])
    +    override protected def saveImpl(path: String): Unit = {
    +      DefaultParamsWriter.saveMetadata(instance, path, sc)
    +      val data = Data(instance.categorySizes)
    +      val dataPath = new Path(path, "data").toString
    +    }
    +  }
    +  private class OneHotEncoderModelReader extends 
MLReader[OneHotEncoderModel] {
    +    private val className = classOf[OneHotEncoderModel].getName
    +    override def load(path: String): OneHotEncoderModel = {
    +      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
    +      val dataPath = new Path(path, "data").toString
    +      val data =
    +        .select("categorySizes")
    +        .head()
    +      val categorySizes = data.getAs[Seq[Int]](0).toArray
    +      val model = new OneHotEncoderModel(metadata.uid, categorySizes)
    +      DefaultParamsReader.getAndSetParams(model, metadata)
    +      model
    +    }
    +  }
    +  @Since("2.3.0")
    +  override def read: MLReader[OneHotEncoderModel] = new 
    +  @Since("2.3.0")
    +  override def load(path: String): OneHotEncoderModel = super.load(path)
    + * Provides some helper methods used by both `OneHotEncoder` and 
    + */
    +private[feature] object OneHotEncoderCommon {
    +  private def genOutputAttrNames(
    +      inputCol: StructField,
    +      outputColName: String): Option[Array[String]] = {
    +    val inputAttr = Attribute.fromStructField(inputCol)
    +    inputAttr match {
    +      case nominal: NominalAttribute =>
    +        if (nominal.values.isDefined) {
    +          nominal.values
    +        } else if (nominal.numValues.isDefined) {
    + => Array.tabulate(n)(_.toString))
    +        } else {
    +          None
    +        }
    +      case binary: BinaryAttribute =>
    +        if (binary.values.isDefined) {
    +          binary.values
    +        } else {
    +          Some(Array.tabulate(2)(_.toString))
    +        }
    +      case _: NumericAttribute =>
    +        throw new RuntimeException(
    +          s"The input column ${} cannot be numeric.")
    +      case _ =>
    +        None // optimistic about unknown attributes
    +    }
    +  }
    +  /** Creates an `AttributeGroup` filled by the `BinaryAttribute` named as 
required. */
    +  private def genOutputAttrGroup(
    +      outputAttrNames: Option[Array[String]],
    +      outputColName: String): AttributeGroup = {
    + { attrNames =>
    +      val attrs: Array[Attribute] = { name =>
    +        BinaryAttribute.defaultAttr.withName(name)
    +      }
    +      new AttributeGroup(outputColName, attrs)
    +    }.getOrElse{
    +      new AttributeGroup(outputColName)
    +    }
    +  }
    +  /**
    +   * Prepares the `StructField` with proper metadata for `OneHotEncoder`'s 
output column.
    +   */
    +  def transformOutputColumnSchema(
    +      inputCol: StructField,
    +      dropLast: Boolean,
    +      outputColName: String): StructField = {
    +    val outputAttrNames = genOutputAttrNames(inputCol, outputColName)
    +    val filteredOutputAttrNames = { names =>
    +      if (dropLast) {
    +        require(names.length > 1,
    +          s"The input column ${} should have at least two 
distinct values.")
    +        names.dropRight(1)
    +      } else {
    +        names
    +      }
    +    }
    +    genOutputAttrGroup(filteredOutputAttrNames, 
    +  }
    +  /**
    +   * This method is called when we want to generate `AttributeGroup` from 
actual data for
    +   * one-hot encoder.
    +   */
    +  def getOutputAttrGroupFromData(
    --- End diff --
    The rdd approach has advantage of early-stop if any values are invalid. It 
seems that dataframe ops don't have equivalent functions.
    We can only check if max values are more than Int.MaxValue after 
aggregation. Seems that we also need to compute `min` aggregation function for 
columns, so we can check if any values are less than zero.
    So currently I think I will modify this to multi-column version but still 
use rdd approach. Sounds good to you?


To unsubscribe, e-mail:
For additional commands, e-mail:

Reply via email to