This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 0cbe863e77c0 [SPARK-45547][ML] Validate Vectors with built-in function 0cbe863e77c0 is described below commit 0cbe863e77c00e8987ddb170bdac5db4508173d7 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Tue Oct 24 07:58:11 2023 +0800 [SPARK-45547][ML] Validate Vectors with built-in function ### What changes were proposed in this pull request? Validate Vectors with built-in function ### Why are the changes needed? with built-in function, the logic might be optimized further ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? ci ### Was this patch authored or co-authored using generative AI tooling? no Closes #43380 from zhengruifeng/ml_vec_validate. Authored-by: Ruifeng Zheng <ruife...@apache.org> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../spark/ml/classification/NaiveBayes.scala | 23 +++-------- .../apache/spark/ml/feature/VectorSizeHint.scala | 47 +++++++++------------- .../org/apache/spark/ml/util/DatasetUtils.scala | 12 +----- 3 files changed, 27 insertions(+), 55 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala index 16176136a7e8..b7f9f97585fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala @@ -156,38 +156,27 @@ class NaiveBayes @Since("1.5.0") ( val validatedWeightCol = checkNonNegativeWeights(get(weightCol)) + val vecCol = col($(featuresCol)) val validatedfeaturesCol = $(modelType) match { case Multinomial | Complement => - val checkNonNegativeVector = udf { vector: Vector => - vector match { - case dv: DenseVector => dv.values.forall(v => v >= 0 && !v.isInfinity) - case sv: SparseVector => sv.values.forall(v => v >= 0 && !v.isInfinity) - } - } - val vecCol = col($(featuresCol)) when(vecCol.isNull, raise_error(lit("Vectors MUST NOT be Null"))) - .when(!checkNonNegativeVector(vecCol), + .when(exists(unwrap_udt(vecCol).getField("values"), + v => v.isNaN || v < 0 || v === Double.PositiveInfinity), raise_error(concat( lit("Vector values MUST NOT be Negative, NaN or Infinity, but got "), vecCol.cast(StringType)))) .otherwise(vecCol) case Bernoulli => - val checkBinaryVector = udf { vector: Vector => - vector match { - case dv: DenseVector => dv.values.forall(v => v == 0 || v == 1) - case sv: SparseVector => sv.values.forall(v => v == 0 || v == 1) - } - } - val vecCol = col($(featuresCol)) when(vecCol.isNull, raise_error(lit("Vectors MUST NOT be Null"))) - .when(!checkBinaryVector(vecCol), + .when(exists(unwrap_udt(vecCol).getField("values"), + v => v =!= 0 && v =!= 1), raise_error(concat( lit("Vector values MUST be in {0, 1}, but got "), vecCol.cast(StringType)))) .otherwise(vecCol) - case _ => checkNonNanVectors($(featuresCol)) + case _ => checkNonNanVectors(vecCol) } val validated = dataset.select( diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala index 2cf440efae85..5c96d07e0ca9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala @@ -17,17 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.{Vector, VectorUDT} +import org.apache.spark.ml.linalg.VectorUDT import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol} import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} -import org.apache.spark.sql.{Column, DataFrame, Dataset} -import org.apache.spark.sql.functions.{col, udf} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StringType, StructType} /** * A feature transformer that adds size information to the metadata of a vector column. @@ -104,33 +103,25 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String) if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size == localSize) { dataset.toDF() } else { - val newCol: Column = localHandleInvalid match { - case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol) + val vecCol = col(localInputCol) + val sizeCol = coalesce(unwrap_udt(vecCol).getField("size"), + array_size(unwrap_udt(vecCol).getField("values"))) + val newVecCol = localHandleInvalid match { + case VectorSizeHint.OPTIMISTIC_INVALID => vecCol case VectorSizeHint.ERROR_INVALID => - val checkVectorSizeUDF = udf { vector: Vector => - if (vector == null) { - throw new SparkException(s"Got null vector in VectorSizeHint, set `handleInvalid` " + - s"to 'skip' to filter invalid rows.") - } - if (vector.size != localSize) { - throw new SparkException(s"VectorSizeHint Expecting a vector of size $localSize but" + - s" got ${vector.size}") - } - vector - }.asNondeterministic() - checkVectorSizeUDF(col(localInputCol)) + when(vecCol.isNull, raise_error( + lit("Got null vector in VectorSizeHint, set `handleInvalid` to 'skip' to " + + "filter invalid rows."))) + .when(sizeCol =!= localSize, raise_error(concat( + lit(s"VectorSizeHint Expecting a vector of size $localSize but got "), + sizeCol.cast(StringType)))) + .otherwise(vecCol) case VectorSizeHint.SKIP_INVALID => - val checkVectorSizeUDF = udf { vector: Vector => - if (vector != null && vector.size == localSize) { - vector - } else { - null - } - } - checkVectorSizeUDF(col(localInputCol)) + when(!vecCol.isNull && sizeCol === localSize, vecCol) + .otherwise(lit(null)) } - val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, newGroup.toMetadata())) + val res = dataset.withColumn(localInputCol, newVecCol, newGroup.toMetadata()) if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) { res.na.drop(Array(localInputCol)) } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala index 08ecdaf0196c..b3cb9c7f2dd1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala @@ -83,7 +83,8 @@ private[spark] object DatasetUtils extends Logging { private[ml] def checkNonNanVectors(vectorCol: Column): Column = { when(vectorCol.isNull, raise_error(lit("Vectors MUST NOT be Null"))) - .when(!validateVector(vectorCol), + .when(exists(unwrap_udt(vectorCol).getField("values"), + v => v.isNaN || v === Double.NegativeInfinity || v === Double.PositiveInfinity), raise_error(concat(lit("Vector values MUST NOT be NaN or Infinity, but got "), vectorCol.cast(StringType)))) .otherwise(vectorCol) @@ -93,15 +94,6 @@ private[spark] object DatasetUtils extends Logging { checkNonNanVectors(col(vectorCol)) } - private lazy val validateVector = udf { vector: Vector => - vector match { - case dv: DenseVector => - dv.values.forall(v => !v.isNaN && !v.isInfinity) - case sv: SparseVector => - sv.values.forall(v => !v.isNaN && !v.isInfinity) - } - } - private[ml] def extractInstances( p: PredictorParams, df: Dataset[_], --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org