Repository: spark
Updated Branches:
  refs/heads/master 0ba3fdd59 -> 971b95b0c


http://git-wip-us.apache.org/repos/asf/spark/blob/971b95b0/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
new file mode 100644
index 0000000..0383bf0
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+
+/**
+ * :: DeveloperApi ::
+ * Utils for handling schemas.
+ */
+@DeveloperApi
+object SchemaUtils {
+
+  // TODO: Move the utility methods to SQL.
+
+  /**
+   * Check whether the given schema contains a column of the required data 
type.
+   * @param colName  column name
+   * @param dataType  required column data type
+   */
+  def checkColumnType(schema: StructType, colName: String, dataType: 
DataType): Unit = {
+    val actualDataType = schema(colName).dataType
+    require(actualDataType.equals(dataType),
+      s"Column $colName must be of type $dataType but was actually 
$actualDataType.")
+  }
+
+  /**
+   * Appends a new column to the input schema. This fails if the given output 
column already exists.
+   * @param schema input schema
+   * @param colName new column name. If this column name is an empty string 
"", this method returns
+   *                the input schema unchanged. This allows users to disable 
output columns.
+   * @param dataType new column data type
+   * @return new schema with the input column appended
+   */
+  def appendColumn(
+      schema: StructType,
+      colName: String,
+      dataType: DataType): StructType = {
+    if (colName.isEmpty) return schema
+    val fieldNames = schema.fieldNames
+    require(!fieldNames.contains(colName), s"Column $colName already exists.")
+    val outputFields = schema.fields :+ StructField(colName, dataType, 
nullable = false)
+    StructType(outputFields)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/971b95b0/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 1ce2987..88ea679 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -21,19 +21,25 @@ import org.scalatest.FunSuite
 
 class ParamsSuite extends FunSuite {
 
-  val solver = new TestParams()
-  import solver.{inputCol, maxIter}
-
   test("param") {
+    val solver = new TestParams()
+    import solver.{maxIter, inputCol}
+
     assert(maxIter.name === "maxIter")
     assert(maxIter.doc === "max number of iterations")
-    assert(maxIter.defaultValue.get === 100)
     assert(maxIter.parent.eq(solver))
-    assert(maxIter.toString === "maxIter: max number of iterations (default: 
100)")
-    assert(inputCol.defaultValue === None)
+    assert(maxIter.toString === "maxIter: max number of iterations (default: 
10)")
+
+    solver.setMaxIter(5)
+    assert(maxIter.toString === "maxIter: max number of iterations (default: 
10, current: 5)")
+
+    assert(inputCol.toString === "inputCol: input column name (undefined)")
   }
 
   test("param pair") {
+    val solver = new TestParams()
+    import solver.maxIter
+
     val pair0 = maxIter -> 5
     val pair1 = maxIter.w(5)
     val pair2 = ParamPair(maxIter, 5)
@@ -44,10 +50,12 @@ class ParamsSuite extends FunSuite {
   }
 
   test("param map") {
+    val solver = new TestParams()
+    import solver.{maxIter, inputCol}
+
     val map0 = ParamMap.empty
 
     assert(!map0.contains(maxIter))
-    assert(map0(maxIter) === maxIter.defaultValue.get)
     map0.put(maxIter, 10)
     assert(map0.contains(maxIter))
     assert(map0(maxIter) === 10)
@@ -78,23 +86,39 @@ class ParamsSuite extends FunSuite {
   }
 
   test("params") {
+    val solver = new TestParams()
+    import solver.{maxIter, inputCol}
+
     val params = solver.params
-    assert(params.size === 2)
+    assert(params.length === 2)
     assert(params(0).eq(inputCol), "params must be ordered by name")
     assert(params(1).eq(maxIter))
+
+    assert(!solver.isSet(maxIter))
+    assert(solver.isDefined(maxIter))
+    assert(solver.getMaxIter === 10)
+    solver.setMaxIter(100)
+    assert(solver.isSet(maxIter))
+    assert(solver.getMaxIter === 100)
+    assert(!solver.isSet(inputCol))
+    assert(!solver.isDefined(inputCol))
+    intercept[NoSuchElementException](solver.getInputCol)
+
     assert(solver.explainParams() === Seq(inputCol, maxIter).mkString("\n"))
+
     assert(solver.getParam("inputCol").eq(inputCol))
     assert(solver.getParam("maxIter").eq(maxIter))
-    intercept[NoSuchMethodException] {
+    intercept[NoSuchElementException] {
       solver.getParam("abc")
     }
-    assert(!solver.isSet(inputCol))
+
     intercept[IllegalArgumentException] {
       solver.validate()
     }
     solver.validate(ParamMap(inputCol -> "input"))
     solver.setInputCol("input")
     assert(solver.isSet(inputCol))
+    assert(solver.isDefined(inputCol))
     assert(solver.getInputCol === "input")
     solver.validate()
     intercept[IllegalArgumentException] {
@@ -104,5 +128,8 @@ class ParamsSuite extends FunSuite {
     intercept[IllegalArgumentException] {
       solver.validate()
     }
+
+    solver.clearMaxIter()
+    assert(!solver.isSet(maxIter))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/971b95b0/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala 
b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
index ce52f2f..8f9ab68 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/TestParams.scala
@@ -20,17 +20,21 @@ package org.apache.spark.ml.param
 /** A subclass of Params for testing. */
 class TestParams extends Params {
 
-  val maxIter = new IntParam(this, "maxIter", "max number of iterations", 
Some(100))
+  val maxIter = new IntParam(this, "maxIter", "max number of iterations")
   def setMaxIter(value: Int): this.type = { set(maxIter, value); this }
-  def getMaxIter: Int = get(maxIter)
+  def getMaxIter: Int = getOrDefault(maxIter)
 
   val inputCol = new Param[String](this, "inputCol", "input column name")
   def setInputCol(value: String): this.type = { set(inputCol, value); this }
-  def getInputCol: String = get(inputCol)
+  def getInputCol: String = getOrDefault(inputCol)
+
+  setDefault(maxIter -> 10)
 
   override def validate(paramMap: ParamMap): Unit = {
-    val m = this.paramMap ++ paramMap
+    val m = extractParamMap(paramMap)
     require(m(maxIter) >= 0)
     require(m.contains(inputCol))
   }
+
+  def clearMaxIter(): this.type = clear(maxIter)
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to