Repository: spark
Updated Branches:
  refs/heads/master 13190a4f6 -> d23dc5b8e


[SPARK-22346][ML] VectorSizeHint Transformer for using VectorAssembler in 
StructuredSteaming

## What changes were proposed in this pull request?

A new VectorSizeHint transformer was added. This transformer is meant to be 
used as a pipeline stage ahead of VectorAssembler, on vector columns, so that 
VectorAssembler can join vectors in a streaming context where the size of the 
input vectors is otherwise not known.

## How was this patch tested?

Unit tests.

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Bago Amirbekian <b...@databricks.com>

Closes #19746 from MrBago/vector-size-hint.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d23dc5b8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d23dc5b8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d23dc5b8

Branch: refs/heads/master
Commit: d23dc5b8ef6c6aee0a31a304eefeb6ddb1c26c0f
Parents: 13190a4
Author: Bago Amirbekian <b...@databricks.com>
Authored: Fri Dec 22 14:05:57 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Fri Dec 22 14:05:57 2017 -0800

----------------------------------------------------------------------
 .../spark/ml/feature/VectorSizeHint.scala       | 195 +++++++++++++++++++
 .../spark/ml/feature/VectorSizeHintSuite.scala  | 189 ++++++++++++++++++
 2 files changed, 384 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d23dc5b8/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
new file mode 100644
index 0000000..1fe3cfc
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSizeHint.scala
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.Transformer
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.{HasHandleInvalid, HasInputCol}
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable}
+import org.apache.spark.sql.{Column, DataFrame, Dataset}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ * A feature transformer that adds size information to the metadata of a 
vector column.
+ * VectorAssembler needs size information for its input columns and cannot be 
used on streaming
+ * dataframes without this metadata.
+ *
+ */
+@Experimental
+@Since("2.3.0")
+class VectorSizeHint @Since("2.3.0") (@Since("2.3.0") override val uid: String)
+  extends Transformer with HasInputCol with HasHandleInvalid with 
DefaultParamsWritable {
+
+  @Since("2.3.0")
+  def this() = this(Identifiable.randomUID("vectSizeHint"))
+
+  /**
+   * The size of Vectors in `inputCol`.
+   * @group param
+   */
+  @Since("2.3.0")
+  val size: IntParam = new IntParam(
+    this,
+    "size",
+    "Size of vectors in column.",
+    {s: Int => s >= 0})
+
+  /** group getParam */
+  @Since("2.3.0")
+  def getSize: Int = getOrDefault(size)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setSize(value: Int): this.type = set(size, value)
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /**
+   * Param for how to handle invalid entries. Invalid vectors include nulls 
and vectors with the
+   * wrong size. The options are `skip` (filter out rows with invalid 
vectors), `error` (throw an
+   * error) and `optimistic` (do not check the vector size, and keep all 
rows). `error` by default.
+   *
+   * Note: Users should take care when setting this param to `optimistic`. The 
use of the
+   * `optimistic` option will prevent the transformer from validating the 
sizes of vectors in
+   * `inputCol`. A mismatch between the metadata of a column and its contents 
could result in
+   * unexpected behaviour or errors when using that column.
+   *
+   * @group param
+   */
+  @Since("2.3.0")
+  override val handleInvalid: Param[String] = new Param[String](
+    this,
+    "handleInvalid",
+    "How to handle invalid vectors in inputCol. Invalid vectors include nulls 
and vectors with " +
+      "the wrong size. The options are `skip` (filter out rows with invalid 
vectors), `error` " +
+      "(throw an error) and `optimistic` (do not check the vector size, and 
keep all rows). " +
+      "`error` by default.",
+    ParamValidators.inArray(VectorSizeHint.supportedHandleInvalids))
+
+  /** @group setParam */
+  @Since("2.3.0")
+  def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+  setDefault(handleInvalid, VectorSizeHint.ERROR_INVALID)
+
+  @Since("2.3.0")
+  override def transform(dataset: Dataset[_]): DataFrame = {
+    val localInputCol = getInputCol
+    val localSize = getSize
+    val localHandleInvalid = getHandleInvalid
+
+    val group = AttributeGroup.fromStructField(dataset.schema(localInputCol))
+    val newGroup = validateSchemaAndSize(dataset.schema, group)
+    if (localHandleInvalid == VectorSizeHint.OPTIMISTIC_INVALID && group.size 
== localSize) {
+      dataset.toDF()
+    } else {
+      val newCol: Column = localHandleInvalid match {
+        case VectorSizeHint.OPTIMISTIC_INVALID => col(localInputCol)
+        case VectorSizeHint.ERROR_INVALID =>
+          val checkVectorSizeUDF = udf { vector: Vector =>
+            if (vector == null) {
+              throw new SparkException(s"Got null vector in VectorSizeHint, 
set `handleInvalid` " +
+                s"to 'skip' to filter invalid rows.")
+            }
+            if (vector.size != localSize) {
+              throw new SparkException(s"VectorSizeHint Expecting a vector of 
size $localSize but" +
+                s" got ${vector.size}")
+            }
+            vector
+          }.asNondeterministic()
+          checkVectorSizeUDF(col(localInputCol))
+        case VectorSizeHint.SKIP_INVALID =>
+          val checkVectorSizeUDF = udf { vector: Vector =>
+            if (vector != null && vector.size == localSize) {
+              vector
+            } else {
+              null
+            }
+          }
+          checkVectorSizeUDF(col(localInputCol))
+      }
+
+      val res = dataset.withColumn(localInputCol, newCol.as(localInputCol, 
newGroup.toMetadata()))
+      if (localHandleInvalid == VectorSizeHint.SKIP_INVALID) {
+        res.na.drop(Array(localInputCol))
+      } else {
+        res
+      }
+    }
+  }
+
+  /**
+   * Checks that schema can be updated with new size and returns a new 
attribute group with
+   * updated size.
+   */
+  private def validateSchemaAndSize(schema: StructType, group: 
AttributeGroup): AttributeGroup = {
+    // This will throw a NoSuchElementException if params are not set.
+    val localSize = getSize
+    val localInputCol = getInputCol
+
+    val inputColType = schema(getInputCol).dataType
+    require(
+      inputColType.isInstanceOf[VectorUDT],
+      s"Input column, $getInputCol must be of Vector type, got $inputColType"
+    )
+    group.size match {
+      case `localSize` => group
+      case -1 => new AttributeGroup(localInputCol, localSize)
+      case _ =>
+        val msg = s"Trying to set size of vectors in `$localInputCol` to 
$localSize but size " +
+          s"already set to ${group.size}."
+        throw new IllegalArgumentException(msg)
+    }
+  }
+
+  @Since("2.3.0")
+  override def transformSchema(schema: StructType): StructType = {
+    val fieldIndex = schema.fieldIndex(getInputCol)
+    val fields = schema.fields.clone()
+    val inputField = fields(fieldIndex)
+    val group = AttributeGroup.fromStructField(inputField)
+    val newGroup = validateSchemaAndSize(schema, group)
+    fields(fieldIndex) = inputField.copy(metadata = newGroup.toMetadata())
+    StructType(fields)
+  }
+
+  @Since("2.3.0")
+  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+}
+
+/** :: Experimental :: */
+@Experimental
+@Since("2.3.0")
+object VectorSizeHint extends DefaultParamsReadable[VectorSizeHint] {
+
+  private[feature] val OPTIMISTIC_INVALID = "optimistic"
+  private[feature] val ERROR_INVALID = "error"
+  private[feature] val SKIP_INVALID = "skip"
+  private[feature] val supportedHandleInvalids: Array[String] =
+    Array(OPTIMISTIC_INVALID, ERROR_INVALID, SKIP_INVALID)
+
+  @Since("2.3.0")
+  override def load(path: String): VectorSizeHint = super.load(path)
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/d23dc5b8/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
new file mode 100644
index 0000000..f6c9a76
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+
+class VectorSizeHintSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+  import testImplicits._
+
+  test("Test Param Validators") {
+    intercept[IllegalArgumentException] (new 
VectorSizeHint().setHandleInvalid("invalidValue"))
+    intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3))
+  }
+
+  test("Required params must be set before transform.") {
+    val data = Seq((Vectors.dense(1, 2), 0)).toDF("vector", "intValue")
+
+    val noSizeTransformer = new VectorSizeHint().setInputCol("vector")
+    intercept[NoSuchElementException] (noSizeTransformer.transform(data))
+    intercept[NoSuchElementException] 
(noSizeTransformer.transformSchema(data.schema))
+
+    val noInputColTransformer = new VectorSizeHint().setSize(2)
+    intercept[NoSuchElementException] (noInputColTransformer.transform(data))
+    intercept[NoSuchElementException] 
(noInputColTransformer.transformSchema(data.schema))
+  }
+
+  test("Adding size to column of vectors.") {
+
+    val size = 3
+    val vectorColName = "vector"
+    val denseVector = Vectors.dense(1, 2, 3)
+    val sparseVector = Vectors.sparse(size, Array(), Array())
+
+    val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
+    val dataFrame = data.toDF(vectorColName)
+    assert(
+      AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == 
-1,
+      s"This test requires that column '$vectorColName' not have size 
metadata.")
+
+    for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+      val transformer = new VectorSizeHint()
+        .setInputCol(vectorColName)
+        .setSize(size)
+        .setHandleInvalid(handleInvalid)
+      val withSize = transformer.transform(dataFrame)
+      assert(
+        AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == 
size,
+        "Transformer did not add expected size data.")
+      val numRows = withSize.collect().length
+      assert(numRows === data.length, s"Expecting ${data.length} rows, got 
$numRows.")
+    }
+  }
+
+  test("Size hint preserves attributes.") {
+
+    val size = 3
+    val vectorColName = "vector"
+    val data = Seq((1, 2, 3), (2, 3, 3))
+    val dataFrame = data.toDF("x", "y", "z")
+
+    val assembler = new VectorAssembler()
+      .setInputCols(Array("x", "y", "z"))
+      .setOutputCol(vectorColName)
+    val dataFrameWithMetadata = assembler.transform(dataFrame)
+    val group = 
AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName))
+
+    for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+      val transformer = new VectorSizeHint()
+        .setInputCol(vectorColName)
+        .setSize(size)
+        .setHandleInvalid(handleInvalid)
+      val withSize = transformer.transform(dataFrameWithMetadata)
+
+      val newGroup = 
AttributeGroup.fromStructField(withSize.schema(vectorColName))
+      assert(newGroup.size === size, "Column has incorrect size metadata.")
+      assert(
+        newGroup.attributes.get === group.attributes.get,
+        "VectorSizeHint did not preserve attributes.")
+      withSize.collect
+    }
+  }
+
+  test("Size mismatch between current and target size raises an error.") {
+    val size = 4
+    val vectorColName = "vector"
+    val data = Seq((1, 2, 3), (2, 3, 3))
+    val dataFrame = data.toDF("x", "y", "z")
+
+    val assembler = new VectorAssembler()
+      .setInputCols(Array("x", "y", "z"))
+      .setOutputCol(vectorColName)
+    val dataFrameWithMetadata = assembler.transform(dataFrame)
+
+    for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+      val transformer = new VectorSizeHint()
+        .setInputCol(vectorColName)
+        .setSize(size)
+        .setHandleInvalid(handleInvalid)
+      
intercept[IllegalArgumentException](transformer.transform(dataFrameWithMetadata))
+    }
+  }
+
+  test("Handle invalid does the right thing.") {
+
+    val vector = Vectors.dense(1, 2, 3)
+    val short = Vectors.dense(2)
+    val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector")
+    val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector")
+
+    val sizeHint = new VectorSizeHint()
+      .setInputCol("vector")
+      .setHandleInvalid("error")
+      .setSize(3)
+
+    intercept[SparkException](sizeHint.transform(dataWithNull).collect())
+    intercept[SparkException](sizeHint.transform(dataWithShort).collect())
+
+    sizeHint.setHandleInvalid("skip")
+    assert(sizeHint.transform(dataWithNull).count() === 1)
+    assert(sizeHint.transform(dataWithShort).count() === 1)
+
+    sizeHint.setHandleInvalid("optimistic")
+    assert(sizeHint.transform(dataWithNull).count() === 2)
+    assert(sizeHint.transform(dataWithShort).count() === 2)
+  }
+
+  test("read/write") {
+    val sizeHint = new VectorSizeHint()
+      .setInputCol("myInputCol")
+      .setSize(11)
+      .setHandleInvalid("skip")
+    testDefaultReadWrite(sizeHint)
+  }
+}
+
+class VectorSizeHintStreamingSuite extends StreamTest {
+
+  import testImplicits._
+
+  test("Test assemble vectors with size hint in streaming.") {
+    val a = Vectors.dense(0, 1, 2)
+    val b = Vectors.sparse(4, Array(0, 3), Array(3, 6))
+
+    val stream = MemoryStream[(Vector, Vector)]
+    val streamingDF = stream.toDS.toDF("a", "b")
+    val sizeHintA = new VectorSizeHint()
+      .setSize(3)
+      .setInputCol("a")
+    val sizeHintB = new VectorSizeHint()
+      .setSize(4)
+      .setInputCol("b")
+    val vectorAssembler = new VectorAssembler()
+      .setInputCols(Array("a", "b"))
+      .setOutputCol("assembled")
+    val pipeline = new Pipeline().setStages(Array(sizeHintA, sizeHintB, 
vectorAssembler))
+    val output = 
pipeline.fit(streamingDF).transform(streamingDF).select("assembled")
+
+    val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6)
+
+    testStream (output) (
+      AddData(stream, (a, b), (a, b)),
+      CheckAnswer(Tuple1(expected), Tuple1(expected))
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to