Repository: spark
Updated Branches:
  refs/heads/master 6a6adb167 -> d15b4f90e


[SPARK-17507][ML][MLLIB] check weight vector size in ANN

## What changes were proposed in this pull request?

as the TODO described,
check weight vector size and if wrong throw exception.

## How was this patch tested?

existing tests.

Author: WeichenXu <weichenxu...@outlook.com>

Closes #15060 from WeichenXu123/check_input_weight_size_of_ann.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d15b4f90
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d15b4f90
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d15b4f90

Branch: refs/heads/master
Commit: d15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677
Parents: 6a6adb1
Author: WeichenXu <weichenxu...@outlook.com>
Authored: Thu Sep 15 09:30:15 2016 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Sep 15 09:30:15 2016 +0100

----------------------------------------------------------------------
 mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d15b4f90/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 88909a9..e7e0dae 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -545,7 +545,9 @@ private[ann] object FeedForwardModel {
    * @return model
    */
   def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel 
= {
-    // TODO: check that weights size is equal to sum of layers sizes
+    val expectedWeightSize = topology.layers.map(_.weightSize).sum
+    require(weights.size == expectedWeightSize,
+      s"Expected weight vector of size ${expectedWeightSize} but got size 
${weights.size}.")
     new FeedForwardModel(weights, topology)
   }
 
@@ -559,11 +561,7 @@ private[ann] object FeedForwardModel {
   def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel 
= {
     val layers = topology.layers
     val layerModels = new Array[LayerModel](layers.length)
-    var totalSize = 0
-    for (i <- 0 until topology.layers.length) {
-      totalSize += topology.layers(i).weightSize
-    }
-    val weights = BDV.zeros[Double](totalSize)
+    val weights = BDV.zeros[Double](topology.layers.map(_.weightSize).sum)
     var offset = 0
     val random = new XORShiftRandom(seed)
     for (i <- 0 until layers.length) {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to