Repository: mahout Updated Branches: refs/heads/master ae1808be0 -> 310534319
http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala b/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala new file mode 100644 index 0000000..eafde11 --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsTestBase.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.stats + +import java.lang.Double +import java.util.Random +import java.util.Arrays + +import org.apache.mahout.common.RandomUtils +import org.apache.mahout.math.Matrix +import org.apache.mahout.test.DistributedMahoutSuite +import org.scalatest.{FunSuite, Matchers} + + + +trait ClassifierStatsTestBase extends DistributedMahoutSuite with Matchers { this: FunSuite => + + val epsilon = 1E-6 + + val smallEpsilon = 1.0 + + // FullRunningAverageAndStdDev tests + test("testFullRunningAverageAndStdDev") { + val average: RunningAverageAndStdDev = new FullRunningAverageAndStdDev + assert(0 == average.getCount) + assert(true == Double.isNaN(average.getAverage)) + assert(true == Double.isNaN(average.getStandardDeviation)) + average.addDatum(6.0) + assert(1 == average.getCount) + assert((6.0 - average.getAverage).abs < epsilon) + assert(true == Double.isNaN(average.getStandardDeviation)) + average.addDatum(6.0) + assert(2 == average.getCount) + assert((6.0 - average.getAverage).abs < epsilon) + assert((0.0 - average.getStandardDeviation).abs < epsilon) + average.removeDatum(6.0) + assert(1 == average.getCount) + assert((6.0 - average.getAverage).abs < epsilon) + assert(true == Double.isNaN(average.getStandardDeviation)) + average.addDatum(-4.0) + assert(2 == average.getCount) + assert((1.0 - average.getAverage).abs < epsilon) + assert(((5.0 * 1.4142135623730951) - average.getStandardDeviation).abs < epsilon) + average.removeDatum(4.0) + assert(1 == average.getCount) + assert((2.0 + average.getAverage).abs < epsilon) + assert(true == Double.isNaN(average.getStandardDeviation)) + } + + test("testBigFullRunningAverageAndStdDev") { + val average: RunningAverageAndStdDev = new FullRunningAverageAndStdDev + RandomUtils.useTestSeed() + val r: Random = RandomUtils.getRandom + + for (i <- 0 until 100000) { + average.addDatum(r.nextDouble() * 1000.0) + } + + assert((500.0 - average.getAverage).abs < smallEpsilon) + assert(((1000.0 / Math.sqrt(12.0)) - average.getStandardDeviation).abs < smallEpsilon) + } + + test("testStddevFullRunningAverageAndStdDev") { + val runningAverage: RunningAverageAndStdDev = new FullRunningAverageAndStdDev + assert(0 == runningAverage.getCount) + assert(true == Double.isNaN(runningAverage.getAverage)) + runningAverage.addDatum(1.0) + assert(1 == runningAverage.getCount) + assert((1.0 - runningAverage.getAverage).abs < epsilon) + assert(true == Double.isNaN(runningAverage.getStandardDeviation)) + runningAverage.addDatum(1.0) + assert(2 == runningAverage.getCount) + assert((1.0 - runningAverage.getAverage).abs < epsilon) + assert((0.0 -runningAverage.getStandardDeviation).abs < epsilon) + runningAverage.addDatum(7.0) + assert(3 == runningAverage.getCount) + assert((3.0 - runningAverage.getAverage).abs < epsilon) + assert((3.464101552963257 - runningAverage.getStandardDeviation).abs < epsilon) + runningAverage.addDatum(5.0) + assert(4 == runningAverage.getCount) + assert((3.5 - runningAverage.getAverage) < epsilon) + assert((3.0- runningAverage.getStandardDeviation).abs < epsilon) + } + + + + // FullRunningAverage tests + test("testFullRunningAverage"){ + val runningAverage: RunningAverage = new FullRunningAverage + assert(0 == runningAverage.getCount) + assert(true == Double.isNaN(runningAverage.getAverage)) + runningAverage.addDatum(1.0) + assert(1 == runningAverage.getCount) + assert((1.0 - runningAverage.getAverage).abs < epsilon) + runningAverage.addDatum(1.0) + assert(2 == runningAverage.getCount) + assert((1.0 - runningAverage.getAverage).abs < epsilon) + runningAverage.addDatum(4.0) + assert(3 == runningAverage.getCount) + assert((2.0 - runningAverage.getAverage) < epsilon) + runningAverage.addDatum(-4.0) + assert(4 == runningAverage.getCount) + assert((0.5 - runningAverage.getAverage).abs < epsilon) + runningAverage.removeDatum(-4.0) + assert(3 == runningAverage.getCount) + assert((2.0 - runningAverage.getAverage).abs < epsilon) + runningAverage.removeDatum(4.0) + assert(2 == runningAverage.getCount) + assert((1.0 - runningAverage.getAverage).abs < epsilon) + runningAverage.changeDatum(0.0) + assert(2 == runningAverage.getCount) + assert((1.0 - runningAverage.getAverage).abs < epsilon) + runningAverage.changeDatum(2.0) + assert(2 == runningAverage.getCount) + assert((2.0 - runningAverage.getAverage).abs < epsilon) + } + + + test("testFullRunningAveragCopyConstructor") { + val runningAverage: RunningAverage = new FullRunningAverage + runningAverage.addDatum(1.0) + runningAverage.addDatum(1.0) + assert(2 == runningAverage.getCount) + assert(1.0 - runningAverage.getAverage < epsilon) + val copy: RunningAverage = new FullRunningAverage(runningAverage.getCount, runningAverage.getAverage) + assert(2 == copy.getCount) + assert(1.0 - copy.getAverage < epsilon) + } + + + + // Inverted Running Average tests + test("testInvertedRunningAverage") { + val avg: RunningAverage = new FullRunningAverage + val inverted: RunningAverage = new InvertedRunningAverage(avg) + assert(0 == inverted.getCount) + avg.addDatum(1.0) + assert(1 == inverted.getCount) + assert((1.0 + inverted.getAverage).abs < epsilon) // inverted.getAverage == -1.0 + avg.addDatum(2.0) + assert(2 == inverted.getCount) + assert((1.5 + inverted.getAverage).abs < epsilon) // inverted.getAverage == -1.5 + } + + test ("testInvertedRunningAverageAndStdDev") { + val avg: RunningAverageAndStdDev = new FullRunningAverageAndStdDev + val inverted: RunningAverageAndStdDev = new InvertedRunningAverageAndStdDev(avg) + assert(0 == inverted.getCount) + avg.addDatum(1.0) + assert(1 == inverted.getCount) + assert(((1.0 + inverted.getAverage).abs < epsilon)) // inverted.getAverage == -1.0 + avg.addDatum(2.0) + assert(2 == inverted.getCount) + assert((1.5 + inverted.getAverage).abs < epsilon) // inverted.getAverage == -1.5 + assert(((Math.sqrt(2.0) / 2.0) - inverted.getStandardDeviation).abs < epsilon) + } + + + // confusion Matrix tests + val VALUES: Array[Array[Int]] = Array(Array(2, 3), Array(10, 20)) + val LABELS: Array[String] = Array("Label1", "Label2") + val OTHER: Array[Int] = Array(3, 6) + val DEFAULT_LABEL: String = "other" + + def fillConfusionMatrix(values: Array[Array[Int]], labels: Array[String], defaultLabel: String): ConfusionMatrix = { + val labelList = Arrays.asList(labels(0),labels(1)) + val confusionMatrix: ConfusionMatrix = new ConfusionMatrix(labelList, defaultLabel) + confusionMatrix.putCount("Label1", "Label1", values(0)(0)) + confusionMatrix.putCount("Label1", "Label2", values(0)(1)) + confusionMatrix.putCount("Label2", "Label1", values(1)(0)) + confusionMatrix.putCount("Label2", "Label2", values(1)(1)) + confusionMatrix.putCount("Label1", DEFAULT_LABEL, OTHER(0)) + confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER(1)) + + confusionMatrix + } + + private def checkAccuracy(cm: ConfusionMatrix) { + val labelstrs = cm.getLabels + assert(3 == labelstrs.size) + assert((25.0 - cm.getAccuracy("Label1")).abs < epsilon) + assert((55.5555555 - cm.getAccuracy("Label2")).abs < epsilon) + assert(true == Double.isNaN(cm.getAccuracy("other"))) + } + + private def checkValues(cm: ConfusionMatrix) { + val counts: Array[Array[Int]] = cm.getConfusionMatrix + cm.toString + assert(counts.length == counts(0).length) + assert(3 == counts.length) + assert(VALUES(0)(0) == counts(0)(0)) + assert(VALUES(0)(1) == counts(0)(1)) + assert(VALUES(1)(0) == counts(1)(0)) + assert(VALUES(1)(1) == counts(1)(1)) + assert(true == Arrays.equals(new Array[Int](3), counts(2))) + assert(OTHER(0) == counts(0)(2)) + assert(OTHER(1) == counts(1)(2)) + assert(3 == cm.getLabels.size) + assert(true == cm.getLabels.contains(LABELS(0))) + assert(true == cm.getLabels.contains(LABELS(1))) + assert(true == cm.getLabels.contains(DEFAULT_LABEL)) + } + + test("testBuild"){ + val confusionMatrix: ConfusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL) + checkValues(confusionMatrix) + checkAccuracy(confusionMatrix) + } + + test("GetMatrix") { + val confusionMatrix: ConfusionMatrix = fillConfusionMatrix(VALUES, LABELS, DEFAULT_LABEL) + val m: Matrix = confusionMatrix.getMatrix + val rowLabels = m.getRowLabelBindings + assert(confusionMatrix.getLabels.size == m.numCols) + assert(true == rowLabels.keySet.contains(LABELS(0))) + assert(true == rowLabels.keySet.contains(LABELS(1))) + assert(true == rowLabels.keySet.contains(DEFAULT_LABEL)) + assert(2 == confusionMatrix.getCorrect(LABELS(0))) + assert(20 == confusionMatrix.getCorrect(LABELS(1))) + assert(0 == confusionMatrix.getCorrect(DEFAULT_LABEL)) + } + + /** + * Example taken from + * http://scikit-learn.org/stable/modules/generated/sklearn.metrics.precision_recall_fscore_support.html + */ + test("testPrecisionRecallAndF1ScoreAsScikitLearn") { + val labelList = Arrays.asList("0", "1", "2") + val confusionMatrix: ConfusionMatrix = new ConfusionMatrix(labelList, "DEFAULT") + confusionMatrix.putCount("0", "0", 2) + confusionMatrix.putCount("1", "0", 1) + confusionMatrix.putCount("1", "2", 1) + confusionMatrix.putCount("2", "1", 2) + val delta: Double = 0.001 + assert((0.222 - confusionMatrix.getWeightedPrecision).abs < delta) + assert((0.333 - confusionMatrix.getWeightedRecall).abs < delta) + assert((0.266 - confusionMatrix.getWeightedF1score).abs < delta) + } + + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java ---------------------------------------------------------------------- diff --git a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java index ebbed92..3ffff85 100644 --- a/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java +++ b/mrlegacy/src/test/java/org/apache/mahout/classifier/ConfusionMatrixTest.java @@ -115,5 +115,5 @@ public final class ConfusionMatrixTest extends MahoutTestCase { confusionMatrix.putCount("Label2", DEFAULT_LABEL, OTHER[1]); return confusionMatrix; } - + } http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala ---------------------------------------------------------------------- diff --git a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala index 0df42a3..a957786 100644 --- a/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala +++ b/spark-shell/src/main/scala/org/apache/mahout/sparkbindings/shell/MahoutSparkILoop.scala @@ -45,6 +45,12 @@ class MahoutSparkILoop extends SparkILoop { conf.set("spark.executor.uri", execUri) } + // temporarily hard code spark.kryoserializer.buffer.mb + // to allow for seq2sparse data + //TODO: remove this before pushing to apache/master + conf.set("spark.kryoserializer.buffer.mb","100") + conf.set("spark.akka.frameSize","100") + sparkContext = mahoutSparkContext( masterUrl = master, appName = "Mahout Spark Shell", http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala b/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala new file mode 100644 index 0000000..fd7116e --- /dev/null +++ b/spark/src/main/scala/org/apache/mahout/classifier/naivebayes/SparkNaiveBayes.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.math._ + +import org.apache.mahout.sparkbindings.drm.CheckpointedDrmSpark + +import scalabindings._ +import scalabindings.RLikeOps._ +import drm.RLikeDrmOps._ +import drm._ +import scala.reflect.ClassTag +import scala.language.asInstanceOf +import collection._ +import JavaConversions._ +import org.apache.spark.SparkContext._ + +import org.apache.mahout.classifier.naivebayes._ +import org.apache.mahout.sparkbindings._ + +/** + * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor + * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf + */ +object SparkNaiveBayes extends NaiveBayes{ + + /** + * Math-Scala Naive Bayes optimized for Spark. + * + * Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse + * and aggregate TF or TF-IDF values by their label + * + * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse + * in form K = e.g./Category/document_title + * V = TF or TF-IDF values per term + * @param cParser a String => String function used to extract categories from + * Keys of the stringKeyedObservations DRM. The default + * CategoryParser will extract "Category" from: '/Category/document_id' + * @return (labelIndexMap, aggregatedByLabelObservationDrm) + * labelIndexMap is a HashMap K = label row index + * V = label + * aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated + * TF or TF-IDF counts per label + */ + override def extractLabelsAndAggregateObservations[K: ClassTag](stringKeyedObservations: DrmLike[K], + cParser: CategoryParser = seq2SparseCategoryParser) + (implicit ctx: DistributedContext): + (mutable.HashMap[String, Integer], DrmLike[Int]) = { + + val stringKeyedRdd = stringKeyedObservations + .checkpoint() + .asInstanceOf[CheckpointedDrmSpark[String]] + .rdd + + // how expensive is it for spark to sort (relatively few) tuples? + // does this cause repartitioning on the back end? + val aggregatedRdd = stringKeyedRdd + .map(x => (cParser(x._1), x._2)) + .reduceByKey(_ + _) + // .sortByKey(true) + + stringKeyedObservations.uncache() + + var categoryIndex = 0 + val labelIndexMap = new mutable.HashMap[String, Integer] + + // todo: has to be an better way of creating this map + val categoryArray = aggregatedRdd.keys.takeOrdered(aggregatedRdd.count.toInt) + for(i <- 0 until categoryArray.size){ + labelIndexMap.put(categoryArray(i), categoryIndex) + categoryIndex += 1 + } + + val intKeyedRdd = aggregatedRdd.map(x => (labelIndexMap(x._1).toInt, x._2)) + + val aggregetedObservationByLabelDrm = drmWrap(intKeyedRdd) + + (labelIndexMap, aggregetedObservationByLabelDrm) + } + + +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala b/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala new file mode 100644 index 0000000..7d0738c --- /dev/null +++ b/spark/src/main/scala/org/apache/mahout/drivers/TestNBDriver.scala @@ -0,0 +1,131 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package org.apache.mahout.drivers + +import org.apache.mahout.classifier.naivebayes.{NBModel, NaiveBayes} +import org.apache.mahout.classifier.stats.ConfusionMatrix +import org.apache.mahout.math.drm +import org.apache.mahout.math.drm.DrmLike +import scala.collection.immutable.HashMap + + +object TestNBDriver extends MahoutSparkDriver { + // define only the options specific to TestNB + private final val testNBOptipns = HashMap[String, Any]( + "appName" -> "TestNBDriver") + + /** + * @param args Command line args, if empty a help message is printed. + */ + override def main(args: Array[String]): Unit = { + + parser = new MahoutSparkOptionParser(programName = "spark-testnb") { + head("spark-testnb", "Mahout 1.0") + + //Input output options, non-driver specific + parseIOOptions(numInputs = 1) + + //Algorithm control options--driver specific + opts = opts ++ testNBOptipns + note("\nAlgorithm control options:") + + //default testComplementary is false + opts = opts + ("testComplementary" -> false) + opt[Unit]("testComplementary") abbr ("c") action { (_, options) => + options + ("testComplementary" -> true) + } text ("Test a complementary model, Default: false.") + + + + opt[String]("pathToModel") abbr ("m") action { (x, options) => + options + ("pathToModel" -> x) + } text ("Path to the Trained Model") + + + //How to search for input + parseFileDiscoveryOptions + + //Drm output schema--not driver specific, drm specific + parseDrmFormatOptions + + //Spark config options--not driver specific + parseSparkOptions + + //Jar inclusion, this option can be set when executing the driver from compiled code, not when from CLI + parseGenericOptions + + help("help") abbr ("h") text ("prints this usage text\n") + + } + parser.parse(args, parser.opts) map { opts => + parser.opts = opts + process + } + } + + override def start(masterUrl: String = parser.opts("master").asInstanceOf[String], + appName: String = parser.opts("appName").asInstanceOf[String]): + Unit = { + + // will be only specific to this job. + // Note: set a large spark.kryoserializer.buffer.mb if using DSL MapBlock else leave as default + + if (parser.opts("sparkExecutorMem").asInstanceOf[String] != "") + sparkConf.set("spark.executor.memory", parser.opts("sparkExecutorMem").asInstanceOf[String]) + + // Note: set a large akka frame size for DSL NB (20) + //sparkConf.set("spark.akka.frameSize","20") // don't need this for Spark optimized NaiveBayes.. + //else leave as set in Spark config + + super.start(masterUrl, appName) + + } + + /** Read the test set from inputPath/part-x-00000 sequence file of form <Text,VectorWritable> */ + private def readTestSet: DrmLike[_] = { + val inputPath = parser.opts("input").asInstanceOf[String] + val trainingSet= drm.drmDfsRead(inputPath) + trainingSet + } + + /** read the model from pathToModel using NBModel.DfsRead(...) */ + private def readModel: NBModel = { + val inputPath = parser.opts("pathToModel").asInstanceOf[String] + val model= NBModel.dfsRead(inputPath) + model + } + + override def process: Unit = { + start() + + val testComplementary = parser.opts("testComplementary").asInstanceOf[Boolean] + val outputPath = parser.opts("output").asInstanceOf[String] + + // todo: get the -ow option in to check for a model in the path and overwrite if flagged. + + val testSet = readTestSet + val model = readModel + val analyzer= NaiveBayes.test(model, testSet, testComplementary) + + println(analyzer) + + stop + } + +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala b/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala new file mode 100644 index 0000000..35ff90b --- /dev/null +++ b/spark/src/main/scala/org/apache/mahout/drivers/TrainNBDriver.scala @@ -0,0 +1,115 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package org.apache.mahout.drivers + +import org.apache.mahout.classifier.naivebayes._ +import org.apache.mahout.classifier.naivebayes.SparkNaiveBayes +import org.apache.mahout.math.drm +import org.apache.mahout.math.drm.DrmLike +import scala.collection.immutable.HashMap + + +object TrainNBDriver extends MahoutSparkDriver { + // define only the options specific to TrainNB + private final val trainNBOptipns = HashMap[String, Any]( + "appName" -> "TrainNBDriver") + + /** + * @param args Command line args, if empty a help message is printed. + */ + override def main(args: Array[String]): Unit = { + + parser = new MahoutSparkOptionParser(programName = "spark-trainnb") { + head("spark-trainnb", "Mahout 1.0") + + //Input output options, non-driver specific + parseIOOptions(numInputs = 1) + + //Algorithm control options--driver specific + opts = opts ++ trainNBOptipns + note("\nAlgorithm control options:") + + //default trainComplementary is false + opts = opts + ("trainComplementary" -> false) + opt[Unit]("trainComplementary") abbr ("c") action { (_, options) => + options + ("trainComplementary" -> true) + } text ("Train a complementary model, Default: false.") + + + //How to search for input + parseFileDiscoveryOptions + + //Drm output schema--not driver specific, drm specific + parseDrmFormatOptions + + //Spark config options--not driver specific + parseSparkOptions + + //Jar inclusion, this option can be set when executing the driver from compiled code, not when from CLI + parseGenericOptions + + help("help") abbr ("h") text ("prints this usage text\n") + + } + parser.parse(args, parser.opts) map { opts => + parser.opts = opts + process + } + } + + override def start(masterUrl: String = parser.opts("master").asInstanceOf[String], + appName: String = parser.opts("appName").asInstanceOf[String]): + Unit = { + + // will be only specific to this job. + // Note: set a large spark.kryoserializer.buffer.mb if using DSL MapBlock else leave as default + + if (parser.opts("sparkExecutorMem").asInstanceOf[String] != "") + sparkConf.set("spark.executor.memory", parser.opts("sparkExecutorMem").asInstanceOf[String]) + + // Note: set a large akka frame size for DSL NB (20) + // sparkConf.set("spark.akka.frameSize","20") // don't need this for Spark optimized NaiveBayes.. + // else leave as set in Spark config + + super.start(masterUrl, appName) + + } + + /** Read the training set from inputPath/part-x-00000 sequence file of form <Text,VectorWritable> */ + private def readTrainingSet: DrmLike[_]= { + val inputPath = parser.opts("input").asInstanceOf[String] + val trainingSet= drm.drmDfsRead(inputPath) + trainingSet + } + + override def process: Unit = { + start() + + val complementary = parser.opts("trainComplementary").asInstanceOf[Boolean] + val outputPath = parser.opts("output").asInstanceOf[String] + + val trainingSet = readTrainingSet + val (labelIndex, aggregatedObservations) = SparkNaiveBayes.extractLabelsAndAggregateObservations(trainingSet) + val model = NaiveBayes.train(aggregatedObservations, labelIndex) + + model.dfsWrite(outputPath) + + stop + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala ---------------------------------------------------------------------- diff --git a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala index 2a2a4a9..c04b306 100644 --- a/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala +++ b/spark/src/main/scala/org/apache/mahout/sparkbindings/drm/package.scala @@ -110,6 +110,4 @@ package object drm { key -> v } } - - } http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala b/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala new file mode 100644 index 0000000..d999f9b --- /dev/null +++ b/spark/src/test/scala/org/apache/mahout/classifier/naivebayes/NBSparkTestSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.math._ +import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.math.scalabindings._ +import org.apache.mahout.sparkbindings.test.DistributedSparkSuite +import org.apache.mahout.test.MahoutSuite +import org.scalatest.FunSuite + +class NBSparkTestSuite extends FunSuite with MahoutSuite with DistributedSparkSuite with NBTestBase { + + test("Spark NB Aggregator") { + + val rowBindings = new java.util.HashMap[String,Integer]() + rowBindings.put("/Cat1/doc_a/", 0) + rowBindings.put("/Cat2/doc_b/", 1) + rowBindings.put("/Cat1/doc_c/", 2) + rowBindings.put("/Cat2/doc_d/", 3) + rowBindings.put("/Cat1/doc_e/", 4) + + + val matrixSetup = sparse( + (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil, + (0, 0.0) ::(1, 0.1) ::(2, 0.0) ::(3, 0.1) :: Nil, + (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil, + (0, 0.0) ::(1, 0.1) ::(2, 0.0) ::(3, 0.1) :: Nil, + (0, 0.1) ::(1, 0.0) ::(2, 0.1) ::(3, 0.0) :: Nil + ) + + + matrixSetup.setRowLabelBindings(rowBindings) + + val TFIDFDrm = drm.drmParallelizeWithRowLabels(m = matrixSetup, numPartitions = 2) + + val (dslLabelIndex, dslAggregatedTFIDFDrm) = NaiveBayes.extractLabelsAndAggregateObservations(TFIDFDrm) + val (sparkLabelIndex, sparkAggregatedTFIDFDrm) = SparkNaiveBayes.extractLabelsAndAggregateObservations(TFIDFDrm) + + dslLabelIndex.size should be (2) + sparkLabelIndex.size should be (2) + + val dslCat1=dslLabelIndex("Cat1") + val dslCat2=dslLabelIndex("Cat2") + + val sparkCat1=sparkLabelIndex("Cat1") + val sparkCat2=sparkLabelIndex("Cat2") + + + dslCat1 should be (0) + dslCat2 should be (1) + + sparkCat1 should be (0) + sparkCat2 should be (1) + + val dslAggInCore = dslAggregatedTFIDFDrm.collect + val sparkAggInCore = sparkAggregatedTFIDFDrm.collect + + dslAggInCore.numCols should be (4) //4 + dslAggInCore.numRows should be (2) //2 + + dslAggInCore(dslCat1, 0) - sparkAggInCore(dslCat1, 0) should be < epsilon //0.3 + dslAggInCore(dslCat1, 1) - sparkAggInCore(dslCat1, 1) should be < epsilon //0.0 + dslAggInCore(dslCat1, 2) - sparkAggInCore(dslCat1, 2) should be < epsilon //0.3 + dslAggInCore(dslCat1, 3) - sparkAggInCore(dslCat1, 3) should be < epsilon //0.0 + dslAggInCore(dslCat2, 0) - sparkAggInCore(dslCat2, 0) should be < epsilon //0.0 + dslAggInCore(dslCat2, 1) - sparkAggInCore(dslCat2, 1) should be < epsilon //0.2 + dslAggInCore(dslCat2, 2) - sparkAggInCore(dslCat2, 2) should be < epsilon //0.0 + dslAggInCore(dslCat2, 3) - sparkAggInCore(dslCat2, 3) should be < epsilon //0.2 + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/31053431/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala ---------------------------------------------------------------------- diff --git a/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala b/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala new file mode 100644 index 0000000..3af7649 --- /dev/null +++ b/spark/src/test/scala/org/apache/mahout/classifier/stats/ClassifierStatsSparkTestSuite.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.stats + +import org.apache.mahout.sparkbindings.test.DistributedSparkSuite +import org.apache.mahout.test.MahoutSuite +import org.scalatest.FunSuite + +class ClassifierStatsSparkTestSuite extends FunSuite with MahoutSuite with DistributedSparkSuite with ClassifierStatsTestBase + +
