This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 56f93e5 [SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private 56f93e5 is described below commit 56f93e56ab731be27a05a299fcbe0ef529f280ba Author: Ruifeng Zheng <ruife...@foxmail.com> AuthorDate: Mon Jan 18 13:19:59 2021 +0900 [SPARK-34080][ML][PYTHON][FOLLOWUP] Add UnivariateFeatureSelector - make methods private ### What changes were proposed in this pull request? 1, make `getTopIndices`/`selectIndicesFromPValues` private; 2, avoid setting `selectionThreshold` in `fit` 3, move param checking to `transformSchema` ### Why are the changes needed? `getTopIndices`/`selectIndicesFromPValues` should not be exposed to end users; ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existing testsuites Closes #31222 from zhengruifeng/selector_clean_up. Authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: HyukjinKwon <gurwls...@apache.org> (cherry picked from commit ac322a1ac3be79b5e514f0119275f53b3a40c923) Signed-off-by: HyukjinKwon <gurwls...@apache.org> --- .../ml/feature/UnivariateFeatureSelector.scala | 74 ++++++++-------------- 1 file changed, 27 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala index 6d5f09e..bfe1d5f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/UnivariateFeatureSelector.scala @@ -76,8 +76,7 @@ private[feature] trait UnivariateFeatureSelectorParams extends Params @Since("3.1.1") final val selectionMode = new Param[String](this, "selectionMode", "The selection mode. Supported options: numTopFeatures, percentile, fpr, fdr, fwe", - ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr", - "fwe"))) + ParamValidators.inArray(Array("numTopFeatures", "percentile", "fpr", "fdr", "fwe"))) /** @group getParam */ @Since("3.1.1") @@ -161,48 +160,17 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v transformSchema(dataset.schema, logging = true) val numFeatures = MetadataUtils.getNumFeatures(dataset, $(featuresCol)) - $(selectionMode) match { - case ("numTopFeatures") => - if (!isSet(selectionThreshold)) { - set(selectionThreshold, 50.0) - } else { - require($(selectionThreshold) > 0 && $(selectionThreshold).toInt == $(selectionThreshold), - "selectionThreshold needs to be a positive Integer for selection mode numTopFeatures") - } - case ("percentile") => - if (!isSet(selectionThreshold)) { - set(selectionThreshold, 0.1) - } else { - require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1, - "selectionThreshold needs to be in the range of 0 to 1 for selection mode percentile") - } - case ("fpr") => - if (!isSet(selectionThreshold)) { - set(selectionThreshold, 0.05) - } else { - require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1, - "selectionThreshold needs to be in the range of 0 to 1 for selection mode fpr") - } - case ("fdr") => - if (!isSet(selectionThreshold)) { - set(selectionThreshold, 0.05) - } else { - require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1, - "selectionThreshold needs to be in the range of 0 to 1 for selection mode fdr") - } - case ("fwe") => - if (!isSet(selectionThreshold)) { - set(selectionThreshold, 0.05) - } else { - require($(selectionThreshold) >= 0 && $(selectionThreshold) <= 1, - "selectionThreshold needs to be in the range of 0 to 1 for selection mode fwe") - } - case _ => - throw new IllegalArgumentException(s"Unsupported selection mode:" + - s" selectionMode=${$(selectionMode)}") + var threshold = Double.NaN + if (isSet(selectionThreshold)) { + threshold = $(selectionThreshold) + } else { + $(selectionMode) match { + case "numTopFeatures" => threshold = 50 + case "percentile" => threshold = 0.1 + case "fpr" | "fdr" | "fwe" => threshold = 0.05 + } } - require(isSet(featureType) && isSet(labelType), "featureType and labelType need to be set") val resultDF = ($(featureType), $(labelType)) match { case ("categorical", "categorical") => ChiSquareTest.test(dataset.toDF, getFeaturesCol, getLabelCol, true) @@ -215,14 +183,12 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v s" featureType=${$(featureType)}, labelType=${$(labelType)}") } - val indices = - selectIndicesFromPValues(numFeatures, resultDF, $(selectionMode), $(selectionThreshold)) - + val indices = selectIndicesFromPValues(numFeatures, resultDF, $(selectionMode), threshold) copyValues(new UnivariateFeatureSelectorModel(uid, indices) .setParent(this)) } - def getTopIndices(df: DataFrame, k: Int): Array[Int] = { + private def getTopIndices(df: DataFrame, k: Int): Array[Int] = { val spark = SparkSession.builder().getOrCreate() import spark.implicits._ df.sort("pValue", "featureIndex") @@ -232,7 +198,7 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v .collect() } - def selectIndicesFromPValues( + private[feature] def selectIndicesFromPValues( numFeatures: Int, resultDF: DataFrame, selectionMode: String, @@ -276,6 +242,20 @@ final class UnivariateFeatureSelector @Since("3.1.1")(@Since("3.1.1") override v @Since("3.1.1") override def transformSchema(schema: StructType): StructType = { + if (isSet(selectionThreshold)) { + val threshold = $(selectionThreshold) + $(selectionMode) match { + case "numTopFeatures" => + require(threshold >= 1 && threshold.toInt == threshold, + s"selectionThreshold needs to be a positive Integer for selection mode " + + s"numTopFeatures, but got $threshold") + case "percentile" | "fpr" | "fdr" | "fwe" => + require(0 <= threshold && threshold <= 1, + s"selectionThreshold needs to be in the range [0, 1] for selection mode " + + s"${$(selectionMode)}, but got $threshold") + } + } + require(isSet(featureType) && isSet(labelType), "featureType and labelType need to be set") SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkNumericType(schema, $(labelCol)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org