This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 0cbe863e77c0 [SPARK-45547][ML] Validate Vectors with built-in function
0cbe863e77c0 is described below

commit 0cbe863e77c00e8987ddb170bdac5db4508173d7
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Tue Oct 24 07:58:11 2023 +0800

    [SPARK-45547][ML] Validate Vectors with built-in function
    
    ### What changes were proposed in this pull request?
    Validate Vectors with built-in function
    
    ### Why are the changes needed?
    with built-in function, the logic might be optimized further
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    ci
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #43380 from zhengruifeng/ml_vec_validate.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../spark/ml/classification/NaiveBayes.scala       | 23 +++--------
 .../apache/spark/ml/feature/VectorSizeHint.scala   | 47 +++++++++-------------
 .../org/apache/spark/ml/util/DatasetUtils.scala    | 12 +-----
 3 files changed, 27 insertions(+), 55 deletions(-)

diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 16176136a7e8..b7f9f97585fc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -156,38 +156,27 @@ class NaiveBayes @Since("1.5.0") (
 
     val validatedWeightCol = checkNonNegativeWeights(get(weightCol))
 
+    val vecCol = col($(featuresCol))
     val validatedfeaturesCol = $(modelType) match {
       case Multinomial | Complement =>
-        val checkNonNegativeVector = udf { vector: Vector =>
-          vector match {
-            case dv: DenseVector => dv.values.forall(v => v >= 0 && 
!v.isInfinity)
-            case sv: SparseVector => sv.values.forall(v => v >= 0 && 
!v.isInfinity)
-          }
-        }
-        val vecCol = col($(featuresCol))
         when(vecCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
-          .when(!checkNonNegativeVector(vecCol),
+          .when(exists(unwrap_udt(vecCol).getField("values"),
+            v => v.isNaN || v < 0 || v === Double.PositiveInfinity),
             raise_error(concat(
               lit("Vector values MUST NOT be Negative, NaN or Infinity, but 
got "),
               vecCol.cast(StringType))))
           .otherwise(vecCol)
 
       case Bernoulli =>
-        val checkBinaryVector = udf { vector: Vector =>
-          vector match {
-            case dv: DenseVector => dv.values.forall(v => v == 0 || v == 1)
-            case sv: SparseVector => sv.values.forall(v => v == 0 || v == 1)
-          }
-        }
-        val vecCol = col($(featuresCol))
         when(vecCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
-          .when(!checkBinaryVector(vecCol),
+          .when(exists(unwrap_udt(vecCol).getField("values"),
+            v => v =!= 0 && v =!= 1),
             raise_error(concat(
               lit("Vector values MUST be in {0, 1}, but got "),
               vecCol.cast(StringType))))
           .otherwise(vecCol)
 
-      case _ => checkNonNanVectors($(featuresCol))
+      case _ => checkNonNanVectors(vecCol)
     }
 
     val validated = dataset.select(
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
index 2cf440efae85..5c96d07e0ca9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -17,17 +17,16 @@
 
 package org.apache.spark.ml.feature
 
-import org.apache.spark.SparkException
 import org.apache.spark.annotation.Since
 import org.apache.spark.ml.Transformer
 import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.linalg.VectorUDT
 import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
 import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
 import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable}
-import org.apache.spark.sql.{Column, DataFrame, Dataset}
-import org.apache.spark.sql.functions.{col, udf}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{StringType, StructType}
 
 /**
  * A feature transformer that adds size information to the metadata of a 
vector column.
@@ -104,33 +103,25 @@ class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") 
override val uid: String)
     if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size 
== localSize) {
       dataset.toDF()
     } else {
-      val newCol: Column = localHandleInvalid match {
-        case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol)
+      val vecCol = col(localInputCol)
+      val sizeCol = coalesce(unwrap_udt(vecCol).getField("size"),
+        array_size(unwrap_udt(vecCol).getField("values")))
+      val newVecCol = localHandleInvalid match {
+        case VectorSizeHint.OPTIMISTIC_INVALID => vecCol
         case VectorSizeHint.ERROR_INVALID =>
-          val checkVectorSizeUDF = udf { vector: Vector =>
-            if (vector == null) {
-              throw new SparkException(s"Got null vector in VectorSizeHint, 
set `handleInvalid` " +
-                s"to 'skip' to filter invalid rows.")
-            }
-            if (vector.size != localSize) {
-              throw new SparkException(s"VectorSizeHint Expecting a vector of 
size $localSize but" +
-                s" got ${vector.size}")
-            }
-            vector
-          }.asNondeterministic()
-          checkVectorSizeUDF(col(localInputCol))
+          when(vecCol.isNull, raise_error(
+            lit("Got null vector in VectorSizeHint, set `handleInvalid` to 
'skip' to " +
+              "filter invalid rows.")))
+            .when(sizeCol =!= localSize, raise_error(concat(
+              lit(s"VectorSizeHint Expecting a vector of size $localSize but 
got "),
+              sizeCol.cast(StringType))))
+            .otherwise(vecCol)
         case VectorSizeHint.SKIP_INVALID =>
-          val checkVectorSizeUDF = udf { vector: Vector =>
-            if (vector != null && vector.size == localSize) {
-              vector
-            } else {
-              null
-            }
-          }
-          checkVectorSizeUDF(col(localInputCol))
+          when(!vecCol.isNull && sizeCol === localSize, vecCol)
+            .otherwise(lit(null))
       }
 
-      val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, 
newGroup.toMetadata()))
+      val res = dataset.withColumn(localInputCol, newVecCol, 
newGroup.toMetadata())
       if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) {
         res.na.drop(Array(localInputCol))
       } else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
index 08ecdaf0196c..b3cb9c7f2dd1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/DatasetUtils.scala
@@ -83,7 +83,8 @@ private[spark] object DatasetUtils extends Logging {
 
   private[ml] def checkNonNanVectors(vectorCol: Column): Column = {
     when(vectorCol.isNull, raise_error(lit("Vectors MUST NOT be Null")))
-      .when(!validateVector(vectorCol),
+      .when(exists(unwrap_udt(vectorCol).getField("values"),
+        v => v.isNaN || v === Double.NegativeInfinity || v === 
Double.PositiveInfinity),
         raise_error(concat(lit("Vector values MUST NOT be NaN or Infinity, but 
got "),
           vectorCol.cast(StringType))))
       .otherwise(vectorCol)
@@ -93,15 +94,6 @@ private[spark] object DatasetUtils extends Logging {
     checkNonNanVectors(col(vectorCol))
   }
 
-  private lazy val validateVector = udf { vector: Vector =>
-    vector match {
-      case dv: DenseVector =>
-        dv.values.forall(v => !v.isNaN && !v.isInfinity)
-      case sv: SparseVector =>
-        sv.values.forall(v => !v.isNaN && !v.isInfinity)
-    }
-  }
-
   private[ml] def extractInstances(
       p: PredictorParams,
       df: Dataset[_],


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to