Repository: spark Updated Branches: refs/heads/master 6a6adb167 -> d15b4f90e
[SPARK-17507][ML][MLLIB] check weight vector size in ANN ## What changes were proposed in this pull request? as the TODO described, check weight vector size and if wrong throw exception. ## How was this patch tested? existing tests. Author: WeichenXu <weichenxu...@outlook.com> Closes #15060 from WeichenXu123/check_input_weight_size_of_ann. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d15b4f90 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d15b4f90 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d15b4f90 Branch: refs/heads/master Commit: d15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677 Parents: 6a6adb1 Author: WeichenXu <weichenxu...@outlook.com> Authored: Thu Sep 15 09:30:15 2016 +0100 Committer: Sean Owen <so...@cloudera.com> Committed: Thu Sep 15 09:30:15 2016 +0100 ---------------------------------------------------------------------- mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d15b4f90/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 88909a9..e7e0dae 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -545,7 +545,9 @@ private[ann] object FeedForwardModel { * @return model */ def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = { - // TODO: check that weights size is equal to sum of layers sizes + val expectedWeightSize = topology.layers.map(_.weightSize).sum + require(weights.size == expectedWeightSize, + s"Expected weight vector of size ${expectedWeightSize} but got size ${weights.size}.") new FeedForwardModel(weights, topology) } @@ -559,11 +561,7 @@ private[ann] object FeedForwardModel { def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = { val layers = topology.layers val layerModels = new Array[LayerModel](layers.length) - var totalSize = 0 - for (i <- 0 until topology.layers.length) { - totalSize += topology.layers(i).weightSize - } - val weights = BDV.zeros[Double](totalSize) + val weights = BDV.zeros[Double](topology.layers.map(_.weightSize).sum) var offset = 0 val random = new XORShiftRandom(seed) for (i <- 0 until layers.length) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org