Github user huaxingao commented on a diff in the pull request: https://github.com/apache/spark/pull/19715#discussion_r150450222 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala --- @@ -129,34 +152,95 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.1.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.3.0") + def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = { + if (isSet(inputCols) && isSet(inputCol)) { + logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + + "`QuantileDiscretize` only map one column specified by `inputCol`") + false + } else if (isSet(inputCols)) { + true + } else { + false + } + } + + private[feature] def getInOutCols: (Array[String], Array[String]) = { + if (!isQuantileDiscretizeMultipleColumns) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "inputCols number do not match outputCols") + ($(inputCols), $(outputCols)) + } + } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - val inputFields = schema.fields - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val (inputColNames, outputColNames) = getInOutCols + val existingFields = schema.fields + var outputFields = existingFields + inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => + SchemaUtils.checkNumericType(schema, inputColName) + require(existingFields.forall(_.name != outputColName), + s"Output column ${outputColName} already exists.") + val attr = NominalAttribute.defaultAttr.withName(outputColName) + outputFields :+= attr.toStructField() + } StructType(outputFields) } @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { transformSchema(dataset.schema, logging = true) - val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) + if (isQuantileDiscretizeMultipleColumns) { + var bucketArray = Array.empty[Int] + if (isSet(numBucketsArray)) { + bucketArray = $(numBucketsArray) + } + else { + bucketArray = Array($(numBuckets)) + } + val probabilityArray = bucketArray.toSeq.flatMap { numOfBucket => + (0.0 to 1.0 by 1.0 / numOfBucket) + } + val splitsArray = dataset.stat.approxQuantile($(inputCols), + probabilityArray.sorted.toArray.distinct, $(relativeError)) + val distinctSplitsArray = splitsArray.toSeq.map { splits => + getDistinctSplits(splits) + } + bucketizer.setSplitsArray(distinctSplitsArray.toArray) + copyValues(bucketizer.setParent(this)) + } + else { --- End diff -- Will fix this. And fix the same problem in another place.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org