Github user MLnick commented on a diff in the pull request: https://github.com/apache/spark/pull/19715#discussion_r155759105 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala --- @@ -129,34 +156,106 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("2.1.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + /** @group setParam */ + @Since("2.3.0") + def setNumBucketsArray(value: Array[Int]): this.type = set(numBucketsArray, value) + + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private[feature] def isQuantileDiscretizeMultipleColumns(): Boolean = { + if (isSet(inputCols) && isSet(inputCol)) { + logWarning("Both `inputCol` and `inputCols` are set, we ignore `inputCols` and this " + + "`QuantileDiscretizer` will only map one column specified by `inputCol`") + false + } else if (isSet(inputCols)) { + true + } else { + false + } + } + + private[feature] def getInOutCols: (Array[String], Array[String]) = { + if (!isQuantileDiscretizeMultipleColumns) { + (Array($(inputCol)), Array($(outputCol))) + } else { + require($(inputCols).length == $(outputCols).length, + "inputCols number do not match outputCols") + ($(inputCols), $(outputCols)) + } + } + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkNumericType(schema, $(inputCol)) - val inputFields = schema.fields - require(inputFields.forall(_.name != $(outputCol)), - s"Output column ${$(outputCol)} already exists.") - val attr = NominalAttribute.defaultAttr.withName($(outputCol)) - val outputFields = inputFields :+ attr.toStructField() + val (inputColNames, outputColNames) = getInOutCols + val existingFields = schema.fields + var outputFields = existingFields + inputColNames.zip(outputColNames).map { case (inputColName, outputColName) => + SchemaUtils.checkNumericType(schema, inputColName) + require(existingFields.forall(_.name != outputColName), + s"Output column ${outputColName} already exists.") + val attr = NominalAttribute.defaultAttr.withName(outputColName) + outputFields :+= attr.toStructField() + } StructType(outputFields) } @Since("2.0.0") override def fit(dataset: Dataset[_]): Bucketizer = { transformSchema(dataset.schema, logging = true) - val splits = dataset.stat.approxQuantile($(inputCol), - (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) --- End diff -- Looking at this now, the `Array.fill` approach probably adds needless complexity. But the multi-buckets case can perhaps still be cleaned up. How about something like this: ```scala override def fit(dataset: Dataset[_]): Bucketizer = { transformSchema(dataset.schema, logging = true) val bucketizer = new Bucketizer(uid).setHandleInvalid($(handleInvalid)) if (isQuantileDiscretizeMultipleColumns) { val splitsArray = if (isSet(numBucketsArray)) { val probArrayPerCol = $(numBucketsArray).map { numOfBuckets => (0.0 to 1.0 by 1.0 / numOfBuckets).toArray } val probabilityArray = probArrayPerCol.flatten.sorted.distinct val splitsArrayRaw = dataset.stat.approxQuantile($(inputCols), probabilityArray, $(relativeError)) splitsArrayRaw.zip(probArrayPerCol).map { case (splits, probs) => val probSet = probs.toSet val idxSet = probabilityArray.zipWithIndex.collect { case (p, idx) if probSet(p) => idx }.toSet splits.zipWithIndex.collect { case (s, idx) if idxSet(idx) => s } } } else { dataset.stat.approxQuantile($(inputCols), (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) } bucketizer.setSplitsArray(splitsArray.map(getDistinctSplits)) } else { val splits = dataset.stat.approxQuantile($(inputCol), (0.0 to 1.0 by 1.0 / $(numBuckets)).toArray, $(relativeError)) bucketizer.setSplits(getDistinctSplits(splits)) } copyValues(bucketizer.setParent(this)) } ``` Then we don't need `getSplitsForEachColumn` method (or part of the above could be factored out into a private method if it makes sense).
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org