Repository: mahout
Updated Branches:
  refs/heads/master ae1808be0 -> 310534319


http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala
----------------------------------------------------------------------
diff --git 
a/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala
 
b/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala
new file mode 100644
index 0000000..eafde11
--- /dev/null
+++ 
b/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala
@@ -0,0 +1,257 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.stats
+
+import java.lang.Double
+import java.util.Random
+import java.util.Arrays
+
+import org.apache.mahout.common.RandomUtils
+import org.apache.mahout.math.Matrix
+import org.apache.mahout.test.DistributedMahoutSuite
+import org.scalatest.{FunSuite, Matchers}
+
+
+
+trait ClassifierStatsTestBase extends DistributedMahoutSuite with Matchers { 
this: FunSuite =>
+
+  val epsilon = 1E-6
+
+  val smallEpsilon = 1.0
+
+  // FullRunningAverageAndStdDev tests
+  test("testFullRunningAverageAndStdDev") {
+    val average: RunningAverageAndStdDev = new FullRunningAverageAndStdDev
+    assert(0 == average.getCount)
+    assert(true == Double.isNaN(average.getAverage))
+    assert(true == Double.isNaN(average.getStandardDeviation))
+    average.addDatum(6.0)
+    assert(1 == average.getCount)
+    assert((6.0 - average.getAverage).abs < epsilon)
+    assert(true == Double.isNaN(average.getStandardDeviation))
+    average.addDatum(6.0)
+    assert(2 == average.getCount)
+    assert((6.0 - average.getAverage).abs < epsilon)
+    assert((0.0 - average.getStandardDeviation).abs < epsilon)
+    average.removeDatum(6.0)
+    assert(1 == average.getCount)
+    assert((6.0 - average.getAverage).abs < epsilon)
+    assert(true == Double.isNaN(average.getStandardDeviation))
+    average.addDatum(-4.0)
+    assert(2 == average.getCount)
+    assert((1.0 - average.getAverage).abs < epsilon)
+    assert(((5.0 * 1.4142135623730951) - average.getStandardDeviation).abs < 
epsilon)
+    average.removeDatum(4.0)
+    assert(1 == average.getCount)
+    assert((2.0 + average.getAverage).abs < epsilon)
+    assert(true == Double.isNaN(average.getStandardDeviation))
+  }
+
+  test("testBigFullRunningAverageAndStdDev") {
+    val average: RunningAverageAndStdDev = new FullRunningAverageAndStdDev
+    RandomUtils.useTestSeed()
+    val r: Random = RandomUtils.getRandom
+
+    for (i <- 0 until 100000) {
+      average.addDatum(r.nextDouble() * 1000.0)
+    }
+
+    assert((500.0 - average.getAverage).abs < smallEpsilon)
+    assert(((1000.0 / Math.sqrt(12.0)) - average.getStandardDeviation).abs < 
smallEpsilon)
+  }
+
+  test("testStddevFullRunningAverageAndStdDev") {
+    val runningAverage: RunningAverageAndStdDev = new 
FullRunningAverageAndStdDev
+    assert(0 == runningAverage.getCount)
+    assert(true == Double.isNaN(runningAverage.getAverage))
+    runningAverage.addDatum(1.0)
+    assert(1 == runningAverage.getCount)
+    assert((1.0 - runningAverage.getAverage).abs < epsilon)
+    assert(true == Double.isNaN(runningAverage.getStandardDeviation))
+    runningAverage.addDatum(1.0)
+    assert(2 == runningAverage.getCount)
+    assert((1.0 - runningAverage.getAverage).abs < epsilon)
+    assert((0.0 -runningAverage.getStandardDeviation).abs < epsilon)
+    runningAverage.addDatum(7.0)
+    assert(3 == runningAverage.getCount)
+    assert((3.0 - runningAverage.getAverage).abs < epsilon)
+    assert((3.464101552963257 - runningAverage.getStandardDeviation).abs < 
epsilon)
+    runningAverage.addDatum(5.0)
+    assert(4 == runningAverage.getCount)
+    assert((3.5 - runningAverage.getAverage) < epsilon)
+    assert((3.0- runningAverage.getStandardDeviation).abs < epsilon)
+  }
+
+
+
+  // FullRunningAverage tests
+  test("testFullRunningAverage"){
+    val runningAverage: RunningAverage = new FullRunningAverage
+    assert(0 == runningAverage.getCount)
+    assert(true == Double.isNaN(runningAverage.getAverage))
+    runningAverage.addDatum(1.0)
+    assert(1 == runningAverage.getCount)
+    assert((1.0 - runningAverage.getAverage).abs < epsilon)
+    runningAverage.addDatum(1.0)
+    assert(2 == runningAverage.getCount)
+    assert((1.0 - runningAverage.getAverage).abs < epsilon)
+    runningAverage.addDatum(4.0)
+    assert(3 == runningAverage.getCount)
+    assert((2.0 - runningAverage.getAverage) < epsilon)
+    runningAverage.addDatum(-4.0)
+    assert(4 == runningAverage.getCount)
+    assert((0.5 - runningAverage.getAverage).abs < epsilon)
+    runningAverage.removeDatum(-4.0)
+    assert(3 == runningAverage.getCount)
+    assert((2.0 - runningAverage.getAverage).abs < epsilon)
+    runningAverage.removeDatum(4.0)
+    assert(2 == runningAverage.getCount)
+    assert((1.0 - runningAverage.getAverage).abs < epsilon)
+    runningAverage.changeDatum(0.0)
+    assert(2 == runningAverage.getCount)
+    assert((1.0 - runningAverage.getAverage).abs < epsilon)
+    runningAverage.changeDatum(2.0)
+    assert(2 == runningAverage.getCount)
+    assert((2.0 - runningAverage.getAverage).abs < epsilon)
+  }
+
+
+  test("testFullRunningAveragCopyConstructor") {
+    val runningAverage: RunningAverage = new FullRunningAverage
+    runningAverage.addDatum(1.0)
+    runningAverage.addDatum(1.0)
+    assert(2 == runningAverage.getCount)
+    assert(1.0 - runningAverage.getAverage < epsilon)
+    val copy: RunningAverage = new FullRunningAverage(runningAverage.getCount, 
runningAverage.getAverage)
+    assert(2 == copy.getCount)
+    assert(1.0 - copy.getAverage < epsilon)
+  }
+
+
+
+  // Inverted Running Average tests
+  test("testInvertedRunningAverage") {
+    val avg: RunningAverage = new FullRunningAverage
+    val inverted: RunningAverage = new InvertedRunningAverage(avg)
+    assert(0 == inverted.getCount)
+    avg.addDatum(1.0)
+    assert(1 == inverted.getCount)
+    assert((1.0 + inverted.getAverage).abs < epsilon) // inverted.getAverage 
== -1.0
+    avg.addDatum(2.0)
+    assert(2 == inverted.getCount)
+    assert((1.5 + inverted.getAverage).abs < epsilon) // inverted.getAverage 
== -1.5
+  }
+
+  test ("testInvertedRunningAverageAndStdDev") {
+    val avg: RunningAverageAndStdDev = new FullRunningAverageAndStdDev
+    val inverted: RunningAverageAndStdDev = new 
InvertedRunningAverageAndStdDev(avg)
+    assert(0 == inverted.getCount)
+    avg.addDatum(1.0)
+    assert(1 == inverted.getCount)
+    assert(((1.0 + inverted.getAverage).abs < epsilon)) // inverted.getAverage 
== -1.0
+    avg.addDatum(2.0)
+    assert(2 == inverted.getCount)
+    assert((1.5 + inverted.getAverage).abs < epsilon) // inverted.getAverage 
== -1.5
+    assert(((Math.sqrt(2.0) / 2.0) - inverted.getStandardDeviation).abs < 
epsilon)
+  }
+
+
+  // confusion Matrix tests
+  val VALUES: Array[Array[Int]] = Array(Array(2, 3), Array(10, 20))
+  val LABELS: Array[String] = Array("Label1", "Label2")
+  val OTHER: Array[Int] = Array(3, 6)
+  val DEFAULT_LABEL: String = "other"
+
+  def fillConfusionMatrix(values: Array[Array[Int]], labels: Array[String], 
defaultLabel: String): ConfusionMatrix = {
+    val labelList = Arrays.asList(labels(0),labels(1))
+    val confusionMatrix: ConfusionMatrix = new ConfusionMatrix(labelList, 
defaultLabel)
+    confusionMatrix.putCount("Label1", "Label1", values(0)(0))
+    confusionMatrix.putCount("Label1", "Label2", values(0)(1))
+    confusionMatrix.putCount("Label2", "Label1", values(1)(0))
+    confusionMatrix.putCount("Label2", "Label2", values(1)(1))
+    confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER(0))
+    confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER(1))
+
+    confusionMatrix
+  }
+
+  private def checkAccuracy(cm: ConfusionMatrix) {
+    val labelstrs = cm.getLabels
+    assert(3 == labelstrs.size)
+    assert((25.0 - cm.getAccuracy("Label1")).abs < epsilon)
+    assert((55.5555555 - cm.getAccuracy("Label2")).abs < epsilon)
+    assert(true == Double.isNaN(cm.getAccuracy("other")))
+  }
+
+  private def checkValues(cm: ConfusionMatrix) {
+    val counts: Array[Array[Int]] = cm.getConfusionMatrix
+    cm.toString
+    assert(counts.length == counts(0).length)
+    assert(3 == counts.length)
+    assert(VALUES(0)(0) == counts(0)(0))
+    assert(VALUES(0)(1) == counts(0)(1))
+    assert(VALUES(1)(0) == counts(1)(0))
+    assert(VALUES(1)(1) == counts(1)(1))
+    assert(true == Arrays.equals(new Array[Int](3), counts(2)))
+    assert(OTHER(0) == counts(0)(2))
+    assert(OTHER(1) == counts(1)(2))
+    assert(3 == cm.getLabels.size)
+    assert(true == cm.getLabels.contains(LABELS(0)))
+    assert(true == cm.getLabels.contains(LABELS(1)))
+    assert(true == cm.getLabels.contains(DEFAULT_LABEL))
+  }
+
+  test("testBuild"){
+    val confusionMatrix: ConfusionMatrix = fillConfusionMatrix(VALUES, LABELS, 
DEFAULT_LABEL)
+    checkValues(confusionMatrix)
+    checkAccuracy(confusionMatrix)
+  }
+
+  test("GetMatrix") {
+    val confusionMatrix: ConfusionMatrix = fillConfusionMatrix(VALUES, LABELS, 
DEFAULT_LABEL)
+    val m: Matrix = confusionMatrix.getMatrix
+    val rowLabels = m.getRowLabelBindings
+    assert(confusionMatrix.getLabels.size == m.numCols)
+    assert(true == rowLabels.keySet.contains(LABELS(0)))
+    assert(true == rowLabels.keySet.contains(LABELS(1)))
+    assert(true == rowLabels.keySet.contains(DEFAULT_LABEL))
+    assert(2 == confusionMatrix.getCorrect(LABELS(0)))
+    assert(20 == confusionMatrix.getCorrect(LABELS(1)))
+    assert(0 == confusionMatrix.getCorrect(DEFAULT_LABEL))
+  }
+
+  /**
+   * Example taken from
+   * 
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html
+   */
+  test("testPrecisionRecallAndF1ScoreAsScikitLearn") {
+    val labelList = Arrays.asList("0", "1", "2")
+    val confusionMatrix: ConfusionMatrix = new ConfusionMatrix(labelList, 
"DEFAULT")
+    confusionMatrix.putCount("0", "0", 2)
+    confusionMatrix.putCount("1", "0", 1)
+    confusionMatrix.putCount("1", "2", 1)
+    confusionMatrix.putCount("2", "1", 2)
+    val delta: Double = 0.001
+    assert((0.222 - confusionMatrix.getWeightedPrecision).abs < delta)
+    assert((0.333 - confusionMatrix.getWeightedRecall).abs < delta)
+    assert((0.266 - confusionMatrix.getWeightedF1score).abs < delta)
+  }
+
+
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
----------------------------------------------------------------------
diff --git 
a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java 
b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
index ebbed92..3ffff85 100644
--- 
a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
+++ 
b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java
@@ -115,5 +115,5 @@ public final class ConfusionMatrixTest extends 
MahoutTestCase {
     confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]);
     return confusionMatrix;
   }
-  
+
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
----------------------------------------------------------------------
diff --git 
a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
 
b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
index 0df42a3..a957786 100644
--- 
a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
+++ 
b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala
@@ -45,6 +45,12 @@ class MahoutSparkILoop extends SparkILoop {
       conf.set("spark.executor.uri", execUri)
     }
 
+    // temporarily hard code spark.kryoserializer.buffer.mb
+    // to allow for seq2sparse data
+    //TODO: remove this before pushing to apache/master
+    conf.set("spark.kryoserializer.buffer.mb","100")
+    conf.set("spark.akka.frameSize","100")
+
     sparkContext = mahoutSparkContext(
       masterUrl = master,
       appName = "Mahout Spark Shell",

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala
 
b/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala
new file mode 100644
index 0000000..fd7116e
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.naivebayes
+
+import org.apache.mahout.math._
+
+import org.apache.mahout.sparkbindings.drm.CheckpointedDrmSpark
+
+import scalabindings._
+import scalabindings.RLikeOps._
+import drm.RLikeDrmOps._
+import drm._
+import scala.reflect.ClassTag
+import scala.language.asInstanceOf
+import collection._
+import JavaConversions._
+import org.apache.spark.SparkContext._
+
+import org.apache.mahout.classifier.naivebayes._
+import org.apache.mahout.sparkbindings._
+
+/**
+ * Distributed training of a Naive Bayes model. Follows the approach presented 
in Rennie et.al.: Tackling the poor
+ * assumptions of Naive Bayes Text classifiers, ICML 2003, 
http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf
+ */
+object SparkNaiveBayes extends NaiveBayes{
+
+  /**
+   * Math-Scala Naive Bayes optimized for Spark.
+   *
+   * Extract label Keys from raw TF or TF-IDF Matrix generated by 
seqdirectory/seq2sparse
+   * and aggregate TF or TF-IDF values by their label
+   *
+   * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse
+   *   in form K = e.g./Category/document_title
+   *           V = TF or TF-IDF values per term
+   * @param cParser a String => String function used to extract categories from
+   *   Keys of the stringKeyedObservations DRM. The default
+   *   CategoryParser will extract "Category" from: '/Category/document_id'
+   * @return  (labelIndexMap, aggregatedByLabelObservationDrm)
+   *   labelIndexMap is a HashMap  K = label row index
+   *                               V = label
+   *   aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated
+   *   TF or TF-IDF counts per label
+   */
+  override def extractLabelsAndAggregateObservations[K: 
ClassTag](stringKeyedObservations: DrmLike[K],
+                                                                  cParser: 
CategoryParser = seq2SparseCategoryParser)
+                                                                 (implicit 
ctx: DistributedContext):
+                                                                 
(mutable.HashMap[String, Integer], DrmLike[Int]) = {
+
+    val stringKeyedRdd = stringKeyedObservations
+                           .checkpoint()
+                           .asInstanceOf[CheckpointedDrmSpark[String]]
+                           .rdd
+
+    // how expensive is it for spark to sort (relatively few) tuples?
+    // does this cause repartitioning on the back end?
+    val aggregatedRdd = stringKeyedRdd
+                          .map(x => (cParser(x._1), x._2))
+                          .reduceByKey(_ + _)
+                        //  .sortByKey(true)
+
+    stringKeyedObservations.uncache()
+
+    var categoryIndex = 0
+    val labelIndexMap = new mutable.HashMap[String, Integer]
+
+    // todo: has to be an better way of creating this map
+    val categoryArray = 
aggregatedRdd.keys.takeOrdered(aggregatedRdd.count.toInt)
+    for(i <- 0 until categoryArray.size){
+      labelIndexMap.put(categoryArray(i), categoryIndex)
+      categoryIndex += 1
+    }
+
+    val intKeyedRdd = aggregatedRdd.map(x => (labelIndexMap(x._1).toInt, x._2))
+
+    val aggregetedObservationByLabelDrm = drmWrap(intKeyedRdd)
+
+    (labelIndexMap, aggregetedObservationByLabelDrm)
+  }
+
+
+}
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala 
b/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala
new file mode 100644
index 0000000..7d0738c
--- /dev/null
+++ b/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala
@@ -0,0 +1,131 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package org.apache.mahout.drivers
+
+import org.apache.mahout.classifier.naivebayes.{NBModel, NaiveBayes}
+import org.apache.mahout.classifier.stats.ConfusionMatrix
+import org.apache.mahout.math.drm
+import org.apache.mahout.math.drm.DrmLike
+import scala.collection.immutable.HashMap
+
+
+object TestNBDriver extends MahoutSparkDriver {
+  // define only the options specific to TestNB
+  private final val testNBOptipns = HashMap[String, Any](
+    "appName" -> "TestNBDriver")
+
+  /**
+   * @param args  Command line args, if empty a help message is printed.
+   */
+  override def main(args: Array[String]): Unit = {
+
+    parser = new MahoutSparkOptionParser(programName = "spark-testnb") {
+      head("spark-testnb", "Mahout 1.0")
+
+      //Input output options, non-driver specific
+      parseIOOptions(numInputs = 1)
+
+      //Algorithm control options--driver specific
+      opts = opts ++ testNBOptipns
+      note("\nAlgorithm control options:")
+
+      //default testComplementary is false
+      opts = opts + ("testComplementary" -> false)
+      opt[Unit]("testComplementary") abbr ("c") action { (_, options) =>
+        options + ("testComplementary" -> true)
+      } text ("Test a complementary model, Default: false.")
+
+
+
+      opt[String]("pathToModel") abbr ("m") action { (x, options) =>
+        options + ("pathToModel" -> x)
+      } text ("Path to the Trained Model")
+
+
+      //How to search for input
+      parseFileDiscoveryOptions
+
+      //Drm output schema--not driver specific, drm specific
+      parseDrmFormatOptions
+
+      //Spark config options--not driver specific
+      parseSparkOptions
+
+      //Jar inclusion, this option can be set when executing the driver from 
compiled code, not when from CLI
+      parseGenericOptions
+
+      help("help") abbr ("h") text ("prints this usage text\n")
+
+    }
+    parser.parse(args, parser.opts) map { opts =>
+      parser.opts = opts
+      process
+    }
+  }
+
+  override def start(masterUrl: String = 
parser.opts("master").asInstanceOf[String],
+      appName: String = parser.opts("appName").asInstanceOf[String]):
+    Unit = {
+
+    // will be only specific to this job.
+    // Note: set a large spark.kryoserializer.buffer.mb if using DSL MapBlock 
else leave as default
+
+    if (parser.opts("sparkExecutorMem").asInstanceOf[String] != "")
+      sparkConf.set("spark.executor.memory", 
parser.opts("sparkExecutorMem").asInstanceOf[String])
+
+    // Note: set a large akka frame size for DSL NB (20)
+    //sparkConf.set("spark.akka.frameSize","20") // don't need this for Spark 
optimized NaiveBayes..
+    //else leave as set in Spark config
+
+    super.start(masterUrl, appName)
+
+    }
+
+  /** Read the test set from inputPath/part-x-00000 sequence file of form 
<Text,VectorWritable> */
+  private def readTestSet: DrmLike[_] = {
+    val inputPath = parser.opts("input").asInstanceOf[String]
+    val trainingSet= drm.drmDfsRead(inputPath)
+    trainingSet
+  }
+
+  /** read the model from pathToModel using NBModel.DfsRead(...) */
+  private def readModel: NBModel = {
+    val inputPath = parser.opts("pathToModel").asInstanceOf[String]
+    val model= NBModel.dfsRead(inputPath)
+    model
+  }
+
+  override def process: Unit = {
+    start()
+
+    val testComplementary = 
parser.opts("testComplementary").asInstanceOf[Boolean]
+    val outputPath = parser.opts("output").asInstanceOf[String]
+
+    // todo:  get the -ow option in to check for a model in the path and 
overwrite if flagged.
+
+    val testSet = readTestSet
+    val model = readModel
+    val analyzer= NaiveBayes.test(model, testSet, testComplementary)
+
+    println(analyzer)
+
+    stop
+  }
+
+}
+

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala
----------------------------------------------------------------------
diff --git a/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala 
b/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala
new file mode 100644
index 0000000..35ff90b
--- /dev/null
+++ b/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala
@@ -0,0 +1,115 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements.  See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License.  You may obtain a copy of the License at
+
+     http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package org.apache.mahout.drivers
+
+import org.apache.mahout.classifier.naivebayes._
+import org.apache.mahout.classifier.naivebayes.SparkNaiveBayes
+import org.apache.mahout.math.drm
+import org.apache.mahout.math.drm.DrmLike
+import scala.collection.immutable.HashMap
+
+
+object TrainNBDriver extends MahoutSparkDriver {
+  // define only the options specific to TrainNB
+  private final val trainNBOptipns = HashMap[String, Any](
+    "appName" -> "TrainNBDriver")
+
+  /**
+   * @param args  Command line args, if empty a help message is printed.
+   */
+  override def main(args: Array[String]): Unit = {
+
+    parser = new MahoutSparkOptionParser(programName = "spark-trainnb") {
+      head("spark-trainnb", "Mahout 1.0")
+
+      //Input output options, non-driver specific
+      parseIOOptions(numInputs = 1)
+
+      //Algorithm control options--driver specific
+      opts = opts ++ trainNBOptipns
+      note("\nAlgorithm control options:")
+
+      //default trainComplementary is false
+      opts = opts + ("trainComplementary" -> false)
+      opt[Unit]("trainComplementary") abbr ("c") action { (_, options) =>
+        options + ("trainComplementary" -> true)
+      } text ("Train a complementary model, Default: false.")
+      
+
+      //How to search for input
+      parseFileDiscoveryOptions
+
+      //Drm output schema--not driver specific, drm specific
+      parseDrmFormatOptions
+
+      //Spark config options--not driver specific
+      parseSparkOptions
+
+      //Jar inclusion, this option can be set when executing the driver from 
compiled code, not when from CLI
+      parseGenericOptions
+
+      help("help") abbr ("h") text ("prints this usage text\n")
+
+    }
+    parser.parse(args, parser.opts) map { opts =>
+      parser.opts = opts
+      process
+    }
+  }
+
+  override def start(masterUrl: String = 
parser.opts("master").asInstanceOf[String],
+      appName: String = parser.opts("appName").asInstanceOf[String]):
+    Unit = {
+
+    // will be only specific to this job.
+    // Note: set a large spark.kryoserializer.buffer.mb if using DSL MapBlock 
else leave as default
+
+    if (parser.opts("sparkExecutorMem").asInstanceOf[String] != "")
+      sparkConf.set("spark.executor.memory", 
parser.opts("sparkExecutorMem").asInstanceOf[String])
+
+    // Note: set a large akka frame size for DSL NB (20)
+    // sparkConf.set("spark.akka.frameSize","20") // don't need this for Spark 
optimized NaiveBayes..
+    // else leave as set in Spark config
+
+    super.start(masterUrl, appName)
+
+    }
+
+  /** Read the training set from inputPath/part-x-00000 sequence file of form 
<Text,VectorWritable> */
+  private def readTrainingSet: DrmLike[_]= {
+    val inputPath = parser.opts("input").asInstanceOf[String]
+    val trainingSet= drm.drmDfsRead(inputPath)
+    trainingSet
+  }
+
+  override def process: Unit = {
+    start()
+
+    val complementary = parser.opts("trainComplementary").asInstanceOf[Boolean]
+    val outputPath = parser.opts("output").asInstanceOf[String]
+
+    val trainingSet = readTrainingSet
+    val (labelIndex, aggregatedObservations) = 
SparkNaiveBayes.extractLabelsAndAggregateObservations(trainingSet)
+    val model = NaiveBayes.train(aggregatedObservations, labelIndex)
+
+    model.dfsWrite(outputPath)
+
+    stop
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala 
b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
index 2a2a4a9..c04b306 100644
--- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
+++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala
@@ -110,6 +110,4 @@ package object drm {
             key -> v
         }
     }
-
-
 }

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala
 
b/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala
new file mode 100644
index 0000000..d999f9b
--- /dev/null
+++ 
b/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.classifier.naivebayes
+
+import org.apache.mahout.math._
+import org.apache.mahout.math.scalabindings.RLikeOps._
+import org.apache.mahout.math.scalabindings._
+import org.apache.mahout.sparkbindings.test.DistributedSparkSuite
+import org.apache.mahout.test.MahoutSuite
+import org.scalatest.FunSuite
+
+class NBSparkTestSuite extends FunSuite with MahoutSuite with 
DistributedSparkSuite with NBTestBase {
+
+  test("Spark NB Aggregator") {
+
+    val rowBindings = new java.util.HashMap[String,Integer]()
+    rowBindings.put("/Cat1/doc_a/", 0)
+    rowBindings.put("/Cat2/doc_b/", 1)
+    rowBindings.put("/Cat1/doc_c/", 2)
+    rowBindings.put("/Cat2/doc_d/", 3)
+    rowBindings.put("/Cat1/doc_e/", 4)
+
+
+    val matrixSetup = sparse(
+      (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil,
+      (0, 0.0) ::(1, 0.1) ::(2, 0.0) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil,
+      (0, 0.0) ::(1, 0.1) ::(2, 0.0) ::(3, 0.1) :: Nil,
+      (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil
+    )
+
+
+    matrixSetup.setRowLabelBindings(rowBindings)
+
+    val TFIDFDrm = drm.drmParallelizeWithRowLabels(m = matrixSetup, 
numPartitions = 2)
+
+    val (dslLabelIndex, dslAggregatedTFIDFDrm) = 
NaiveBayes.extractLabelsAndAggregateObservations(TFIDFDrm)
+    val (sparkLabelIndex, sparkAggregatedTFIDFDrm) = 
SparkNaiveBayes.extractLabelsAndAggregateObservations(TFIDFDrm)
+
+    dslLabelIndex.size should be (2)
+    sparkLabelIndex.size should be (2)
+
+    val dslCat1=dslLabelIndex("Cat1")
+    val dslCat2=dslLabelIndex("Cat2")
+
+    val sparkCat1=sparkLabelIndex("Cat1")
+    val sparkCat2=sparkLabelIndex("Cat2")
+
+
+    dslCat1 should be (0)
+    dslCat2 should be (1)
+
+    sparkCat1 should be (0)
+    sparkCat2 should be (1)
+
+    val dslAggInCore = dslAggregatedTFIDFDrm.collect
+    val sparkAggInCore = sparkAggregatedTFIDFDrm.collect
+
+    dslAggInCore.numCols should be (4) //4
+    dslAggInCore.numRows should be (2) //2
+
+    dslAggInCore(dslCat1, 0) - sparkAggInCore(dslCat1, 0) should be < epsilon 
//0.3
+    dslAggInCore(dslCat1, 1) - sparkAggInCore(dslCat1, 1) should be < epsilon 
//0.0
+    dslAggInCore(dslCat1, 2) - sparkAggInCore(dslCat1, 2) should be < epsilon 
//0.3
+    dslAggInCore(dslCat1, 3) - sparkAggInCore(dslCat1, 3) should be < epsilon 
//0.0
+    dslAggInCore(dslCat2, 0) - sparkAggInCore(dslCat2, 0) should be < epsilon 
//0.0
+    dslAggInCore(dslCat2, 1) - sparkAggInCore(dslCat2, 1) should be < epsilon 
//0.2
+    dslAggInCore(dslCat2, 2) - sparkAggInCore(dslCat2, 2) should be < epsilon 
//0.0
+    dslAggInCore(dslCat2, 3) - sparkAggInCore(dslCat2, 3) should be < epsilon 
//0.2
+
+  }
+}

http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala
----------------------------------------------------------------------
diff --git 
a/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala
 
b/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala
new file mode 100644
index 0000000..3af7649
--- /dev/null
+++ 
b/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.stats
+
+import org.apache.mahout.sparkbindings.test.DistributedSparkSuite
+import org.apache.mahout.test.MahoutSuite
+import org.scalatest.FunSuite
+
+class ClassifierStatsSparkTestSuite extends FunSuite with MahoutSuite with 
DistributedSparkSuite with ClassifierStatsTestBase
+
+

Reply via email to