Repository: spark
Updated Branches:
  refs/heads/branch-1.4 8d6684986 -> 5f64269c5


[SPARK-7762] [MLLIB] set default value for outputCol

Set a default value for `outputCol` instead of forcing users to name it. This 
is useful for intermediate transformers in the pipeline. jkbradley

Author: Xiangrui Meng <m...@databricks.com>

Closes #6289 from mengxr/SPARK-7762 and squashes the following commits:

54edebc [Xiangrui Meng] merge master
bff8667 [Xiangrui Meng] update unit test
171246b [Xiangrui Meng] add unit test for outputCol
a4321bd [Xiangrui Meng] set default value for outputCol

(cherry picked from commit c330e52dae6a3ec7e67ca82e2c2f4ea873976458)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5f64269c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5f64269c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5f64269c

Branch: refs/heads/branch-1.4
Commit: 5f64269c5251d15feec3ed000b8ad34b1d18502b
Parents: 8d66849
Author: Xiangrui Meng <m...@databricks.com>
Authored: Wed May 20 17:26:26 2015 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Wed May 20 17:26:44 2015 -0700

----------------------------------------------------------------------
 .../ml/param/shared/SharedParamsCodeGen.scala   |  2 +-
 .../spark/ml/param/shared/sharedParams.scala    |  4 ++-
 .../ml/param/shared/SharedParamsSuite.scala     | 35 ++++++++++++++++++++
 .../pyspark/ml/param/_shared_params_code_gen.py |  2 +-
 python/pyspark/ml/param/shared.py               |  3 +-
 5 files changed, 42 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5f64269c/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8b8cb81..1ffb5ed 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -49,7 +49,7 @@ private[shared] object SharedParamsCodeGen {
         isValid = "ParamValidators.inRange(0, 1)"),
       ParamDesc[String]("inputCol", "input column name"),
       ParamDesc[Array[String]]("inputCols", "input column names"),
-      ParamDesc[String]("outputCol", "output column name"),
+      ParamDesc[String]("outputCol", "output column name", Some("uid + 
\"__output\"")),
       ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
         isValid = "ParamValidators.gtEq(1)"),
       ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", 
Some("true")),

http://git-wip-us.apache.org/repos/asf/spark/blob/5f64269c/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala 
b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 3a4976d..ed08417 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -185,7 +185,7 @@ private[ml] trait HasInputCols extends Params {
 }
 
 /**
- * (private[ml]) Trait for shared param outputCol.
+ * (private[ml]) Trait for shared param outputCol (default: uid + "__output").
  */
 private[ml] trait HasOutputCol extends Params {
 
@@ -195,6 +195,8 @@ private[ml] trait HasOutputCol extends Params {
    */
   final val outputCol: Param[String] = new Param[String](this, "outputCol", 
"output column name")
 
+  setDefault(outputCol, uid + "__output")
+
   /** @group getParam */
   final def getOutputCol: String = $(outputCol)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5f64269c/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
new file mode 100644
index 0000000..ca18fa1
--- /dev/null
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/param/shared/SharedParamsSuite.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.param.shared
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.ml.param.Params
+
+class SharedParamsSuite extends FunSuite {
+
+  test("outputCol") {
+
+    class Obj(override val uid: String) extends Params with HasOutputCol
+
+    val obj = new Obj("obj")
+
+    assert(obj.hasDefault(obj.outputCol))
+    assert(obj.getOrDefault(obj.outputCol) === "obj__output")
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/5f64269c/python/pyspark/ml/param/_shared_params_code_gen.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py 
b/python/pyspark/ml/param/_shared_params_code_gen.py
index ccb929a..69efc42 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -116,7 +116,7 @@ if __name__ == "__main__":
         ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name", 
"'rawPrediction'"),
         ("inputCol", "input column name", None),
         ("inputCols", "input column names", None),
-        ("outputCol", "output column name", None),
+        ("outputCol", "output column name", "self.uid + '__output'"),
         ("numFeatures", "number of features", None),
         ("checkpointInterval", "checkpoint interval (>= 1)", None),
         ("seed", "random seed", "hash(type(self).__name__)"),

http://git-wip-us.apache.org/repos/asf/spark/blob/5f64269c/python/pyspark/ml/param/shared.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/param/shared.py 
b/python/pyspark/ml/param/shared.py
index 0b93788..bc088e4 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -280,6 +280,7 @@ class HasOutputCol(Params):
         super(HasOutputCol, self).__init__()
         #: param for output column name
         self.outputCol = Param(self, "outputCol", "output column name")
+        self._setDefault(outputCol=self.uid + '__output')
 
     def setOutputCol(self, value):
         """
@@ -459,7 +460,7 @@ class DecisionTreeParams(Params):
         self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in 
MB allocated to histogram aggregation.")
         #: param for If false, the algorithm will pass trees to executors to 
match instances with nodes. If true, the algorithm will cache node IDs for each 
instance. Caching can speed up training of deeper trees.
         self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the 
algorithm will pass trees to executors to match instances with nodes. If true, 
the algorithm will cache node IDs for each instance. Caching can speed up 
training of deeper trees.")
-
+        
     def setMaxDepth(self, value):
         """
         Sets the value of :py:attr:`maxDepth`.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to