Revert "MAHOUT-1681: Renamed mahout-math-scala to mahout-samsara"
This reverts commit f7b69fabf1253b5e735e269c9410459d91816cdd. Project: http://git-wip-us.apache.org/repos/asf/mahout/repo Commit: http://git-wip-us.apache.org/repos/asf/mahout/commit/ef6d93a3 Tree: http://git-wip-us.apache.org/repos/asf/mahout/tree/ef6d93a3 Diff: http://git-wip-us.apache.org/repos/asf/mahout/diff/ef6d93a3 Branch: refs/heads/mahout-0.10.x Commit: ef6d93a34035c848335af57289c2c0acb58a2e4a Parents: f7b69fa Author: Stevo Slavic <[email protected]> Authored: Tue Apr 14 23:34:28 2015 +0200 Committer: Stevo Slavic <[email protected]> Committed: Tue Apr 14 23:34:28 2015 +0200 ---------------------------------------------------------------------- CHANGELOG | 2 - distribution/pom.xml | 2 +- distribution/src/main/assembly/bin.xml | 8 +- h2o/pom.xml | 4 +- math-scala/pom.xml | 197 +++++++ .../classifier/naivebayes/NBClassifier.scala | 119 ++++ .../mahout/classifier/naivebayes/NBModel.scala | 217 ++++++++ .../classifier/naivebayes/NaiveBayes.scala | 380 +++++++++++++ .../classifier/stats/ClassifierStats.scala | 467 ++++++++++++++++ .../classifier/stats/ConfusionMatrix.scala | 460 ++++++++++++++++ .../apache/mahout/drivers/MahoutDriver.scala | 44 ++ .../mahout/drivers/MahoutOptionParser.scala | 220 ++++++++ .../mahout/math/cf/SimilarityAnalysis.scala | 308 +++++++++++ .../apache/mahout/math/decompositions/ALS.scala | 140 +++++ .../apache/mahout/math/decompositions/DQR.scala | 74 +++ .../mahout/math/decompositions/DSPCA.scala | 153 ++++++ .../mahout/math/decompositions/DSSVD.scala | 82 +++ .../mahout/math/decompositions/SSVD.scala | 165 ++++++ .../mahout/math/decompositions/package.scala | 141 +++++ .../org/apache/mahout/math/drm/BCast.scala | 23 + .../org/apache/mahout/math/drm/CacheHint.scala | 19 + .../mahout/math/drm/CheckpointedDrm.scala | 47 ++ .../mahout/math/drm/CheckpointedOps.scala | 43 ++ .../mahout/math/drm/DistributedContext.scala | 27 + .../mahout/math/drm/DistributedEngine.scala | 215 ++++++++ .../mahout/math/drm/DrmDoubleScalarOps.scala | 33 ++ .../org/apache/mahout/math/drm/DrmLike.scala | 55 ++ .../org/apache/mahout/math/drm/DrmLikeOps.scala | 118 ++++ .../apache/mahout/math/drm/RLikeDrmOps.scala | 146 +++++ .../math/drm/logical/AbstractBinaryOp.scala | 54 ++ .../math/drm/logical/AbstractUnaryOp.scala | 37 ++ .../math/drm/logical/CheckpointAction.scala | 47 ++ .../apache/mahout/math/drm/logical/OpAB.scala | 41 ++ .../mahout/math/drm/logical/OpABAnyKey.scala | 41 ++ .../apache/mahout/math/drm/logical/OpABt.scala | 42 ++ .../apache/mahout/math/drm/logical/OpAewB.scala | 46 ++ .../mahout/math/drm/logical/OpAewScalar.scala | 45 ++ .../apache/mahout/math/drm/logical/OpAt.scala | 35 ++ .../apache/mahout/math/drm/logical/OpAtA.scala | 36 ++ .../mahout/math/drm/logical/OpAtAnyKey.scala | 34 ++ .../apache/mahout/math/drm/logical/OpAtB.scala | 42 ++ .../apache/mahout/math/drm/logical/OpAtx.scala | 41 ++ .../apache/mahout/math/drm/logical/OpAx.scala | 42 ++ .../mahout/math/drm/logical/OpCbind.scala | 42 ++ .../mahout/math/drm/logical/OpMapBlock.scala | 43 ++ .../apache/mahout/math/drm/logical/OpPar.scala | 18 + .../mahout/math/drm/logical/OpRbind.scala | 40 ++ .../mahout/math/drm/logical/OpRowRange.scala | 36 ++ .../math/drm/logical/OpTimesLeftMatrix.scala | 43 ++ .../math/drm/logical/OpTimesRightMatrix.scala | 46 ++ .../org/apache/mahout/math/drm/package.scala | 136 +++++ .../math/indexeddataset/IndexedDataset.scala | 63 +++ .../math/indexeddataset/ReaderWriter.scala | 117 ++++ .../mahout/math/indexeddataset/Schema.scala | 104 ++++ .../math/scalabindings/DoubleScalarOps.scala | 42 ++ .../scalabindings/MatlabLikeMatrixOps.scala | 66 +++ .../math/scalabindings/MatlabLikeOps.scala | 35 ++ .../math/scalabindings/MatlabLikeTimesOps.scala | 28 + .../scalabindings/MatlabLikeVectorOps.scala | 73 +++ .../mahout/math/scalabindings/MatrixOps.scala | 215 ++++++++ .../math/scalabindings/RLikeMatrixOps.scala | 94 ++++ .../mahout/math/scalabindings/RLikeOps.scala | 38 ++ .../math/scalabindings/RLikeTimesOps.scala | 28 + .../math/scalabindings/RLikeVectorOps.scala | 71 +++ .../mahout/math/scalabindings/VectorOps.scala | 141 +++++ .../mahout/math/scalabindings/package.scala | 297 ++++++++++ .../org/apache/mahout/nlp/tfidf/TFIDF.scala | 112 ++++ .../classifier/naivebayes/NBTestBase.scala | 291 ++++++++++ .../stats/ClassifierStatsTestBase.scala | 257 +++++++++ .../decompositions/DecompositionsSuite.scala | 113 ++++ .../DistributedDecompositionsSuiteBase.scala | 219 ++++++++ .../mahout/math/drm/DrmLikeOpsSuiteBase.scala | 93 ++++ .../mahout/math/drm/DrmLikeSuiteBase.scala | 76 +++ .../mahout/math/drm/RLikeDrmOpsSuiteBase.scala | 550 +++++++++++++++++++ .../mahout/math/scalabindings/MathSuite.scala | 214 ++++++++ .../MatlabLikeMatrixOpsSuite.scala | 67 +++ .../math/scalabindings/MatrixOpsSuite.scala | 185 +++++++ .../scalabindings/RLikeMatrixOpsSuite.scala | 80 +++ .../scalabindings/RLikeVectorOpsSuite.scala | 36 ++ .../math/scalabindings/VectorOpsSuite.scala | 82 +++ .../apache/mahout/nlp/tfidf/TFIDFtestBase.scala | 184 +++++++ .../mahout/test/DistributedMahoutSuite.scala | 28 + .../mahout/test/LoggerConfiguration.scala | 16 + .../org/apache/mahout/test/MahoutSuite.scala | 54 ++ pom.xml | 6 +- samsara/pom.xml | 194 ------- .../classifier/naivebayes/NBClassifier.scala | 119 ---- .../mahout/classifier/naivebayes/NBModel.scala | 217 -------- .../classifier/naivebayes/NaiveBayes.scala | 380 ------------- .../classifier/stats/ClassifierStats.scala | 467 ---------------- .../classifier/stats/ConfusionMatrix.scala | 460 ---------------- .../apache/mahout/drivers/MahoutDriver.scala | 44 -- .../mahout/drivers/MahoutOptionParser.scala | 220 -------- .../mahout/math/cf/SimilarityAnalysis.scala | 308 ----------- .../apache/mahout/math/decompositions/ALS.scala | 140 ----- .../apache/mahout/math/decompositions/DQR.scala | 74 --- .../mahout/math/decompositions/DSPCA.scala | 153 ------ .../mahout/math/decompositions/DSSVD.scala | 82 --- .../mahout/math/decompositions/SSVD.scala | 165 ------ .../mahout/math/decompositions/package.scala | 141 ----- .../org/apache/mahout/math/drm/BCast.scala | 23 - .../org/apache/mahout/math/drm/CacheHint.scala | 19 - .../mahout/math/drm/CheckpointedDrm.scala | 47 -- .../mahout/math/drm/CheckpointedOps.scala | 43 -- .../mahout/math/drm/DistributedContext.scala | 27 - .../mahout/math/drm/DistributedEngine.scala | 215 -------- .../mahout/math/drm/DrmDoubleScalarOps.scala | 33 -- .../org/apache/mahout/math/drm/DrmLike.scala | 55 -- .../org/apache/mahout/math/drm/DrmLikeOps.scala | 118 ---- .../apache/mahout/math/drm/RLikeDrmOps.scala | 146 ----- .../math/drm/logical/AbstractBinaryOp.scala | 54 -- .../math/drm/logical/AbstractUnaryOp.scala | 37 -- .../math/drm/logical/CheckpointAction.scala | 47 -- .../apache/mahout/math/drm/logical/OpAB.scala | 41 -- .../mahout/math/drm/logical/OpABAnyKey.scala | 41 -- .../apache/mahout/math/drm/logical/OpABt.scala | 42 -- .../apache/mahout/math/drm/logical/OpAewB.scala | 46 -- .../mahout/math/drm/logical/OpAewScalar.scala | 45 -- .../apache/mahout/math/drm/logical/OpAt.scala | 35 -- .../apache/mahout/math/drm/logical/OpAtA.scala | 36 -- .../mahout/math/drm/logical/OpAtAnyKey.scala | 34 -- .../apache/mahout/math/drm/logical/OpAtB.scala | 42 -- .../apache/mahout/math/drm/logical/OpAtx.scala | 41 -- .../apache/mahout/math/drm/logical/OpAx.scala | 42 -- .../mahout/math/drm/logical/OpCbind.scala | 42 -- .../mahout/math/drm/logical/OpMapBlock.scala | 43 -- .../apache/mahout/math/drm/logical/OpPar.scala | 18 - .../mahout/math/drm/logical/OpRbind.scala | 40 -- .../mahout/math/drm/logical/OpRowRange.scala | 36 -- .../math/drm/logical/OpTimesLeftMatrix.scala | 43 -- .../math/drm/logical/OpTimesRightMatrix.scala | 46 -- .../org/apache/mahout/math/drm/package.scala | 136 ----- .../math/indexeddataset/IndexedDataset.scala | 63 --- .../math/indexeddataset/ReaderWriter.scala | 117 ---- .../mahout/math/indexeddataset/Schema.scala | 104 ---- .../math/scalabindings/DoubleScalarOps.scala | 42 -- .../scalabindings/MatlabLikeMatrixOps.scala | 66 --- .../math/scalabindings/MatlabLikeOps.scala | 35 -- .../math/scalabindings/MatlabLikeTimesOps.scala | 28 - .../scalabindings/MatlabLikeVectorOps.scala | 73 --- .../mahout/math/scalabindings/MatrixOps.scala | 215 -------- .../math/scalabindings/RLikeMatrixOps.scala | 94 ---- .../mahout/math/scalabindings/RLikeOps.scala | 38 -- .../math/scalabindings/RLikeTimesOps.scala | 28 - .../math/scalabindings/RLikeVectorOps.scala | 71 --- .../mahout/math/scalabindings/VectorOps.scala | 141 ----- .../mahout/math/scalabindings/package.scala | 297 ---------- .../org/apache/mahout/nlp/tfidf/TFIDF.scala | 112 ---- .../classifier/naivebayes/NBTestBase.scala | 291 ---------- .../stats/ClassifierStatsTestBase.scala | 257 --------- .../decompositions/DecompositionsSuite.scala | 113 ---- .../DistributedDecompositionsSuiteBase.scala | 219 -------- .../mahout/math/drm/DrmLikeOpsSuiteBase.scala | 93 ---- .../mahout/math/drm/DrmLikeSuiteBase.scala | 76 --- .../mahout/math/drm/RLikeDrmOpsSuiteBase.scala | 550 ------------------- .../mahout/math/scalabindings/MathSuite.scala | 214 -------- .../MatlabLikeMatrixOpsSuite.scala | 67 --- .../math/scalabindings/MatrixOpsSuite.scala | 185 ------- .../scalabindings/RLikeMatrixOpsSuite.scala | 80 --- .../scalabindings/RLikeVectorOpsSuite.scala | 36 -- .../math/scalabindings/VectorOpsSuite.scala | 82 --- .../apache/mahout/nlp/tfidf/TFIDFtestBase.scala | 184 ------- .../mahout/test/DistributedMahoutSuite.scala | 28 - .../mahout/test/LoggerConfiguration.scala | 16 - .../org/apache/mahout/test/MahoutSuite.scala | 54 -- spark-shell/pom.xml | 2 +- spark/pom.xml | 4 +- 167 files changed, 8962 insertions(+), 8961 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/CHANGELOG ---------------------------------------------------------------------- diff --git a/CHANGELOG b/CHANGELOG index a3e39ac..777963a 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -2,8 +2,6 @@ Mahout Change Log Release 0.11.0 - unreleased - MAHOUT-1681: Renamed mahout-math-scala to mahout-samsara - MAHOUT-1680: Renamed mahout-distribution to apache-mahout-distribution Release 0.10.0 - 2015-04-11 http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/distribution/pom.xml ---------------------------------------------------------------------- diff --git a/distribution/pom.xml b/distribution/pom.xml index 3a47e08..bc17a08 100644 --- a/distribution/pom.xml +++ b/distribution/pom.xml @@ -115,7 +115,7 @@ </dependency> <dependency> <groupId>org.apache.mahout</groupId> - <artifactId>mahout-samsara_${scala.compat.version}</artifactId> + <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> </dependency> </dependencies> </project> http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/distribution/src/main/assembly/bin.xml ---------------------------------------------------------------------- diff --git a/distribution/src/main/assembly/bin.xml b/distribution/src/main/assembly/bin.xml index 5dd014c..c49ddc2 100644 --- a/distribution/src/main/assembly/bin.xml +++ b/distribution/src/main/assembly/bin.xml @@ -117,7 +117,7 @@ <outputDirectory/> </fileSet> <fileSet> - <directory>${project.basedir}/../samsara/target</directory> + <directory>${project.basedir}/../math-scala/target</directory> <includes> <include>mahout-*.jar</include> <include>mahout-*.job</include> @@ -193,12 +193,12 @@ <outputDirectory>docs/mahout-examples</outputDirectory> </fileSet> <fileSet> - <directory>${project.basedir}/../samsara/target/site/scaladocs</directory> - <outputDirectory>docs/mahout-samsara</outputDirectory> + <directory>${project.basedir}/../math-scala/target/site/scaladocs</directory> + <outputDirectory>docs/mahout-examples</outputDirectory> </fileSet> <fileSet> <directory>${project.basedir}/../spark/target/site/scaladocs</directory> - <outputDirectory>docs/mahout-spark</outputDirectory> + <outputDirectory>docs/mahout-examples</outputDirectory> </fileSet> <fileSet> <directory>${project.basedir}/../spark-shell/target/site/scaladocs</directory> http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/h2o/pom.xml ---------------------------------------------------------------------- diff --git a/h2o/pom.xml b/h2o/pom.xml index c0ccdcc..b9d101a 100644 --- a/h2o/pom.xml +++ b/h2o/pom.xml @@ -127,7 +127,7 @@ <dependency> <groupId>org.apache.mahout</groupId> - <artifactId>mahout-samsara_${scala.compat.version}</artifactId> + <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> <version>${project.version}</version> </dependency> @@ -140,7 +140,7 @@ <dependency> <groupId>org.apache.mahout</groupId> - <artifactId>mahout-samsara_${scala.compat.version}</artifactId> + <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> <classifier>tests</classifier> <scope>test</scope> </dependency> http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/pom.xml ---------------------------------------------------------------------- diff --git a/math-scala/pom.xml b/math-scala/pom.xml new file mode 100644 index 0000000..78331dd --- /dev/null +++ b/math-scala/pom.xml @@ -0,0 +1,197 @@ +<?xml version="1.0" encoding="UTF-8"?> + +<!-- + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--> + +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + <modelVersion>4.0.0</modelVersion> + + <parent> + <groupId>org.apache.mahout</groupId> + <artifactId>mahout</artifactId> + <version>0.11.0-SNAPSHOT</version> + <relativePath>../pom.xml</relativePath> + </parent> + + <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> + <name>Mahout Math Scala bindings</name> + <description>High performance scientific and technical computing data structures and methods, + mostly based on CERN's + Colt Java API + </description> + + <packaging>jar</packaging> + + <build> + <plugins> + <!-- create test jar so other modules can reuse the math-scala test utility classes. --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-jar-plugin</artifactId> + <executions> + <execution> + <goals> + <goal>test-jar</goal> + </goals> + <phase>package</phase> + </execution> + </executions> + </plugin> + + <plugin> + <artifactId>maven-javadoc-plugin</artifactId> + </plugin> + + <plugin> + <artifactId>maven-source-plugin</artifactId> + </plugin> + + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <executions> + <execution> + <id>add-scala-sources</id> + <phase>initialize</phase> + <goals> + <goal>add-source</goal> + </goals> + </execution> + <execution> + <id>scala-compile</id> + <phase>process-resources</phase> + <goals> + <goal>compile</goal> + </goals> + </execution> + <execution> + <id>scala-test-compile</id> + <phase>process-test-resources</phase> + <goals> + <goal>testCompile</goal> + </goals> + </execution> + </executions> + </plugin> + + <!--this is what scalatest recommends to do to enable scala tests --> + + <!-- disable surefire --> + <plugin> + <groupId>org.apache.maven.plugins</groupId> + <artifactId>maven-surefire-plugin</artifactId> + <configuration> + <skipTests>true</skipTests> + </configuration> + </plugin> + <!-- enable scalatest --> + <plugin> + <groupId>org.scalatest</groupId> + <artifactId>scalatest-maven-plugin</artifactId> + <executions> + <execution> + <id>test</id> + <goals> + <goal>test</goal> + </goals> + </execution> + </executions> + </plugin> + + </plugins> + </build> + + <dependencies> + + <dependency> + <groupId>org.apache.mahout</groupId> + <artifactId>mahout-math</artifactId> + </dependency> + + <!-- 3rd-party --> + <dependency> + <groupId>log4j</groupId> + <artifactId>log4j</artifactId> + </dependency> + + <dependency> + <groupId>com.github.scopt</groupId> + <artifactId>scopt_${scala.compat.version}</artifactId> + <version>3.3.0</version> + </dependency> + + <!-- scala stuff --> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-compiler</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-reflect</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-library</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scala-actors</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scala-lang</groupId> + <artifactId>scalap</artifactId> + <version>${scala.version}</version> + </dependency> + <dependency> + <groupId>org.scalatest</groupId> + <artifactId>scalatest_${scala.compat.version}</artifactId> + </dependency> + + </dependencies> + + <profiles> + <profile> + <id>mahout-release</id> + <build> + <plugins> + <plugin> + <groupId>net.alchim31.maven</groupId> + <artifactId>scala-maven-plugin</artifactId> + <executions> + <execution> + <id>generate-scaladoc</id> + <goals> + <goal>doc</goal> + </goals> + </execution> + <execution> + <id>attach-scaladoc-jar</id> + <goals> + <goal>doc-jar</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> + </profile> + </profiles> +</project> http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala new file mode 100644 index 0000000..5de0733 --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala @@ -0,0 +1,119 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.math.Vector +import scala.collection.JavaConversions._ + +/** + * Abstract Classifier base for Complentary and Standard Classifiers + * @param nbModel a trained NBModel + */ +abstract class AbstractNBClassifier(nbModel: NBModel) extends java.io.Serializable { + + // Trained Naive Bayes Model + val model = nbModel + + /** scoring method for standard and complementary classifiers */ + protected def getScoreForLabelFeature(label: Int, feature: Int): Double + + /** getter for model */ + protected def getModel: NBModel= { + model + } + + /** + * Compute the score for a Vector of weighted TF-IDF featured + * @param label Label to be scored + * @param instance Vector of weights to be calculate score + * @return score for this Label + */ + protected def getScoreForLabelInstance(label: Int, instance: Vector): Double = { + var result: Double = 0.0 + for (e <- instance.nonZeroes) { + result += e.get * getScoreForLabelFeature(label, e.index) + } + result + } + + /** number of categories the model has been trained on */ + def numCategories: Int = { + model.numLabels + } + + /** + * get a scoring vector for a vector of TF of TF-IDF weights + * @param instance vector of TF of TF-IDF weights to be classified + * @return a vector of scores. + */ + def classifyFull(instance: Vector): Vector = { + classifyFull(model.createScoringVector, instance) + } + + /** helper method for classifyFull(Vector) */ + def classifyFull(r: Vector, instance: Vector): Vector = { + var label: Int = 0 + for (label <- 0 until model.numLabels) { + r.setQuick(label, getScoreForLabelInstance(label, instance)) + } + r + } +} + +/** + * Standard Multinomial Naive Bayes Classifier + * @param nbModel a trained NBModel + */ +class StandardNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable{ + override def getScoreForLabelFeature(label: Int, feature: Int): Double = { + val model: NBModel = getModel + StandardNBClassifier.computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI, model.numFeatures) + } +} + +/** helper object for StandardNBClassifier */ +object StandardNBClassifier extends java.io.Serializable { + /** Compute Standard Multinomial Naive Bayes Weights See Rennie et. al. Section 2.1 */ + def computeWeight(featureLabelWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = { + val numerator: Double = featureLabelWeight + alphaI + val denominator: Double = labelWeight + alphaI * numFeatures + return Math.log(numerator / denominator) + } +} + +/** + * Complementary Naive Bayes Classifier + * @param nbModel a trained NBModel + */ +class ComplementaryNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable { + override def getScoreForLabelFeature(label: Int, feature: Int): Double = { + val model: NBModel = getModel + val weight: Double = ComplementaryNBClassifier.computeWeight(model.featureWeight(feature), model.weight(label, feature), model.totalWeightSum, model.labelWeight(label), model.alphaI, model.numFeatures) + return weight / model.thetaNormalizer(label) + } +} + +/** helper object for ComplementaryNBClassifier */ +object ComplementaryNBClassifier extends java.io.Serializable { + + /** Compute Complementary weights See Rennie et. al. Section 3.1 */ + def computeWeight(featureWeight: Double, featureLabelWeight: Double, totalWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = { + val numerator: Double = featureWeight - featureLabelWeight + alphaI + val denominator: Double = totalWeight - labelWeight + alphaI * numFeatures + return -Math.log(numerator / denominator) + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala new file mode 100644 index 0000000..3ceae96 --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.math._ + +import org.apache.mahout.math.{drm, scalabindings} + +import scalabindings._ +import scalabindings.RLikeOps._ +import drm.RLikeDrmOps._ +import drm._ +import scala.collection.JavaConverters._ +import scala.language.asInstanceOf +import scala.collection._ +import JavaConversions._ + +/** + * + * @param weightsPerLabelAndFeature Aggregated matrix of weights of labels x features + * @param weightsPerFeature Vector of summation of all feature weights. + * @param weightsPerLabel Vector of summation of all label weights. + * @param perlabelThetaNormalizer Vector of weight normalizers per label (used only for complemtary models) + * @param labelIndex HashMap of labels and their corresponding row in the weightMatrix + * @param alphaI Laplace smoothing factor. + * @param isComplementary Whether or not this is a complementary model. + */ +class NBModel(val weightsPerLabelAndFeature: Matrix = null, + val weightsPerFeature: Vector = null, + val weightsPerLabel: Vector = null, + val perlabelThetaNormalizer: Vector = null, + val labelIndex: Map[String, Integer] = null, + val alphaI: Float = 1.0f, + val isComplementary: Boolean= false) extends java.io.Serializable { + + + val numFeatures: Double = weightsPerFeature.getNumNondefaultElements + val totalWeightSum: Double = weightsPerLabel.zSum + val alphaVector: Vector = null + + validate() + + // todo: Maybe it is a good idea to move the dfsWrite and dfsRead out + // todo: of the model and into a helper + + // TODO: weightsPerLabelAndFeature, a sparse (numFeatures x numLabels) matrix should fit + // TODO: upfront in memory and should not require a DRM decide if we want this to scale out. + + + /** getter for summed label weights. Used by legacy classifier */ + def labelWeight(label: Int): Double = { + weightsPerLabel.getQuick(label) + } + + /** getter for weight normalizers. Used by legacy classifier */ + def thetaNormalizer(label: Int): Double = { + perlabelThetaNormalizer.get(label) + } + + /** getter for summed feature weights. Used by legacy classifier */ + def featureWeight(feature: Int): Double = { + weightsPerFeature.getQuick(feature) + } + + /** getter for individual aggregated weights. Used by legacy classifier */ + def weight(label: Int, feature: Int): Double = { + weightsPerLabelAndFeature.getQuick(label, feature) + } + + /** getter for a single empty vector of weights */ + def createScoringVector: Vector = { + weightsPerLabel.like + } + + /** getter for a the number of labels to consider */ + def numLabels: Int = { + weightsPerLabel.size + } + + /** + * Write a trained model to the filesystem as a series of DRMs + * @param pathToModel Directory to which the model will be written + */ + def dfsWrite(pathToModel: String)(implicit ctx: DistributedContext): Unit = { + //todo: write out as smaller partitions or possibly use reader and writers to + //todo: write something other than a DRM for label Index, is Complementary, alphaI. + + // add a directory to put all of the DRMs in + val fullPathToModel = pathToModel + NBModel.modelBaseDirectory + + drmParallelize(weightsPerLabelAndFeature).dfsWrite(fullPathToModel + "/weightsPerLabelAndFeatureDrm.drm") + drmParallelize(sparse(weightsPerFeature)).dfsWrite(fullPathToModel + "/weightsPerFeatureDrm.drm") + drmParallelize(sparse(weightsPerLabel)).dfsWrite(fullPathToModel + "/weightsPerLabelDrm.drm") + drmParallelize(sparse(perlabelThetaNormalizer)).dfsWrite(fullPathToModel + "/perlabelThetaNormalizerDrm.drm") + drmParallelize(sparse(svec((0,alphaI)::Nil))).dfsWrite(fullPathToModel + "/alphaIDrm.drm") + + // isComplementry is true if isComplementaryDrm(0,0) == 1 else false + val isComplementaryDrm = sparse(0 to 1, 0 to 1) + if(isComplementary){ + isComplementaryDrm(0,0) = 1.0 + } else { + isComplementaryDrm(0,0) = 0.0 + } + drmParallelize(isComplementaryDrm).dfsWrite(fullPathToModel + "/isComplementaryDrm.drm") + + // write the label index as a String-Keyed DRM. + val labelIndexDummyDrm = weightsPerLabelAndFeature.like() + labelIndexDummyDrm.setRowLabelBindings(labelIndex) + // get a reverse map of [Integer, String] and set the value of firsr column of the drm + // to the corresponding row number for it's Label (the rows may not be read back in the same order) + val revMap = labelIndex.map(x => x._2 -> x._1) + for(i <- 0 until labelIndexDummyDrm.numRows() ){ + labelIndexDummyDrm.set(labelIndex(revMap(i)), 0, i.toDouble) + } + + drmParallelizeWithRowLabels(labelIndexDummyDrm).dfsWrite(fullPathToModel + "/labelIndex.drm") + } + + /** Model Validation */ + def validate() { + assert(alphaI > 0, "alphaI has to be greater than 0!") + assert(numFeatures > 0, "the vocab count has to be greater than 0!") + assert(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!") + assert(weightsPerLabel != null, "the number of labels has to be defined!") + assert(weightsPerLabel.getNumNondefaultElements > 0, "the number of labels has to be greater than 0!") + assert(weightsPerFeature != null, "the feature sums have to be defined") + assert(weightsPerFeature.getNumNondefaultElements > 0, "the feature sums have to be greater than 0!") + if (isComplementary) { + assert(perlabelThetaNormalizer != null, "the theta normalizers have to be defined") + assert(perlabelThetaNormalizer.getNumNondefaultElements > 0, "the number of theta normalizers has to be greater than 0!") + assert(Math.signum(perlabelThetaNormalizer.minValue) == Math.signum(perlabelThetaNormalizer.maxValue), "Theta normalizers do not all have the same sign") + assert(perlabelThetaNormalizer.getNumNonZeroElements == perlabelThetaNormalizer.size, "Weight normalizers can not have zero value.") + } + assert(labelIndex.size == weightsPerLabel.getNumNondefaultElements, "label index must have entries for all labels") + } +} + +object NBModel extends java.io.Serializable { + + val modelBaseDirectory = "/naiveBayesModel" + + /** + * Read a trained model in from from the filesystem. + * @param pathToModel directory from which to read individual model components + * @return a valid NBModel + */ + def dfsRead(pathToModel: String)(implicit ctx: DistributedContext): NBModel = { + //todo: Takes forever to read we need a more practical method of writing models. Readers/Writers? + + // read from a base directory for all drms + val fullPathToModel = pathToModel + modelBaseDirectory + + val weightsPerFeatureDrm = drmDfsRead(fullPathToModel + "/weightsPerFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val weightsPerFeature = weightsPerFeatureDrm.collect(0, ::) + weightsPerFeatureDrm.uncache() + + val weightsPerLabelDrm = drmDfsRead(fullPathToModel + "/weightsPerLabelDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val weightsPerLabel = weightsPerLabelDrm.collect(0, ::) + weightsPerLabelDrm.uncache() + + val alphaIDrm = drmDfsRead(fullPathToModel + "/alphaIDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val alphaI: Float = alphaIDrm.collect(0, 0).toFloat + alphaIDrm.uncache() + + // isComplementry is true if isComplementaryDrm(0,0) == 1 else false + val isComplementaryDrm = drmDfsRead(fullPathToModel + "/isComplementaryDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val isComplementary = isComplementaryDrm.collect(0, 0).toInt == 1 + isComplementaryDrm.uncache() + + var perLabelThetaNormalizer= weightsPerFeature.like() + if (isComplementary) { + val perLabelThetaNormalizerDrm = drm.drmDfsRead(fullPathToModel + "/perlabelThetaNormalizerDrm.drm") + .checkpoint(CacheHint.MEMORY_ONLY) + perLabelThetaNormalizer = perLabelThetaNormalizerDrm.collect(0, ::) + } + + val dummyLabelDrm= drmDfsRead(fullPathToModel + "/labelIndex.drm") + .checkpoint(CacheHint.MEMORY_ONLY) + val labelIndexMap:java.util.Map[String, Integer] = dummyLabelDrm.getRowLabelBindings + dummyLabelDrm.uncache() + + // map the labels to the corresponding row numbers of weightsPerFeatureDrm (values in dummyLabelDrm) + val scalaLabelIndexMap: mutable.Map[String, Integer] = + labelIndexMap.map(x => x._1 -> dummyLabelDrm.get(labelIndexMap(x._1), 0) + .toInt + .asInstanceOf[Integer]) + + val weightsPerLabelAndFeatureDrm = drmDfsRead(fullPathToModel + "/weightsPerLabelAndFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) + val weightsPerLabelAndFeature = weightsPerLabelAndFeatureDrm.collect + weightsPerLabelAndFeatureDrm.uncache() + + // model validation is triggered automatically by constructor + val model: NBModel = new NBModel(weightsPerLabelAndFeature, + weightsPerFeature, + weightsPerLabel, + perLabelThetaNormalizer, + scalaLabelIndexMap, + alphaI, + isComplementary) + + model + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala new file mode 100644 index 0000000..a15ca09 --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes + +import org.apache.mahout.classifier.stats.{ResultAnalyzer, ClassifierResult} +import org.apache.mahout.math._ +import scalabindings._ +import scalabindings.RLikeOps._ +import drm.RLikeDrmOps._ +import drm._ +import scala.reflect.ClassTag +import scala.language.asInstanceOf +import collection._ +import scala.collection.JavaConversions._ + +/** + * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor + * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf + */ +trait NaiveBayes extends java.io.Serializable{ + + /** default value for the Laplacian smoothing parameter */ + def defaultAlphaI = 1.0f + + // function to extract categories from string keys + type CategoryParser = String => String + + /** Default: seqdirectory/seq2Sparse Categories are Stored in Drm Keys as: /Category/document_id */ + def seq2SparseCategoryParser: CategoryParser = x => x.split("/")(1) + + + /** + * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor + * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf + * + * @param observationsPerLabel a DrmLike[Int] matrix containing term frequency counts for each label. + * @param trainComplementary whether or not to train a complementary Naive Bayes model + * @param alphaI Laplace smoothing parameter + * @return trained naive bayes model + */ + def train(observationsPerLabel: DrmLike[Int], + labelIndex: Map[String, Integer], + trainComplementary: Boolean = true, + alphaI: Float = defaultAlphaI): NBModel = { + + // Summation of all weights per feature + val weightsPerFeature = observationsPerLabel.colSums + + // Distributed summation of all weights per label + val weightsPerLabel = observationsPerLabel.rowSums + + // Collect a matrix to pass to the NaiveBayesModel + val inCoreTFIDF = observationsPerLabel.collect + + // perLabelThetaNormalizer Vector is expected by NaiveBayesModel. We can pass a null value + // or Vector of zeroes in the case of a standard NB model. + var thetaNormalizer = weightsPerFeature.like() + + // Instantiate a trainer and retrieve the perLabelThetaNormalizer Vector from it in the case of + // a complementary NB model + if (trainComplementary) { + val thetaTrainer = new ComplementaryNBThetaTrainer(weightsPerFeature, + weightsPerLabel, + alphaI) + // local training of the theta normalization + for (labelIndex <- 0 until inCoreTFIDF.nrow) { + thetaTrainer.train(labelIndex, inCoreTFIDF(labelIndex, ::)) + } + thetaNormalizer = thetaTrainer.retrievePerLabelThetaNormalizer + } + + new NBModel(inCoreTFIDF, + weightsPerFeature, + weightsPerLabel, + thetaNormalizer, + labelIndex, + alphaI, + trainComplementary) + } + + /** + * Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse + * and aggregate TF or TF-IDF values by their label + * Override this method in engine specific modules to optimize + * + * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse + * in form K = eg./Category/document_title + * V = TF or TF-IDF values per term + * @param cParser a String => String function used to extract categories from + * Keys of the stringKeyedObservations DRM. The default + * CategoryParser will extract "Category" from: '/Category/document_id' + * @return (labelIndexMap,aggregatedByLabelObservationDrm) + * labelIndexMap is a HashMap [String, Integer] K = label row index + * V = label + * aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated + * TF or TF-IDF counts per label + */ + def extractLabelsAndAggregateObservations[K: ClassTag](stringKeyedObservations: DrmLike[K], + cParser: CategoryParser = seq2SparseCategoryParser) + (implicit ctx: DistributedContext): + (mutable.HashMap[String, Integer], DrmLike[Int])= { + + stringKeyedObservations.checkpoint() + + val numDocs=stringKeyedObservations.nrow + val numFeatures=stringKeyedObservations.ncol + + // Extract categories from labels assigned by seq2sparse + // Categories are Stored in Drm Keys as eg.: /Category/document_id + + // Get a new DRM with a single column so that we don't have to collect the + // DRM into memory upfront. + val strippedObeservations= stringKeyedObservations.mapBlock(ncol=1){ + case(keys, block) => + val blockB = block.like(keys.size, 1) + keys -> blockB + } + + // Extract the row label bindings (the String keys) from the slim Drm + // strip the document_id from the row keys keeping only the category. + // Sort the bindings alphabetically into a Vector + val labelVectorByRowIndex = strippedObeservations + .getRowLabelBindings + .map(x => x._2 -> cParser(x._1)) + .toVector.sortWith(_._1 < _._1) + + //TODO: add a .toIntKeyed(...) method to DrmLike? + + // Copy stringKeyedObservations to an Int-Keyed Drm so that we can compute transpose + // Copy the Collected Matrices up front for now until we hav a distributed way of converting + val inCoreStringKeyedObservations = stringKeyedObservations.collect + val inCoreIntKeyedObservations = new SparseMatrix( + stringKeyedObservations.nrow.toInt, + stringKeyedObservations.ncol) + for (i <- 0 until inCoreStringKeyedObservations.nrow.toInt) { + inCoreIntKeyedObservations(i, ::) = inCoreStringKeyedObservations(i, ::) + } + + val intKeyedObservations= drmParallelize(inCoreIntKeyedObservations) + + stringKeyedObservations.uncache() + + var labelIndex = 0 + val labelIndexMap = new mutable.HashMap[String, Integer] + val encodedLabelByRowIndexVector = new DenseVector(labelVectorByRowIndex.size) + + // Encode Categories as an Integer (Double) so we can broadcast as a vector + // where each element is an Int-encoded category whose index corresponds + // to its row in the Drm + for (i <- 0 until labelVectorByRowIndex.size) { + if (!(labelIndexMap.contains(labelVectorByRowIndex(i)._2))) { + encodedLabelByRowIndexVector(i) = labelIndex.toDouble + labelIndexMap.put(labelVectorByRowIndex(i)._2, labelIndex) + labelIndex += 1 + } + // don't like this casting but need to use a java.lang.Integer when setting rowLabelBindings + encodedLabelByRowIndexVector(i) = labelIndexMap + .getOrElse(labelVectorByRowIndex(i)._2, -1) + .asInstanceOf[Int].toDouble + } + + // "Combiner": Map and aggregate by Category. Do this by broadcasting the encoded + // category vector and mapping a transposed IntKeyed Drm out so that all categories + // will be present on all nodes as columns and can be referenced by + // BCastEncodedCategoryByRowVector. Iteratively sum all categories. + val nLabels = labelIndex + + val bcastEncodedCategoryByRowVector = drmBroadcast(encodedLabelByRowIndexVector) + + val aggregetedObservationByLabelDrm = intKeyedObservations.t.mapBlock(ncol = nLabels) { + case (keys, blockA) => + val blockB = blockA.like(keys.size, nLabels) + var label : Int = 0 + for (i <- 0 until keys.size) { + blockA(i, ::).nonZeroes().foreach { elem => + label = bcastEncodedCategoryByRowVector.get(elem.index).toInt + blockB(i, label) = blockB(i, label) + blockA(i, elem.index) + } + } + keys -> blockB + }.t + + (labelIndexMap, aggregetedObservationByLabelDrm) + } + + /** + * Test a trained model with a labeled dataset sequentially + * @param model a trained NBModel + * @param testSet a labeled testing set + * @param testComplementary test using a complementary or a standard NB classifier + * @param cParser a String => String function used to extract categories from + * Keys of the testing set DRM. The default + * CategoryParser will extract "Category" from: '/Category/document_id' + * + * *Note*: this method brings the entire test set into upfront memory, + * This method is optimized and parallelized in SparkNaiveBayes + * + * @tparam K implicitly determined Key type of test set DRM: String + * @return a result analyzer with confusion matrix and accuracy statistics + */ + def test[K: ClassTag](model: NBModel, + testSet: DrmLike[K], + testComplementary: Boolean = false, + cParser: CategoryParser = seq2SparseCategoryParser) + (implicit ctx: DistributedContext): ResultAnalyzer = { + + val labelMap = model.labelIndex + + val numLabels = model.numLabels + + testSet.checkpoint() + + val numTestInstances = testSet.nrow.toInt + + // instantiate the correct type of classifier + val classifier = testComplementary match { + case true => new ComplementaryNBClassifier(model) with Serializable + case _ => new StandardNBClassifier(model) with Serializable + } + + if (testComplementary) { + assert(testComplementary == model.isComplementary, + "Complementary Label Assignment requires Complementary Training") + } + + + // Sequentially assign labels to the test set: + // *Note* this brings the entire test set into memory upfront: + + // Since we cant broadcast the model as is do it sequentially up front for now + val inCoreTestSet = testSet.collect + + // get the labels of the test set and extract the keys + val testSetLabelMap = testSet.getRowLabelBindings + + // empty Matrix in which we'll set the classification scores + val inCoreScoredTestSet = testSet.like(numTestInstances, numLabels) + + testSet.uncache() + + for (i <- 0 until numTestInstances) { + inCoreScoredTestSet(i, ::) := classifier.classifyFull(inCoreTestSet(i, ::)) + } + + // todo: reverse the labelMaps in training and through the model? + + // reverse the label map and extract the labels + val reverseTestSetLabelMap = testSetLabelMap.map(x => x._2 -> cParser(x._1)) + + val reverseLabelMap = labelMap.map(x => x._2 -> x._1) + + val analyzer = new ResultAnalyzer(labelMap.keys.toList.sorted, "DEFAULT") + + // assign labels- winner takes all + for (i <- 0 until numTestInstances) { + val (bestIdx, bestScore) = argmax(inCoreScoredTestSet(i, ::)) + val classifierResult = new ClassifierResult(reverseLabelMap(bestIdx), bestScore) + analyzer.addInstance(reverseTestSetLabelMap(i), classifierResult) + } + + analyzer + } + + /** + * argmax with values as well + * returns a tuple of index of the max score and the score itself. + * @param v Vector of of scores + * @return (bestIndex, bestScore) + */ + def argmax(v: Vector): (Int, Double) = { + var bestIdx: Int = Integer.MIN_VALUE + var bestScore: Double = Integer.MIN_VALUE.asInstanceOf[Int].toDouble + for(i <- 0 until v.size) { + if(v(i) > bestScore){ + bestScore = v(i) + bestIdx = i + } + } + (bestIdx, bestScore) + } + +} + +object NaiveBayes extends NaiveBayes with java.io.Serializable + +/** + * Trainer for the weight normalization vector used by Transform Weight Normalized Complement + * Naive Bayes. See: Rennie et.al.: Tackling the poor assumptions of Naive Bayes Text classifiers, + * ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf Sec. 3.2. + * + * @param weightsPerFeature a Vector of summed TF or TF-IDF weights for each word in dictionary. + * @param weightsPerLabel a Vector of summed TF or TF-IDF weights for each label. + * @param alphaI Laplace smoothing factor. Defaut value of 1. + */ +class ComplementaryNBThetaTrainer(private val weightsPerFeature: Vector, + private val weightsPerLabel: Vector, + private val alphaI: Double = 1.0) { + + private val perLabelThetaNormalizer: Vector = weightsPerLabel.like() + private val totalWeightSum: Double = weightsPerLabel.zSum + private var numFeatures: Double = weightsPerFeature.getNumNondefaultElements + + assert(weightsPerFeature != null, "weightsPerFeature vector can not be null") + assert(weightsPerLabel != null, "weightsPerLabel vector can not be null") + + /** + * Train the weight normalization vector for each label + * @param label + * @param featurePerLabelWeight + */ + def train(label: Int, featurePerLabelWeight: Vector) { + val currentLabelWeight = labelWeight(label) + // sum weights for each label including those with zero word counts + for (i <- 0 until featurePerLabelWeight.size) { + val currentFeaturePerLabelWeight = featurePerLabelWeight(i) + updatePerLabelThetaNormalizer(label, + ComplementaryNBClassifier.computeWeight(featureWeight(i), + currentFeaturePerLabelWeight, + totalWeightSum, + currentLabelWeight, + alphaI, + numFeatures) + ) + } + } + + /** + * getter for summed TF or TF-IDF weights by label + * @param label index of label + * @return sum of word TF or TF-IDF weights for label + */ + def labelWeight(label: Int): Double = { + weightsPerLabel(label) + } + + /** + * getter for summed TF or TF-IDF weights by word. + * @param feature index of word. + * @return sum of TF or TF-IDF weights for word. + */ + def featureWeight(feature: Int): Double = { + weightsPerFeature(feature) + } + + /** + * add the magnitude of the current weight to the current + * label's corresponding Vector element. + * @param label index of label to update. + * @param weight weight to add. + */ + def updatePerLabelThetaNormalizer(label: Int, weight: Double) { + perLabelThetaNormalizer(label) = perLabelThetaNormalizer(label) + Math.abs(weight) + } + + /** + * Getter for the weight normalizer vector as indexed by label + * @return a copy of the weight normalizer vector. + */ + def retrievePerLabelThetaNormalizer: Vector = { + perLabelThetaNormalizer.cloned + } + + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala new file mode 100644 index 0000000..8f1413a --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala @@ -0,0 +1,467 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package org.apache.mahout.classifier.stats + +import java.text.{DecimalFormat, NumberFormat} +import java.util +import org.apache.mahout.math.stats.OnlineSummarizer + + +/** + * Result of a document classification. The label and the associated score (usually probabilty) + */ +class ClassifierResult (private var label: String = null, + private var score: Double = 0.0, + private var logLikelihood: Double = Integer.MAX_VALUE.toDouble) { + + def getLogLikelihood: Double = logLikelihood + + def setLogLikelihood(llh: Double) { + logLikelihood = llh + } + + def getLabel: String = label + + def getScore: Double = score + + def setLabel(lbl: String) { + label = lbl + } + + def setScore(sc: Double) { + score = sc + } + + override def toString: String = { + "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}' + } + +} + +/** + * ResultAnalyzer captures the classification statistics and displays in a tabular manner + * @param labelSet Set of labels to be considered in classification + * @param defaultLabel the default label for an unknown classification + */ +class ResultAnalyzer(private val labelSet: util.Collection[String], defaultLabel: String) { + + val confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel) + val summarizer = new OnlineSummarizer + + private var hasLL: Boolean = false + private var correctlyClassified: Int = 0 + private var incorrectlyClassified: Int = 0 + + + def getConfusionMatrix: ConfusionMatrix = confusionMatrix + + /** + * + * @param correctLabel + * The correct label + * @param classifiedResult + * The classified result + * @return whether the instance was correct or not + */ + def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Boolean = { + val result: Boolean = correctLabel == classifiedResult.getLabel + if (result) { + correctlyClassified += 1 + } + else { + incorrectlyClassified += 1 + } + confusionMatrix.addInstance(correctLabel, classifiedResult) + if (classifiedResult.getLogLikelihood != Integer.MAX_VALUE.toDouble) { + summarizer.add(classifiedResult.getLogLikelihood) + hasLL = true + } + + result + } + + /** Dump the resulting statistics to a string */ + override def toString: String = { + val returnString: StringBuilder = new StringBuilder + returnString.append('\n') + returnString.append("=======================================================\n") + returnString.append("Summary\n") + returnString.append("-------------------------------------------------------\n") + val totalClassified: Int = correctlyClassified + incorrectlyClassified + val percentageCorrect: Double = 100.asInstanceOf[Double] * correctlyClassified / totalClassified + val percentageIncorrect: Double = 100.asInstanceOf[Double] * incorrectlyClassified / totalClassified + val decimalFormatter: NumberFormat = new DecimalFormat("0.####") + returnString.append("Correctly Classified Instances") + .append(": ") + .append(Integer.toString(correctlyClassified)) + .append('\t') + .append(decimalFormatter.format(percentageCorrect)) + .append("%\n") + returnString.append("Incorrectly Classified Instances") + .append(": ") + .append(Integer.toString(incorrectlyClassified)) + .append('\t') + .append(decimalFormatter.format(percentageIncorrect)) + .append("%\n") + returnString.append("Total Classified Instances") + .append(": ") + .append(Integer.toString(totalClassified)) + .append('\n') + returnString.append('\n') + returnString.append(confusionMatrix) + returnString.append("=======================================================\n") + returnString.append("Statistics\n") + returnString.append("-------------------------------------------------------\n") + val normStats: RunningAverageAndStdDev = confusionMatrix.getNormalizedStats + returnString.append("Kappa: \t") + .append(decimalFormatter.format(confusionMatrix.getKappa)) + .append('\n') + returnString.append("Accuracy: \t") + .append(decimalFormatter.format(confusionMatrix.getAccuracy)) + .append("%\n") + returnString.append("Reliability: \t") + .append(decimalFormatter.format(normStats.getAverage * 100.00000001)) + .append("%\n") + returnString.append("Reliability (std dev): \t") + .append(decimalFormatter.format(normStats.getStandardDeviation)) + .append('\n') + returnString.append("Weighted precision: \t") + .append(decimalFormatter.format(confusionMatrix.getWeightedPrecision)) + .append('\n') + returnString.append("Weighted recall: \t") + .append(decimalFormatter.format(confusionMatrix.getWeightedRecall)) + .append('\n') + returnString.append("Weighted F1 score: \t") + .append(decimalFormatter.format(confusionMatrix.getWeightedF1score)) + .append('\n') + if (hasLL) { + returnString.append("Log-likelihood: \t") + .append("mean : \t") + .append(decimalFormatter.format(summarizer.getMean)) + .append('\n') + returnString.append("25%-ile : \t") + .append(decimalFormatter.format(summarizer.getQuartile(1))) + .append('\n') + returnString.append("75%-ile : \t") + .append(decimalFormatter.format(summarizer.getQuartile(3))) + .append('\n') + } + + returnString.toString() + } + + +} + +/** + * + * Interface for classes that can keep track of a running average of a series of numbers. One can add to or + * remove from the series, as well as update a datum in the series. The class does not actually keep track of + * the series of values, just its running average, so it doesn't even matter if you remove/change a value that + * wasn't added. + * + * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverage.java + */ +trait RunningAverage { + + /** + * @param datum + * new item to add to the running average + * @throws IllegalArgumentException + * if datum is { @link Double#NaN} + */ + def addDatum(datum: Double) + + /** + * @param datum + * item to remove to the running average + * @throws IllegalArgumentException + * if datum is { @link Double#NaN} + * @throws IllegalStateException + * if count is 0 + */ + def removeDatum(datum: Double) + + /** + * @param delta + * amount by which to change a datum in the running average + * @throws IllegalArgumentException + * if delta is { @link Double#NaN} + * @throws IllegalStateException + * if count is 0 + */ + def changeDatum(delta: Double) + + def getCount: Int + + def getAverage: Double + + /** + * @return a (possibly immutable) object whose average is the negative of this object's + */ + def inverse: RunningAverage +} + +/** + * + * Extends {@link RunningAverage} by adding standard deviation too. + * + * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev.java + */ +trait RunningAverageAndStdDev extends RunningAverage { + + /** @return standard deviation of data */ + def getStandardDeviation: Double + + /** + * @return a (possibly immutable) object whose average is the negative of this object's + */ + def inverse: RunningAverageAndStdDev +} + + +class InvertedRunningAverage(private val delegate: RunningAverage) extends RunningAverage { + + override def addDatum(datum: Double) { + throw new UnsupportedOperationException + } + + override def removeDatum(datum: Double) { + throw new UnsupportedOperationException + } + + override def changeDatum(delta: Double) { + throw new UnsupportedOperationException + } + + override def getCount: Int = { + delegate.getCount + } + + override def getAverage: Double = { + -delegate.getAverage + } + + override def inverse: RunningAverage = { + delegate + } +} + + +/** + * + * A simple class that can keep track of a running average of a series of numbers. One can add to or remove + * from the series, as well as update a datum in the series. The class does not actually keep track of the + * series of values, just its running average, so it doesn't even matter if you remove/change a value that + * wasn't added. + * + * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverage.java + */ +class FullRunningAverage(private var count: Int = 0, + private var average: Double = Double.NaN ) extends RunningAverage { + + /** + * @param datum + * new item to add to the running average + */ + override def addDatum(datum: Double) { + count += 1 + if (count == 1) { + average = datum + } + else { + average = average * (count - 1) / count + datum / count + } + } + + /** + * @param datum + * item to remove from the running average + * @throws IllegalStateException + * if count is 0 + */ + override def removeDatum(datum: Double) { + if (count == 0) { + throw new IllegalStateException + } + count -= 1 + if (count == 0) { + average = Double.NaN + } + else { + average = average * (count + 1) / count - datum / count + } + } + + /** + * @param delta + * amount by which to change a datum in the running average + * @throws IllegalStateException + * if count is 0 + */ + override def changeDatum(delta: Double) { + if (count == 0) { + throw new IllegalStateException + } + average += delta / count + } + + override def getCount: Int = { + count + } + + override def getAverage: Double = { + average + } + + override def inverse: RunningAverage = { + new InvertedRunningAverage(this) + } + + override def toString: String = { + String.valueOf(average) + } +} + + +/** + * + * Extends {@link FullRunningAverage} to add a running standard deviation computation. + * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html + * + * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev.java + */ +class FullRunningAverageAndStdDev(private var count: Int = 0, + private var average: Double = 0.0, + private var mk: Double = 0.0, + private var sk: Double = 0.0) extends FullRunningAverage with RunningAverageAndStdDev { + + var stdDev: Double = 0.0 + + recomputeStdDev + + def getMk: Double = { + mk + } + + def getSk: Double = { + sk + } + + override def getStandardDeviation: Double = { + stdDev + } + + override def addDatum(datum: Double) { + super.addDatum(datum) + val count: Int = getCount + if (count == 1) { + mk = datum + sk = 0.0 + } + else { + val oldmk: Double = mk + val diff: Double = datum - oldmk + mk += diff / count + sk += diff * (datum - mk) + } + recomputeStdDev + } + + override def removeDatum(datum: Double) { + val oldCount: Int = getCount + super.removeDatum(datum) + val oldmk: Double = mk + mk = (oldCount * oldmk - datum) / (oldCount - 1) + sk -= (datum - mk) * (datum - oldmk) + recomputeStdDev + } + + /** + * @throws UnsupportedOperationException + */ + override def changeDatum(delta: Double) { + throw new UnsupportedOperationException + } + + private def recomputeStdDev { + val count: Int = getCount + stdDev = if (count > 1) Math.sqrt(sk / (count - 1)) else Double.NaN + } + + override def inverse: RunningAverageAndStdDev = { + new InvertedRunningAverageAndStdDev(this) + } + + override def toString: String = { + String.valueOf(String.valueOf(getAverage) + ',' + stdDev) + } + +} + + +/** + * + * @param delegate RunningAverageAndStdDev instance + * + * Ported from org.apache.mahout.cf.taste.impl.common.InvertedRunningAverageAndStdDev.java + */ +class InvertedRunningAverageAndStdDev(private val delegate: RunningAverageAndStdDev) extends RunningAverageAndStdDev { + + /** + * @throws UnsupportedOperationException + */ + override def addDatum(datum: Double) { + throw new UnsupportedOperationException + } + + /** + * @throws UnsupportedOperationException + */ + + override def removeDatum(datum: Double) { + throw new UnsupportedOperationException + } + + /** + * @throws UnsupportedOperationException + */ + override def changeDatum(delta: Double) { + throw new UnsupportedOperationException + } + + override def getCount: Int = { + delegate.getCount + } + + override def getAverage: Double = { + -delegate.getAverage + } + + override def getStandardDeviation: Double = { + delegate.getStandardDeviation + } + + override def inverse: RunningAverageAndStdDev = { + delegate + } +} + + + + http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala new file mode 100644 index 0000000..328d27b --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala @@ -0,0 +1,460 @@ +/* + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package org.apache.mahout.classifier.stats + +import java.util +import org.apache.commons.math3.stat.descriptive.moment.Mean // This is brought in by mahout-math +import org.apache.mahout.math.{DenseMatrix, Matrix} +import scala.collection.mutable +import scala.collection.JavaConversions._ + +/** + * + * Ported from org.apache.mahout.classifier.ConfusionMatrix.java + * + * The ConfusionMatrix Class stores the result of Classification of a Test Dataset. + * + * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default. + * + * See http://en.wikipedia.org/wiki/Confusion_matrix for background + * + * + * @param labels The labels to consider for classification + * @param defaultLabel default unknown label + */ +class ConfusionMatrix(private var labels: util.Collection[String] = null, + private var defaultLabel: String = "unknown") { + /** + * Matrix Constructor + * @param m a DenseMatrix with RowLabelBindings + */ +// def this(m: Matrix) { +// this() +// confusionMatrix = Array.ofDim[Int](m.numRows, m.numRows) +// setMatrix(m) +// } + + // val LOG: Logger = LoggerFactory.getLogger(classOf[ConfusionMatrix]) + + var confusionMatrix = Array.ofDim[Int](labels.size + 1, labels.size + 1) + + val labelMap = new mutable.HashMap[String,Integer]() + + var samples: Int = 0 + + var i: Integer = 0 + for (label <- labels) { + labelMap.put(label, i) + i+=1 + } + labelMap.put(defaultLabel, i) + + + def getConfusionMatrix: Array[Array[Int]] = confusionMatrix + + def getLabels = labelMap.keys.toList + + def numLabels: Int = labelMap.size + + def getAccuracy(label: String): Double = { + val labelId: Int = labelMap(label) + var labelTotal: Int = 0 + var correct: Int = 0 + for (i <- 0 until numLabels) { + labelTotal += confusionMatrix(labelId)(i) + if (i == labelId) { + correct += confusionMatrix(labelId)(i) + } + } + + 100.0 * correct / labelTotal + } + + def getAccuracy: Double = { + var total: Int = 0 + var correct: Int = 0 + for (i <- 0 until numLabels) { + for (j <- 0 until numLabels) { + total += confusionMatrix(i)(j) + if (i == j) { + correct += confusionMatrix(i)(j) + } + } + } + + 100.0 * correct / total + } + + /** Sum of true positives and false negatives */ + private def getActualNumberOfTestExamplesForClass(label: String): Int = { + val labelId: Int = labelMap(label) + var sum: Int = 0 + for (i <- 0 until numLabels) { + sum += confusionMatrix(labelId)(i) + } + sum + } + + def getPrecision(label: String): Double = { + val labelId: Int = labelMap(label) + val truePositives: Int = confusionMatrix(labelId)(labelId) + var falsePositives: Int = 0 + + for (i <- 0 until numLabels) { + if (i != labelId) { + falsePositives += confusionMatrix(i)(labelId) + } + } + + if (truePositives + falsePositives == 0) { + 0 + } else { + (truePositives.asInstanceOf[Double]) / (truePositives + falsePositives) + } + } + + + def getWeightedPrecision: Double = { + val precisions: Array[Double] = new Array[Double](numLabels) + val weights: Array[Double] = new Array[Double](numLabels) + var index: Int = 0 + for (label <- labelMap.keys) { + precisions(index) = getPrecision(label) + weights(index) = getActualNumberOfTestExamplesForClass(label) + index += 1 + } + new Mean().evaluate(precisions, weights) + } + + def getRecall(label: String): Double = { + val labelId: Int = labelMap(label) + val truePositives: Int = confusionMatrix(labelId)(labelId) + var falseNegatives: Int = 0 + for (i <- 0 until numLabels) { + if (i != labelId) { + falseNegatives += confusionMatrix(labelId)(i) + } + } + + if (truePositives + falseNegatives == 0) { + 0 + } else { + (truePositives.asInstanceOf[Double]) / (truePositives + falseNegatives) + } + } + + def getWeightedRecall: Double = { + val recalls: Array[Double] = new Array[Double](numLabels) + val weights: Array[Double] = new Array[Double](numLabels) + var index: Int = 0 + for (label <- labelMap.keys) { + recalls(index) = getRecall(label) + weights(index) = getActualNumberOfTestExamplesForClass(label) + index += 1 + } + new Mean().evaluate(recalls, weights) + } + + def getF1score(label: String): Double = { + val precision: Double = getPrecision(label) + val recall: Double = getRecall(label) + if (precision + recall == 0) { + 0 + } else { + 2 * precision * recall / (precision + recall) + } + } + + def getWeightedF1score: Double = { + val f1Scores: Array[Double] = new Array[Double](numLabels) + val weights: Array[Double] = new Array[Double](numLabels) + var index: Int = 0 + for (label <- labelMap.keys) { + f1Scores(index) = getF1score(label) + weights(index) = getActualNumberOfTestExamplesForClass(label) + index += 1 + } + new Mean().evaluate(f1Scores, weights) + } + + def getReliability: Double = { + var count: Int = 0 + var accuracy: Double = 0 + for (label <- labelMap.keys) { + if (!(label == defaultLabel)) { + accuracy += getAccuracy(label) + } + count += 1 + } + accuracy / count + } + + /** + * Accuracy v.s. randomly classifying all samples. + * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy()) + * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. + * Educational And Psychological Measurement 20:37-46. + * + * Formula and variable names from: + * http://www.yale.edu/ceo/OEFS/Accuracy.pdf + * + * @return double + */ + def getKappa: Double = { + var a: Double = 0.0 + var b: Double = 0.0 + for (i <- 0 until confusionMatrix.length) { + a += confusionMatrix(i)(i) + var br: Int = 0 + for (j <- 0 until confusionMatrix.length) { + br += confusionMatrix(i)(j) + } + var bc: Int = 0 + //TODO: verify this as an iterator + for (vec <- confusionMatrix) { + bc += vec(i) + } + b += br * bc + } + (samples * a - b) / (samples * samples - b) + } + + def getCorrect(label: String): Int = { + val labelId: Int = labelMap(label) + confusionMatrix(labelId)(labelId) + } + + def getTotal(label: String): Int = { + val labelId: Int = labelMap(label) + var labelTotal: Int = 0 + for (i <- 0 until numLabels) { + labelTotal += confusionMatrix(labelId)(i) + } + labelTotal + } + + /** + * Standard deviation of normalized producer accuracy + * Not a standard score + * @return double + */ + def getNormalizedStats: RunningAverageAndStdDev = { + val summer = new FullRunningAverageAndStdDev() + for (d <- 0 until confusionMatrix.length) { + var total: Double = 0.0 + for (j <- 0 until confusionMatrix.length) { + total += confusionMatrix(d)(j) + } + summer.addDatum(confusionMatrix(d)(d) / (total + 0.000001)) + } + summer + } + + def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Unit = { + samples += 1 + incrementCount(correctLabel, classifiedResult.getLabel) + } + + def addInstance(correctLabel: String, classifiedLabel: String): Unit = { + samples += 1 + incrementCount(correctLabel, classifiedLabel) + } + + def getCount(correctLabel: String, classifiedLabel: String): Int = { + if (!labelMap.containsKey(correctLabel)) { + // LOG.warn("Label {} did not appear in the training examples", correctLabel) + return 0 + } + assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel) + val correctId: Int = labelMap(correctLabel) + val classifiedId: Int = labelMap(classifiedLabel) + confusionMatrix(correctId)(classifiedId) + } + + def putCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = { + if (!labelMap.containsKey(correctLabel)) { + // LOG.warn("Label {} did not appear in the training examples", correctLabel) + return + } + assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel) + val correctId: Int = labelMap(correctLabel) + val classifiedId: Int = labelMap(classifiedLabel) + if (confusionMatrix(correctId)(classifiedId) == 0.0 && count != 0) { + samples += 1 + } + confusionMatrix(correctId)(classifiedId) = count + } + + def incrementCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = { + putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel)) + } + + def incrementCount(correctLabel: String, classifiedLabel: String): Unit = { + incrementCount(correctLabel, classifiedLabel, 1) + } + + def getDefaultLabel: String = { + defaultLabel + } + + def merge(b: ConfusionMatrix): ConfusionMatrix = { + assert(labelMap.size == b.getLabels.size, "The label sizes do not match") + for (correctLabel <- this.labelMap.keys) { + for (classifiedLabel <- this.labelMap.keys) { + incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)) + } + } + this + } + + def getMatrix: Matrix = { + val length: Int = confusionMatrix.length + val m: Matrix = new DenseMatrix(length, length) + + val labels: java.util.HashMap[String, Integer] = new java.util.HashMap() + + for (r <- 0 until length) { + for (c <- 0 until length) { + m.set(r, c, confusionMatrix(r)(c)) + } + } + + for (entry <- labelMap.entrySet) { + labels.put(entry.getKey, entry.getValue) + } + m.setRowLabelBindings(labels) + m.setColumnLabelBindings(labels) + + m + } + + def setMatrix(m: Matrix) : Unit = { + val length: Int = confusionMatrix.length + if (m.numRows != m.numCols) { + throw new IllegalArgumentException("ConfusionMatrix: matrix(" + m.numRows + ',' + m.numCols + ") must be square") + } + + for (r <- 0 until length) { + for (c <- 0 until length) { + confusionMatrix(r)(c) = Math.round(m.get(r, c)).toInt + } + } + + var labels = m.getRowLabelBindings + if (labels == null) { + labels = m.getColumnLabelBindings + } + + if (labels != null) { + val sorted: Array[String] = sortLabels(labels) + verifyLabels(length, sorted) + labelMap.clear + for (i <- 0 until length) { + labelMap.put(sorted(i), i) + } + } + } + + def verifyLabels(length: Int, sorted: Array[String]): Unit = { + assert(sorted.length == length, "One label, one row") + for (i <- 0 until length) { + if (sorted(i) == null) { + assert(false, "One label, one row") + } + } + } + + def sortLabels(labels: java.util.Map[String, Integer]): Array[String] = { + val sorted: Array[String] = new Array[String](labels.size) + for (entry <- labels.entrySet) { + sorted(entry.getValue) = entry.getKey + } + + sorted + } + + /** + * This is overloaded. toString() is not a formatted report you print for a manager :) + * Assume that if there are no default assignments, the default feature was not used + */ + override def toString: String = { + + val returnString: StringBuilder = new StringBuilder(200) + + returnString.append("=======================================================").append('\n') + returnString.append("Confusion Matrix\n") + returnString.append("-------------------------------------------------------").append('\n') + + val unclassified: Int = getTotal(defaultLabel) + + for (entry <- this.labelMap.entrySet) { + if (!((entry.getKey == defaultLabel) && unclassified == 0)) { + returnString.append(getSmallLabel(entry.getValue) + " ").append('\t') + } + } + + returnString.append("<--Classified as").append('\n') + + for (entry <- this.labelMap.entrySet) { + if (!((entry.getKey == defaultLabel) && unclassified == 0)) { + val correctLabel: String = entry.getKey + var labelTotal: Int = 0 + + for (classifiedLabel <- this.labelMap.keySet) { + if (!((classifiedLabel == defaultLabel) && unclassified == 0)) { + returnString.append(Integer.toString(getCount(correctLabel, classifiedLabel)) + " ") + .append('\t') + labelTotal += getCount(correctLabel, classifiedLabel) + } + } + returnString.append(" | ").append(String.valueOf(labelTotal) + " ") + .append('\t') + .append(getSmallLabel(entry.getValue) + " ") + .append(" = ") + .append(correctLabel) + .append('\n') + } + } + + if (unclassified > 0) { + returnString.append("Default Category: ") + .append(defaultLabel) + .append(": ") + .append(unclassified) + .append('\n') + } + returnString.append('\n') + + returnString.toString() + } + + + def getSmallLabel(i: Int): String = { + var value: Int = i + val returnString: StringBuilder = new StringBuilder + do { + val n: Int = value % 26 + returnString.insert(0, ('a' + n).asInstanceOf[Char]) + value /= 26 + } while (value > 0) + + returnString.toString() + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/main/scala/org/apache/mahout/drivers/MahoutDriver.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/main/scala/org/apache/mahout/drivers/MahoutDriver.scala b/math-scala/src/main/scala/org/apache/mahout/drivers/MahoutDriver.scala new file mode 100644 index 0000000..32515f1 --- /dev/null +++ b/math-scala/src/main/scala/org/apache/mahout/drivers/MahoutDriver.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.drivers + +import org.apache.mahout.math.drm.DistributedContext + +/** Extended by a platform specific version of this class to create a Mahout CLI driver. */ +abstract class MahoutDriver { + + implicit protected var mc: DistributedContext = _ + implicit protected var parser: MahoutOptionParser = _ + + var _useExistingContext: Boolean = false // used in the test suite to reuse one context per suite + + /** must be overriden to setup the DistributedContext mc*/ + protected def start() : Unit + + /** Override (optionally) for special cleanup */ + protected def stop(): Unit = { + if (!_useExistingContext) mc.close + } + + /** This is where you do the work, call start first, then before exiting call stop */ + protected def process(): Unit + + /** Parse command line and call process */ + def main(args: Array[String]): Unit + +}
