Repository: spark Updated Branches: refs/heads/branch-2.3 f891ee324 -> 2ec302658
[SPARK-23046][ML][SPARKR] Have RFormula include VectorSizeHint in pipeline ## What changes were proposed in this pull request? Including VectorSizeHint in RFormula piplelines will allow them to be applied to streaming dataframes. ## How was this patch tested? Unit tests. Author: Bago Amirbekian <b...@databricks.com> Closes #20238 from MrBago/rFormulaVectorSize. (cherry picked from commit 186bf8fb2e9ff8a80f3f6bcb5f2a0327fa79a1c9) Signed-off-by: Joseph K. Bradley <jos...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2ec30265 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2ec30265 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2ec30265 Branch: refs/heads/branch-2.3 Commit: 2ec302658c98038962c9b7a90fd2cff751a35ffa Parents: f891ee3 Author: Bago Amirbekian <b...@databricks.com> Authored: Thu Jan 11 13:57:15 2018 -0800 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Thu Jan 11 13:57:27 2018 -0800 ---------------------------------------------------------------------- R/pkg/R/mllib_utils.R | 1 + .../org/apache/spark/ml/feature/RFormula.scala | 18 ++++++++-- .../apache/spark/ml/feature/RFormulaSuite.scala | 37 +++++++++++++++++--- 3 files changed, 48 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2ec30265/R/pkg/R/mllib_utils.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib_utils.R b/R/pkg/R/mllib_utils.R index a53c92c..23dda42 100644 --- a/R/pkg/R/mllib_utils.R +++ b/R/pkg/R/mllib_utils.R @@ -130,3 +130,4 @@ read.ml <- function(path) { stop("Unsupported model: ", jobj) } } + http://git-wip-us.apache.org/repos/asf/spark/blob/2ec30265/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 7da3339..f384ffb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.linalg.VectorUDT +import org.apache.spark.ml.linalg.{Vector, VectorUDT} import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasHandleInvalid, HasLabelCol} import org.apache.spark.ml.util._ @@ -210,8 +210,8 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) // First we index each string column referenced by the input terms. val indexed: Map[String, String] = resolvedFormula.terms.flatten.distinct.map { term => - dataset.schema(term) match { - case column if column.dataType == StringType => + dataset.schema(term).dataType match { + case _: StringType => val indexCol = tmpColumn("stridx") encoderStages += new StringIndexer() .setInputCol(term) @@ -220,6 +220,18 @@ class RFormula @Since("1.5.0") (@Since("1.5.0") override val uid: String) .setHandleInvalid($(handleInvalid)) prefixesToRewrite(indexCol + "_") = term + "_" (term, indexCol) + case _: VectorUDT => + val group = AttributeGroup.fromStructField(dataset.schema(term)) + val size = if (group.size < 0) { + dataset.select(term).first().getAs[Vector](0).size + } else { + group.size + } + encoderStages += new VectorSizeHint(uid) + .setHandleInvalid("optimistic") + .setInputCol(term) + .setSize(size) + (term, term) case _ => (term, term) } http://git-wip-us.apache.org/repos/asf/spark/blob/2ec30265/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index 5d09c90..f3f4b5a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -17,15 +17,15 @@ package org.apache.spark.ml.feature -import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.SparkException import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.linalg.Vectors +import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils} +import org.apache.spark.sql.{DataFrame, Encoder, Row} import org.apache.spark.sql.types.DoubleType -class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { +class RFormulaSuite extends MLTest with DefaultReadWriteTest { import testImplicits._ @@ -548,4 +548,31 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul assert(result3.collect() === expected3.collect()) assert(result4.collect() === expected4.collect()) } + + test("Use Vectors as inputs to formula.") { + val original = Seq( + (1, 4, Vectors.dense(0.0, 0.0, 4.0)), + (2, 4, Vectors.dense(1.0, 0.0, 4.0)), + (3, 5, Vectors.dense(1.0, 0.0, 5.0)), + (4, 5, Vectors.dense(0.0, 1.0, 5.0)) + ).toDF("id", "a", "b") + val formula = new RFormula().setFormula("id ~ a + b") + val (first +: rest) = Seq("id", "a", "b", "features", "label") + testTransformer[(Int, Int, Vector)](original, formula.fit(original), first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + + val group = new AttributeGroup("b", 3) + val vectorColWithMetadata = original("b").as("b", group.toMetadata()) + val dfWithMetadata = original.withColumn("b", vectorColWithMetadata) + val model = formula.fit(dfWithMetadata) + // model should work even when applied to dataframe without metadata. + testTransformer[(Int, Int, Vector)](original, model, first, rest: _*) { + case Row(id: Int, a: Int, b: Vector, features: Vector, label: Double) => + assert(label === id) + assert(features.toArray === a +: b.toArray) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org