Repository: spark
Updated Branches:
  refs/heads/master d1d5069aa -> 0557a4545


[SPARK-16750][ML] Fix GaussianMixture training failed due to feature column 
type mistake

## What changes were proposed in this pull request?
ML ```GaussianMixture``` training failed due to feature column type mistake. 
The feature column type should be ```ml.linalg.VectorUDT``` but got 
```mllib.linalg.VectorUDT``` by mistake.
See [SPARK-16750](https://issues.apache.org/jira/browse/SPARK-16750) for how to 
reproduce this bug.
Why the unit tests did not complain this errors? Because some 
estimators/transformers missed calling ```transformSchema(dataset.schema)``` 
firstly during ```fit``` or ```transform```. I will also add this function to 
all estimators/transformers who missed in this PR.

## How was this patch tested?
No new tests, should pass existing ones.

Author: Yanbo Liang <[email protected]>

Closes #14378 from yanboliang/spark-16750.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0557a454
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0557a454
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0557a454

Branch: refs/heads/master
Commit: 0557a45452f6e73877e5ec972110825ce8f3fbc5
Parents: d1d5069
Author: Yanbo Liang <[email protected]>
Authored: Fri Jul 29 04:40:20 2016 -0700
Committer: Sean Owen <[email protected]>
Committed: Fri Jul 29 04:40:20 2016 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/clustering/BisectingKMeans.scala     | 2 ++
 .../org/apache/spark/ml/clustering/GaussianMixture.scala     | 8 +++++---
 .../main/scala/org/apache/spark/ml/clustering/KMeans.scala   | 2 ++
 .../main/scala/org/apache/spark/ml/feature/Interaction.scala | 1 +
 .../scala/org/apache/spark/ml/feature/MinMaxScaler.scala     | 1 +
 .../org/apache/spark/ml/feature/QuantileDiscretizer.scala    | 3 ++-
 .../main/scala/org/apache/spark/ml/feature/RFormula.scala    | 1 +
 .../scala/org/apache/spark/ml/feature/SQLTransformer.scala   | 1 +
 .../apache/spark/ml/regression/AFTSurvivalRegression.scala   | 4 ++--
 .../org/apache/spark/ml/regression/IsotonicRegression.scala  | 3 ++-
 10 files changed, 19 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index afb1080..a97bd0f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -99,6 +99,7 @@ class BisectingKMeansModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val predictUDF = udf((vector: Vector) => predict(vector))
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
@@ -222,6 +223,7 @@ class BisectingKMeans @Since("2.0.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): BisectingKMeansModel = {
+    transformSchema(dataset.schema, logging = true)
     val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 8174905..69f060a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -30,7 +30,7 @@ import 
org.apache.spark.ml.stat.distribution.MultivariateGaussian
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM}
 import org.apache.spark.mllib.linalg.{Matrices => OldMatrices, Matrix => 
OldMatrix,
-  Vector => OldVector, Vectors => OldVectors, VectorUDT => OldVectorUDT}
+  Vector => OldVector, Vectors => OldVectors}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
 import org.apache.spark.sql.functions.{col, udf}
@@ -61,9 +61,9 @@ private[clustering] trait GaussianMixtureParams extends 
Params with HasMaxIter w
    * @return output schema
    */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(featuresCol), new OldVectorUDT)
+    SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
     SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
-    SchemaUtils.appendColumn(schema, $(probabilityCol), new OldVectorUDT)
+    SchemaUtils.appendColumn(schema, $(probabilityCol), new VectorUDT)
   }
 }
 
@@ -95,6 +95,7 @@ class GaussianMixtureModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val predUDF = udf((vector: Vector) => predict(vector))
     val probUDF = udf((vector: Vector) => predictProbability(vector))
     dataset.withColumn($(predictionCol), predUDF(col($(featuresCol))))
@@ -317,6 +318,7 @@ class GaussianMixture @Since("2.0.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): GaussianMixtureModel = {
+    transformSchema(dataset.schema, logging = true)
     val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala 
b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 9fb7d6a..6c46be7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -120,6 +120,7 @@ class KMeansModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val predictUDF = udf((vector: Vector) => predict(vector))
     dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
   }
@@ -304,6 +305,7 @@ class KMeans @Since("1.5.0") (
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): KMeansModel = {
+    transformSchema(dataset.schema, logging = true)
     val rdd: RDD[OldVector] = dataset.select(col($(featuresCol))).rdd.map {
       case Row(point: Vector) => OldVectors.fromML(point)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 7b11f86..96d0bde 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -68,6 +68,7 @@ class Interaction @Since("1.6.0") (@Since("1.6.0") override 
val uid: String) ext
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val inputFeatures = $(inputCols).map(c => dataset.schema(c))
     val featureEncoders = getFeatureEncoders(inputFeatures)
     val featureAttrs = getFeatureAttrs(inputFeatures)

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 9ed8d83..068f11a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -170,6 +170,7 @@ class MinMaxScalerModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val originalRange = (originalMax.asBreeze - originalMin.asBreeze).toArray
     val minArray = originalMin.toArray
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 9a636bd..558a7bb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -97,7 +97,7 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
 
   @Since("1.6.0")
   override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
+    SchemaUtils.checkNumericType(schema, $(inputCol))
     val inputFields = schema.fields
     require(inputFields.forall(_.name != $(outputCol)),
       s"Output column ${$(outputCol)} already exists.")
@@ -108,6 +108,7 @@ final class QuantileDiscretizer @Since("1.6.0") 
(@Since("1.6.0") override val ui
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): Bucketizer = {
+    transformSchema(dataset.schema, logging = true)
     val splits = dataset.stat.approxQuantile($(inputCol),
       (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
     splits(0) = Double.NegativeInfinity

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index c95dacf..2ee899b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -112,6 +112,7 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override 
val uid: String)
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): RFormulaModel = {
+    transformSchema(dataset.schema, logging = true)
     require(isDefined(formula), "Formula must be defined first.")
     val parsedFormula = RFormulaParser.parse($(formula))
     val resolvedFormula = parsedFormula.resolve(dataset.schema)

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index 2890376..259be26 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -63,6 +63,7 @@ class SQLTransformer @Since("1.6.0") (@Since("1.6.0") 
override val uid: String)
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val tableName = Identifiable.randomUID(uid)
     dataset.createOrReplaceTempView(tableName)
     val realStatement = $(statement).replace(tableIdentifier, tableName)

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 2b99126..d4ae59d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -196,7 +196,7 @@ class AFTSurvivalRegression @Since("1.6.0") 
(@Since("1.6.0") override val uid: S
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
-    validateAndTransformSchema(dataset.schema, fitting = true)
+    transformSchema(dataset.schema, logging = true)
     val instances = extractAFTPoints(dataset)
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
     if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
@@ -326,7 +326,7 @@ class AFTSurvivalRegressionModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
-    transformSchema(dataset.schema)
+    transformSchema(dataset.schema, logging = true)
     val predictUDF = udf { features: Vector => predict(features) }
     val predictQuantilesUDF = udf { features: Vector => 
predictQuantiles(features)}
     if (hasQuantilesCol) {

http://git-wip-us.apache.org/repos/asf/spark/blob/0557a454/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 3539644..cd7b4f2 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -164,7 +164,7 @@ class IsotonicRegression @Since("1.5.0") (@Since("1.5.0") 
override val uid: Stri
 
   @Since("2.0.0")
   override def fit(dataset: Dataset[_]): IsotonicRegressionModel = {
-    validateAndTransformSchema(dataset.schema, fitting = true)
+    transformSchema(dataset.schema, logging = true)
     // Extract columns from data.  If dataset is persisted, do not persist 
oldDataset.
     val instances = extractWeightedLabeledPoints(dataset)
     val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
@@ -234,6 +234,7 @@ class IsotonicRegressionModel private[ml] (
 
   @Since("2.0.0")
   override def transform(dataset: Dataset[_]): DataFrame = {
+    transformSchema(dataset.schema, logging = true)
     val predict = dataset.schema($(featuresCol)).dataType match {
       case DoubleType =>
         udf { feature: Double => oldModel.predict(feature) }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to