Repository: spark Updated Branches: refs/heads/master 340f09d10 -> b89d0556d
[SPARK-18210][ML] Pipeline.copy does not create an instance with the same UID ## What changes were proposed in this pull request? Motivation: `org.apache.spark.ml.Pipeline.copy(extra: ParamMap)` does not create an instance with the same UID. It does not conform to the method specification from its base class `org.apache.spark.ml.param.Params.copy(extra: ParamMap)` Solution: - fix for Pipeline UID - introduced new tests for `org.apache.spark.ml.Pipeline.copy` - minor improvements in test for `org.apache.spark.ml.PipelineModel.copy` ## How was this patch tested? Introduced new unit test: `org.apache.spark.ml.PipelineSuite."Pipeline.copy"` Improved existing unit test: `org.apache.spark.ml.PipelineSuite."PipelineModel.copy"` Author: Wojciech Szymanski <wk.szyman...@gmail.com> Closes #15759 from wojtek-szymanski/SPARK-18210. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b89d0556 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b89d0556 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b89d0556 Branch: refs/heads/master Commit: b89d0556dff0520ab35882382242fbfa7d9478eb Parents: 340f09d Author: Wojciech Szymanski <wk.szyman...@gmail.com> Authored: Sun Nov 6 07:43:13 2016 -0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Sun Nov 6 07:43:13 2016 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../org/apache/spark/ml/PipelineSuite.scala | 22 ++++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b89d0556/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 195a93e..f406f8c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") ( override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) - new Pipeline().setStages(newStages) + new Pipeline(uid).setStages(newStages) } @Since("1.2.0") http://git-wip-us.apache.org/repos/asf/spark/blob/b89d0556/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 6413ca1..dafc6c2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } + test("Pipeline.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF)) + val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10)) + + assert(copied.uid === pipeline.uid, + "copy should create an instance with the same UID") + assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + test("PipelineModel.copy") { val hashingTF = new HashingTF() .setNumFeatures(100) - val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF)) + .setParent(new Pipeline()) val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) - require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + + assert(copied.uid === model.uid, + "copy should create an instance with the same UID") + assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, "copy should handle extra stage params") + assert(copied.parent === model.parent, + "copy should create an instance with the same parent") } test("pipeline model constructors") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org