http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala new file mode 100644 index 0000000..a943c5f --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeMatrixOpsSuite.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.scalabindings + +import org.scalatest.FunSuite +import RLikeOps._ +import org.apache.mahout.test.MahoutSuite + +class RLikeMatrixOpsSuite extends FunSuite with MahoutSuite { + + test("multiplication") { + + val a = dense((1, 2, 3), (3, 4, 5)) + val b = dense(1, 4, 5) + val m = a %*% b + + assert(m(0, 0) == 24) + assert(m(1, 0) == 44) + println(m.toString) + } + + test("Hadamard") { + val a = dense( + (1, 2, 3), + (3, 4, 5) + ) + val b = dense( + (1, 1, 2), + (2, 1, 1) + ) + + val c = a * b + + printf("C=\n%s\n", c) + + assert(c(0, 0) == 1) + assert(c(1, 2) == 5) + println(c.toString) + + val d = a * 5.0 + assert(d(0, 0) == 5) + assert(d(1, 1) == 20) + + a *= b + assert(a(0, 0) == 1) + assert(a(1, 2) == 5) + println(a.toString) + + } + + /** Test dsl overloads over scala operations over matrices */ + test ("scalarOps") { + val a = dense( + (1, 2, 3), + (3, 4, 5) + ) + + (10 * a - (10 *: a)).norm shouldBe 0 + (10 + a - (10 +: a)).norm shouldBe 0 + (10 - a - (10 -: a)).norm shouldBe 0 + (10 / a - (10 /: a)).norm shouldBe 0 + + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala new file mode 100644 index 0000000..832937b --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/RLikeVectorOpsSuite.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.scalabindings + +import org.scalatest.FunSuite +import org.apache.mahout.math.Vector +import RLikeOps._ +import org.apache.mahout.test.MahoutSuite + +class RLikeVectorOpsSuite extends FunSuite with MahoutSuite { + + test("Hadamard") { + val a: Vector = (1, 2, 3) + val b = (3, 4, 5) + + val c = a * b + println(c) + assert(c ===(3, 8, 15)) + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala new file mode 100644 index 0000000..037f562 --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/math/scalabindings/VectorOpsSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.math.scalabindings + +import org.scalatest.FunSuite +import org.apache.mahout.math.{RandomAccessSparseVector, Vector} +import RLikeOps._ +import org.apache.mahout.test.MahoutSuite + +/** VectorOps Suite */ +class VectorOpsSuite extends FunSuite with MahoutSuite { + + test("inline create") { + + val sparseVec = svec((5 -> 1) :: (10 -> 2.0) :: Nil) + println(sparseVec) + + val sparseVec2: Vector = (5 -> 1.0) :: (10 -> 2.0) :: Nil + println(sparseVec2) + + val sparseVec3: Vector = new RandomAccessSparseVector(100) := (5 -> 1.0) :: Nil + println(sparseVec3) + + val denseVec1: Vector = (1.0, 1.1, 1.2) + println(denseVec1) + + val denseVec2 = dvec(1, 0, 1.1, 1.2) + println(denseVec2) + } + + test("plus minus") { + + val a: Vector = (1, 2, 3) + val b: Vector = (0 -> 3) :: (1 -> 4) :: (2 -> 5) :: Nil + + val c = a + b + val d = b - a + val e = -b - a + + assert(c ===(4, 6, 8)) + assert(d ===(2, 2, 2)) + assert(e ===(-4, -6, -8)) + + } + + test("dot") { + + val a: Vector = (1, 2, 3) + val b = (3, 4, 5) + + val c = a dot b + println(c) + assert(c == 26) + + } + + test ("scalarOps") { + val a = dvec(1 to 5):Vector + + 10 * a shouldBe 10 *: a + 10 + a shouldBe 10 +: a + 10 - a shouldBe 10 -: a + 10 / a shouldBe 10 /: a + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala b/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala new file mode 100644 index 0000000..3ec5ec1 --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/nlp/tfidf/TFIDFtestBase.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.nlp.tfidf + +import org.apache.mahout.math._ +import org.apache.mahout.math.scalabindings._ +import org.apache.mahout.test.DistributedMahoutSuite +import org.scalatest.{FunSuite, Matchers} +import scala.collection._ +import RLikeOps._ +import scala.math._ + + +trait TFIDFtestBase extends DistributedMahoutSuite with Matchers { + this: FunSuite => + + val epsilon = 1E-6 + + val documents: List[(Int, String)] = List( + (1, "the first document contains 5 terms"), + (2, "document two document contains 4 terms"), + (3, "document three three terms"), + (4, "each document including this document contain the term document")) + + def createDictionaryAndDfMaps(documents: List[(Int, String)]): (Map[String, Int], Map[Int, Int]) = { + + // get a tf count for the entire dictionary + val dictMap = documents.unzip._2.mkString(" ").toLowerCase.split(" ").groupBy(identity).mapValues(_.length) + + // create a dictionary with an index for each term + val dictIndex = dictMap.zipWithIndex.map(x => x._1._1 -> x._2).toMap + + val docFrequencyCount = new Array[Int](dictMap.size) + + for (token <- dictMap) { + for (doc <- documents) { + // parse the string and get a word then increment the df count for that word + if (doc._2.toLowerCase.split(" ").contains(token._1)) { + docFrequencyCount(dictIndex(token._1)) += 1 + } + } + } + + val docFrequencyMap = docFrequencyCount.zipWithIndex.map(x => x._2 -> x._1).toMap + + (dictIndex, docFrequencyMap) + } + + def vectorizeDocument(document: String, + dictionaryMap: Map[String, Int], + dfMap: Map[Int, Int], weight: TermWeight = new TFIDF): Vector = { + + val wordCounts = document.toLowerCase.split(" ").groupBy(identity).mapValues(_.length) + + val vec = new RandomAccessSparseVector(dictionaryMap.size) + + val totalDFSize = dictionaryMap.size + val docSize = wordCounts.size + + for (word <- wordCounts) { + val term = word._1 + if (dictionaryMap.contains(term)) { + val termFreq = word._2 + val dictIndex = dictionaryMap(term) + val docFreq = dfMap(dictIndex) + val currentWeight = weight.calculate(termFreq, docFreq.toInt, docSize, totalDFSize.toInt) + vec(dictIndex)= currentWeight + } + } + vec + } + + test("TF test") { + + val (dictionary, dfMap) = createDictionaryAndDfMaps(documents) + + val tf: TermWeight = new TF() + + val vectorizedDocuments: Matrix = new SparseMatrix(documents.size, dictionary.size) + + for (doc <- documents) { + vectorizedDocuments(doc._1 - 1, ::) := vectorizeDocument(doc._2, dictionary, dfMap, tf) + } + + // corpus: + // (1, "the first document contains 5 terms"), + // (2, "document two document contains 4 terms"), + // (3, "document three three terms"), + // (4, "each document including this document contain the term document") + + // dictonary: + // (this -> 0, 4 -> 1, three -> 2, document -> 3, two -> 4, term -> 5, 5 -> 6, contain -> 7, + // each -> 8, first -> 9, terms -> 10, contains -> 11, including -> 12, the -> 13) + + // dfMap: + // (0 -> 1, 5 -> 1, 10 -> 3, 1 -> 1, 6 -> 1, 9 -> 1, 13 -> 2, 2 -> 1, 12 -> 1, 7 -> 1, 3 -> 4, + // 11 -> 2, 8 -> 1, 4 -> 1) + + vectorizedDocuments(0, 0).toInt should be (0) + vectorizedDocuments(0, 13).toInt should be (1) + vectorizedDocuments(1, 3).toInt should be (2) + vectorizedDocuments(3, 3).toInt should be (3) + + } + + + test("TFIDF test") { + val (dictionary, dfMap) = createDictionaryAndDfMaps(documents) + + val tfidf: TermWeight = new TFIDF() + + val vectorizedDocuments: Matrix = new SparseMatrix(documents.size, dictionary.size) + + for (doc <- documents) { + vectorizedDocuments(doc._1 - 1, ::) := vectorizeDocument(doc._2, dictionary, dfMap, tfidf) + } + + // corpus: + // (1, "the first document contains 5 terms"), + // (2, "document two document contains 4 terms"), + // (3, "document three three terms"), + // (4, "each document including this document contain the term document") + + // dictonary: + // (this -> 0, 4 -> 1, three -> 2, document -> 3, two -> 4, term -> 5, 5 -> 6, contain -> 7, + // each -> 8, first -> 9, terms -> 10, contains -> 11, including -> 12, the -> 13) + + // dfMap: + // (0 -> 1, 5 -> 1, 10 -> 3, 1 -> 1, 6 -> 1, 9 -> 1, 13 -> 2, 2 -> 1, 12 -> 1, 7 -> 1, 3 -> 4, + // 11 -> 2, 8 -> 1, 4 -> 1) + + abs(vectorizedDocuments(0, 0) - 0.0) should be < epsilon + abs(vectorizedDocuments(0, 13) - 2.540445) should be < epsilon + abs(vectorizedDocuments(1, 3) - 2.870315) should be < epsilon + abs(vectorizedDocuments(3, 3) - 3.515403) should be < epsilon + } + + test("MLlib TFIDF test") { + val (dictionary, dfMap) = createDictionaryAndDfMaps(documents) + + val tfidf: TermWeight = new MLlibTFIDF() + + val vectorizedDocuments: Matrix = new SparseMatrix(documents.size, dictionary.size) + + for (doc <- documents) { + vectorizedDocuments(doc._1 - 1, ::) := vectorizeDocument(doc._2, dictionary, dfMap, tfidf) + } + + // corpus: + // (1, "the first document contains 5 terms"), + // (2, "document two document contains 4 terms"), + // (3, "document three three terms"), + // (4, "each document including this document contain the term document") + + // dictonary: + // (this -> 0, 4 -> 1, three -> 2, document -> 3, two -> 4, term -> 5, 5 -> 6, contain -> 7, + // each -> 8, first -> 9, terms -> 10, contains -> 11, including -> 12, the -> 13) + + // dfMap: + // (0 -> 1, 5 -> 1, 10 -> 3, 1 -> 1, 6 -> 1, 9 -> 1, 13 -> 2, 2 -> 1, 12 -> 1, 7 -> 1, 3 -> 4, + // 11 -> 2, 8 -> 1, 4 -> 1) + + abs(vectorizedDocuments(0, 0) - 0.0) should be < epsilon + abs(vectorizedDocuments(0, 13) - 1.609437) should be < epsilon + abs(vectorizedDocuments(1, 3) - 2.197224) should be < epsilon + abs(vectorizedDocuments(3, 3) - 3.295836) should be < epsilon + } + +} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala b/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala new file mode 100644 index 0000000..3538991 --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/test/DistributedMahoutSuite.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.test + +import org.apache.mahout.math.drm.DistributedContext +import org.scalatest.{Suite, FunSuite, Matchers} + +/** + * Unit tests that use a distributed context to run + */ +trait DistributedMahoutSuite extends MahoutSuite { this: Suite => + protected implicit var mahoutCtx: DistributedContext +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala b/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala new file mode 100644 index 0000000..7a34aa2 --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/test/LoggerConfiguration.scala @@ -0,0 +1,16 @@ +package org.apache.mahout.test + +import org.scalatest._ +import org.apache.log4j.{Level, Logger, BasicConfigurator} + +trait LoggerConfiguration extends BeforeAndAfterAllConfigMap { + this: Suite => + + override protected def beforeAll(configMap: ConfigMap): Unit = { + super.beforeAll(configMap) + BasicConfigurator.resetConfiguration() + BasicConfigurator.configure() + Logger.getRootLogger.setLevel(Level.ERROR) + Logger.getLogger("org.apache.mahout.math.scalabindings").setLevel(Level.DEBUG) + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala ---------------------------------------------------------------------- diff --git a/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala b/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala new file mode 100644 index 0000000..d3b8a38 --- /dev/null +++ b/math-scala/src/test/scala/org/apache/mahout/test/MahoutSuite.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.test + +import java.io.File +import org.scalatest._ +import org.apache.mahout.common.RandomUtils + +trait MahoutSuite extends BeforeAndAfterEach with LoggerConfiguration with Matchers { + this: Suite => + + final val TmpDir = "tmp/" + + override protected def beforeEach() { + super.beforeEach() + RandomUtils.useTestSeed() + } + + override protected def beforeAll(configMap: ConfigMap) { + super.beforeAll(configMap) + + // just in case there is an existing tmp dir clean it before every suite + deleteDirectory(new File(TmpDir)) + } + + override protected def afterEach() { + + // clean the tmp dir after every test + deleteDirectory(new File(TmpDir)) + + super.afterEach() + } + + /** Delete directory no symlink checking and exceptions are not caught */ + private def deleteDirectory(path: File): Unit = { + if (path.isDirectory) + for (files <- path.listFiles) deleteDirectory(files) + path.delete + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/pom.xml ---------------------------------------------------------------------- diff --git a/pom.xml b/pom.xml index 7151414..74da44e 100644 --- a/pom.xml +++ b/pom.xml @@ -212,12 +212,12 @@ </dependency> <dependency> - <artifactId>mahout-samsara_${scala.compat.version}</artifactId> + <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> <groupId>${project.groupId}</groupId> <version>${project.version}</version> </dependency> <dependency> - <artifactId>mahout-samsara_${scala.compat.version}</artifactId> + <artifactId>mahout-math-scala_${scala.compat.version}</artifactId> <groupId>${project.groupId}</groupId> <version>${project.version}</version> <classifier>tests</classifier> @@ -772,7 +772,7 @@ <module>integration</module> <module>examples</module> <module>distribution</module> - <module>samsara</module> + <module>math-scala</module> <module>spark</module> <module>spark-shell</module> <module>h2o</module> http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/samsara/pom.xml ---------------------------------------------------------------------- diff --git a/samsara/pom.xml b/samsara/pom.xml deleted file mode 100644 index 5f80b68..0000000 --- a/samsara/pom.xml +++ /dev/null @@ -1,194 +0,0 @@ -<?xml version="1.0" encoding="UTF-8"?> - -<!-- - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. ---> - -<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> - <modelVersion>4.0.0</modelVersion> - - <parent> - <groupId>org.apache.mahout</groupId> - <artifactId>mahout</artifactId> - <version>0.11.0-SNAPSHOT</version> - <relativePath>../pom.xml</relativePath> - </parent> - - <artifactId>mahout-samsara_${scala.compat.version}</artifactId> - <name>Mahout Samsara</name> - <description>Mahout Math Scala bindings</description> - - <packaging>jar</packaging> - - <build> - <plugins> - <!-- create test jar so other modules can reuse the samsara test utility classes. --> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-jar-plugin</artifactId> - <executions> - <execution> - <goals> - <goal>test-jar</goal> - </goals> - <phase>package</phase> - </execution> - </executions> - </plugin> - - <plugin> - <artifactId>maven-javadoc-plugin</artifactId> - </plugin> - - <plugin> - <artifactId>maven-source-plugin</artifactId> - </plugin> - - <plugin> - <groupId>net.alchim31.maven</groupId> - <artifactId>scala-maven-plugin</artifactId> - <executions> - <execution> - <id>add-scala-sources</id> - <phase>initialize</phase> - <goals> - <goal>add-source</goal> - </goals> - </execution> - <execution> - <id>scala-compile</id> - <phase>process-resources</phase> - <goals> - <goal>compile</goal> - </goals> - </execution> - <execution> - <id>scala-test-compile</id> - <phase>process-test-resources</phase> - <goals> - <goal>testCompile</goal> - </goals> - </execution> - </executions> - </plugin> - - <!--this is what scalatest recommends to do to enable scala tests --> - - <!-- disable surefire --> - <plugin> - <groupId>org.apache.maven.plugins</groupId> - <artifactId>maven-surefire-plugin</artifactId> - <configuration> - <skipTests>true</skipTests> - </configuration> - </plugin> - <!-- enable scalatest --> - <plugin> - <groupId>org.scalatest</groupId> - <artifactId>scalatest-maven-plugin</artifactId> - <executions> - <execution> - <id>test</id> - <goals> - <goal>test</goal> - </goals> - </execution> - </executions> - </plugin> - - </plugins> - </build> - - <dependencies> - - <dependency> - <groupId>org.apache.mahout</groupId> - <artifactId>mahout-math</artifactId> - </dependency> - - <!-- 3rd-party --> - <dependency> - <groupId>log4j</groupId> - <artifactId>log4j</artifactId> - </dependency> - - <dependency> - <groupId>com.github.scopt</groupId> - <artifactId>scopt_${scala.compat.version}</artifactId> - <version>3.3.0</version> - </dependency> - - <!-- scala stuff --> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-compiler</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-reflect</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-library</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scala-actors</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scalap</artifactId> - <version>${scala.version}</version> - </dependency> - <dependency> - <groupId>org.scalatest</groupId> - <artifactId>scalatest_${scala.compat.version}</artifactId> - </dependency> - - </dependencies> - - <profiles> - <profile> - <id>mahout-release</id> - <build> - <plugins> - <plugin> - <groupId>net.alchim31.maven</groupId> - <artifactId>scala-maven-plugin</artifactId> - <executions> - <execution> - <id>generate-scaladoc</id> - <goals> - <goal>doc</goal> - </goals> - </execution> - <execution> - <id>attach-scaladoc-jar</id> - <goals> - <goal>doc-jar</goal> - </goals> - </execution> - </executions> - </plugin> - </plugins> - </build> - </profile> - </profiles> -</project> http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala deleted file mode 100644 index 5de0733..0000000 --- a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBClassifier.scala +++ /dev/null @@ -1,119 +0,0 @@ -/* - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ -package org.apache.mahout.classifier.naivebayes - -import org.apache.mahout.math.Vector -import scala.collection.JavaConversions._ - -/** - * Abstract Classifier base for Complentary and Standard Classifiers - * @param nbModel a trained NBModel - */ -abstract class AbstractNBClassifier(nbModel: NBModel) extends java.io.Serializable { - - // Trained Naive Bayes Model - val model = nbModel - - /** scoring method for standard and complementary classifiers */ - protected def getScoreForLabelFeature(label: Int, feature: Int): Double - - /** getter for model */ - protected def getModel: NBModel= { - model - } - - /** - * Compute the score for a Vector of weighted TF-IDF featured - * @param label Label to be scored - * @param instance Vector of weights to be calculate score - * @return score for this Label - */ - protected def getScoreForLabelInstance(label: Int, instance: Vector): Double = { - var result: Double = 0.0 - for (e <- instance.nonZeroes) { - result += e.get * getScoreForLabelFeature(label, e.index) - } - result - } - - /** number of categories the model has been trained on */ - def numCategories: Int = { - model.numLabels - } - - /** - * get a scoring vector for a vector of TF of TF-IDF weights - * @param instance vector of TF of TF-IDF weights to be classified - * @return a vector of scores. - */ - def classifyFull(instance: Vector): Vector = { - classifyFull(model.createScoringVector, instance) - } - - /** helper method for classifyFull(Vector) */ - def classifyFull(r: Vector, instance: Vector): Vector = { - var label: Int = 0 - for (label <- 0 until model.numLabels) { - r.setQuick(label, getScoreForLabelInstance(label, instance)) - } - r - } -} - -/** - * Standard Multinomial Naive Bayes Classifier - * @param nbModel a trained NBModel - */ -class StandardNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable{ - override def getScoreForLabelFeature(label: Int, feature: Int): Double = { - val model: NBModel = getModel - StandardNBClassifier.computeWeight(model.weight(label, feature), model.labelWeight(label), model.alphaI, model.numFeatures) - } -} - -/** helper object for StandardNBClassifier */ -object StandardNBClassifier extends java.io.Serializable { - /** Compute Standard Multinomial Naive Bayes Weights See Rennie et. al. Section 2.1 */ - def computeWeight(featureLabelWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = { - val numerator: Double = featureLabelWeight + alphaI - val denominator: Double = labelWeight + alphaI * numFeatures - return Math.log(numerator / denominator) - } -} - -/** - * Complementary Naive Bayes Classifier - * @param nbModel a trained NBModel - */ -class ComplementaryNBClassifier(nbModel: NBModel) extends AbstractNBClassifier(nbModel: NBModel) with java.io.Serializable { - override def getScoreForLabelFeature(label: Int, feature: Int): Double = { - val model: NBModel = getModel - val weight: Double = ComplementaryNBClassifier.computeWeight(model.featureWeight(feature), model.weight(label, feature), model.totalWeightSum, model.labelWeight(label), model.alphaI, model.numFeatures) - return weight / model.thetaNormalizer(label) - } -} - -/** helper object for ComplementaryNBClassifier */ -object ComplementaryNBClassifier extends java.io.Serializable { - - /** Compute Complementary weights See Rennie et. al. Section 3.1 */ - def computeWeight(featureWeight: Double, featureLabelWeight: Double, totalWeight: Double, labelWeight: Double, alphaI: Double, numFeatures: Double): Double = { - val numerator: Double = featureWeight - featureLabelWeight + alphaI - val denominator: Double = totalWeight - labelWeight + alphaI * numFeatures - return -Math.log(numerator / denominator) - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala deleted file mode 100644 index 3ceae96..0000000 --- a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NBModel.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.mahout.classifier.naivebayes - -import org.apache.mahout.math._ - -import org.apache.mahout.math.{drm, scalabindings} - -import scalabindings._ -import scalabindings.RLikeOps._ -import drm.RLikeDrmOps._ -import drm._ -import scala.collection.JavaConverters._ -import scala.language.asInstanceOf -import scala.collection._ -import JavaConversions._ - -/** - * - * @param weightsPerLabelAndFeature Aggregated matrix of weights of labels x features - * @param weightsPerFeature Vector of summation of all feature weights. - * @param weightsPerLabel Vector of summation of all label weights. - * @param perlabelThetaNormalizer Vector of weight normalizers per label (used only for complemtary models) - * @param labelIndex HashMap of labels and their corresponding row in the weightMatrix - * @param alphaI Laplace smoothing factor. - * @param isComplementary Whether or not this is a complementary model. - */ -class NBModel(val weightsPerLabelAndFeature: Matrix = null, - val weightsPerFeature: Vector = null, - val weightsPerLabel: Vector = null, - val perlabelThetaNormalizer: Vector = null, - val labelIndex: Map[String, Integer] = null, - val alphaI: Float = 1.0f, - val isComplementary: Boolean= false) extends java.io.Serializable { - - - val numFeatures: Double = weightsPerFeature.getNumNondefaultElements - val totalWeightSum: Double = weightsPerLabel.zSum - val alphaVector: Vector = null - - validate() - - // todo: Maybe it is a good idea to move the dfsWrite and dfsRead out - // todo: of the model and into a helper - - // TODO: weightsPerLabelAndFeature, a sparse (numFeatures x numLabels) matrix should fit - // TODO: upfront in memory and should not require a DRM decide if we want this to scale out. - - - /** getter for summed label weights. Used by legacy classifier */ - def labelWeight(label: Int): Double = { - weightsPerLabel.getQuick(label) - } - - /** getter for weight normalizers. Used by legacy classifier */ - def thetaNormalizer(label: Int): Double = { - perlabelThetaNormalizer.get(label) - } - - /** getter for summed feature weights. Used by legacy classifier */ - def featureWeight(feature: Int): Double = { - weightsPerFeature.getQuick(feature) - } - - /** getter for individual aggregated weights. Used by legacy classifier */ - def weight(label: Int, feature: Int): Double = { - weightsPerLabelAndFeature.getQuick(label, feature) - } - - /** getter for a single empty vector of weights */ - def createScoringVector: Vector = { - weightsPerLabel.like - } - - /** getter for a the number of labels to consider */ - def numLabels: Int = { - weightsPerLabel.size - } - - /** - * Write a trained model to the filesystem as a series of DRMs - * @param pathToModel Directory to which the model will be written - */ - def dfsWrite(pathToModel: String)(implicit ctx: DistributedContext): Unit = { - //todo: write out as smaller partitions or possibly use reader and writers to - //todo: write something other than a DRM for label Index, is Complementary, alphaI. - - // add a directory to put all of the DRMs in - val fullPathToModel = pathToModel + NBModel.modelBaseDirectory - - drmParallelize(weightsPerLabelAndFeature).dfsWrite(fullPathToModel + "/weightsPerLabelAndFeatureDrm.drm") - drmParallelize(sparse(weightsPerFeature)).dfsWrite(fullPathToModel + "/weightsPerFeatureDrm.drm") - drmParallelize(sparse(weightsPerLabel)).dfsWrite(fullPathToModel + "/weightsPerLabelDrm.drm") - drmParallelize(sparse(perlabelThetaNormalizer)).dfsWrite(fullPathToModel + "/perlabelThetaNormalizerDrm.drm") - drmParallelize(sparse(svec((0,alphaI)::Nil))).dfsWrite(fullPathToModel + "/alphaIDrm.drm") - - // isComplementry is true if isComplementaryDrm(0,0) == 1 else false - val isComplementaryDrm = sparse(0 to 1, 0 to 1) - if(isComplementary){ - isComplementaryDrm(0,0) = 1.0 - } else { - isComplementaryDrm(0,0) = 0.0 - } - drmParallelize(isComplementaryDrm).dfsWrite(fullPathToModel + "/isComplementaryDrm.drm") - - // write the label index as a String-Keyed DRM. - val labelIndexDummyDrm = weightsPerLabelAndFeature.like() - labelIndexDummyDrm.setRowLabelBindings(labelIndex) - // get a reverse map of [Integer, String] and set the value of firsr column of the drm - // to the corresponding row number for it's Label (the rows may not be read back in the same order) - val revMap = labelIndex.map(x => x._2 -> x._1) - for(i <- 0 until labelIndexDummyDrm.numRows() ){ - labelIndexDummyDrm.set(labelIndex(revMap(i)), 0, i.toDouble) - } - - drmParallelizeWithRowLabels(labelIndexDummyDrm).dfsWrite(fullPathToModel + "/labelIndex.drm") - } - - /** Model Validation */ - def validate() { - assert(alphaI > 0, "alphaI has to be greater than 0!") - assert(numFeatures > 0, "the vocab count has to be greater than 0!") - assert(totalWeightSum > 0, "the totalWeightSum has to be greater than 0!") - assert(weightsPerLabel != null, "the number of labels has to be defined!") - assert(weightsPerLabel.getNumNondefaultElements > 0, "the number of labels has to be greater than 0!") - assert(weightsPerFeature != null, "the feature sums have to be defined") - assert(weightsPerFeature.getNumNondefaultElements > 0, "the feature sums have to be greater than 0!") - if (isComplementary) { - assert(perlabelThetaNormalizer != null, "the theta normalizers have to be defined") - assert(perlabelThetaNormalizer.getNumNondefaultElements > 0, "the number of theta normalizers has to be greater than 0!") - assert(Math.signum(perlabelThetaNormalizer.minValue) == Math.signum(perlabelThetaNormalizer.maxValue), "Theta normalizers do not all have the same sign") - assert(perlabelThetaNormalizer.getNumNonZeroElements == perlabelThetaNormalizer.size, "Weight normalizers can not have zero value.") - } - assert(labelIndex.size == weightsPerLabel.getNumNondefaultElements, "label index must have entries for all labels") - } -} - -object NBModel extends java.io.Serializable { - - val modelBaseDirectory = "/naiveBayesModel" - - /** - * Read a trained model in from from the filesystem. - * @param pathToModel directory from which to read individual model components - * @return a valid NBModel - */ - def dfsRead(pathToModel: String)(implicit ctx: DistributedContext): NBModel = { - //todo: Takes forever to read we need a more practical method of writing models. Readers/Writers? - - // read from a base directory for all drms - val fullPathToModel = pathToModel + modelBaseDirectory - - val weightsPerFeatureDrm = drmDfsRead(fullPathToModel + "/weightsPerFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) - val weightsPerFeature = weightsPerFeatureDrm.collect(0, ::) - weightsPerFeatureDrm.uncache() - - val weightsPerLabelDrm = drmDfsRead(fullPathToModel + "/weightsPerLabelDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) - val weightsPerLabel = weightsPerLabelDrm.collect(0, ::) - weightsPerLabelDrm.uncache() - - val alphaIDrm = drmDfsRead(fullPathToModel + "/alphaIDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) - val alphaI: Float = alphaIDrm.collect(0, 0).toFloat - alphaIDrm.uncache() - - // isComplementry is true if isComplementaryDrm(0,0) == 1 else false - val isComplementaryDrm = drmDfsRead(fullPathToModel + "/isComplementaryDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) - val isComplementary = isComplementaryDrm.collect(0, 0).toInt == 1 - isComplementaryDrm.uncache() - - var perLabelThetaNormalizer= weightsPerFeature.like() - if (isComplementary) { - val perLabelThetaNormalizerDrm = drm.drmDfsRead(fullPathToModel + "/perlabelThetaNormalizerDrm.drm") - .checkpoint(CacheHint.MEMORY_ONLY) - perLabelThetaNormalizer = perLabelThetaNormalizerDrm.collect(0, ::) - } - - val dummyLabelDrm= drmDfsRead(fullPathToModel + "/labelIndex.drm") - .checkpoint(CacheHint.MEMORY_ONLY) - val labelIndexMap:java.util.Map[String, Integer] = dummyLabelDrm.getRowLabelBindings - dummyLabelDrm.uncache() - - // map the labels to the corresponding row numbers of weightsPerFeatureDrm (values in dummyLabelDrm) - val scalaLabelIndexMap: mutable.Map[String, Integer] = - labelIndexMap.map(x => x._1 -> dummyLabelDrm.get(labelIndexMap(x._1), 0) - .toInt - .asInstanceOf[Integer]) - - val weightsPerLabelAndFeatureDrm = drmDfsRead(fullPathToModel + "/weightsPerLabelAndFeatureDrm.drm").checkpoint(CacheHint.MEMORY_ONLY) - val weightsPerLabelAndFeature = weightsPerLabelAndFeatureDrm.collect - weightsPerLabelAndFeatureDrm.uncache() - - // model validation is triggered automatically by constructor - val model: NBModel = new NBModel(weightsPerLabelAndFeature, - weightsPerFeature, - weightsPerLabel, - perLabelThetaNormalizer, - scalaLabelIndexMap, - alphaI, - isComplementary) - - model - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala b/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala deleted file mode 100644 index a15ca09..0000000 --- a/samsara/src/main/scala/org/apache/mahout/classifier/naivebayes/NaiveBayes.scala +++ /dev/null @@ -1,380 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.classifier.naivebayes - -import org.apache.mahout.classifier.stats.{ResultAnalyzer, ClassifierResult} -import org.apache.mahout.math._ -import scalabindings._ -import scalabindings.RLikeOps._ -import drm.RLikeDrmOps._ -import drm._ -import scala.reflect.ClassTag -import scala.language.asInstanceOf -import collection._ -import scala.collection.JavaConversions._ - -/** - * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor - * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - */ -trait NaiveBayes extends java.io.Serializable{ - - /** default value for the Laplacian smoothing parameter */ - def defaultAlphaI = 1.0f - - // function to extract categories from string keys - type CategoryParser = String => String - - /** Default: seqdirectory/seq2Sparse Categories are Stored in Drm Keys as: /Category/document_id */ - def seq2SparseCategoryParser: CategoryParser = x => x.split("/")(1) - - - /** - * Distributed training of a Naive Bayes model. Follows the approach presented in Rennie et.al.: Tackling the poor - * assumptions of Naive Bayes Text classifiers, ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf - * - * @param observationsPerLabel a DrmLike[Int] matrix containing term frequency counts for each label. - * @param trainComplementary whether or not to train a complementary Naive Bayes model - * @param alphaI Laplace smoothing parameter - * @return trained naive bayes model - */ - def train(observationsPerLabel: DrmLike[Int], - labelIndex: Map[String, Integer], - trainComplementary: Boolean = true, - alphaI: Float = defaultAlphaI): NBModel = { - - // Summation of all weights per feature - val weightsPerFeature = observationsPerLabel.colSums - - // Distributed summation of all weights per label - val weightsPerLabel = observationsPerLabel.rowSums - - // Collect a matrix to pass to the NaiveBayesModel - val inCoreTFIDF = observationsPerLabel.collect - - // perLabelThetaNormalizer Vector is expected by NaiveBayesModel. We can pass a null value - // or Vector of zeroes in the case of a standard NB model. - var thetaNormalizer = weightsPerFeature.like() - - // Instantiate a trainer and retrieve the perLabelThetaNormalizer Vector from it in the case of - // a complementary NB model - if (trainComplementary) { - val thetaTrainer = new ComplementaryNBThetaTrainer(weightsPerFeature, - weightsPerLabel, - alphaI) - // local training of the theta normalization - for (labelIndex <- 0 until inCoreTFIDF.nrow) { - thetaTrainer.train(labelIndex, inCoreTFIDF(labelIndex, ::)) - } - thetaNormalizer = thetaTrainer.retrievePerLabelThetaNormalizer - } - - new NBModel(inCoreTFIDF, - weightsPerFeature, - weightsPerLabel, - thetaNormalizer, - labelIndex, - alphaI, - trainComplementary) - } - - /** - * Extract label Keys from raw TF or TF-IDF Matrix generated by seqdirectory/seq2sparse - * and aggregate TF or TF-IDF values by their label - * Override this method in engine specific modules to optimize - * - * @param stringKeyedObservations DrmLike matrix; Output from seq2sparse - * in form K = eg./Category/document_title - * V = TF or TF-IDF values per term - * @param cParser a String => String function used to extract categories from - * Keys of the stringKeyedObservations DRM. The default - * CategoryParser will extract "Category" from: '/Category/document_id' - * @return (labelIndexMap,aggregatedByLabelObservationDrm) - * labelIndexMap is a HashMap [String, Integer] K = label row index - * V = label - * aggregatedByLabelObservationDrm is a DrmLike[Int] of aggregated - * TF or TF-IDF counts per label - */ - def extractLabelsAndAggregateObservations[K: ClassTag](stringKeyedObservations: DrmLike[K], - cParser: CategoryParser = seq2SparseCategoryParser) - (implicit ctx: DistributedContext): - (mutable.HashMap[String, Integer], DrmLike[Int])= { - - stringKeyedObservations.checkpoint() - - val numDocs=stringKeyedObservations.nrow - val numFeatures=stringKeyedObservations.ncol - - // Extract categories from labels assigned by seq2sparse - // Categories are Stored in Drm Keys as eg.: /Category/document_id - - // Get a new DRM with a single column so that we don't have to collect the - // DRM into memory upfront. - val strippedObeservations= stringKeyedObservations.mapBlock(ncol=1){ - case(keys, block) => - val blockB = block.like(keys.size, 1) - keys -> blockB - } - - // Extract the row label bindings (the String keys) from the slim Drm - // strip the document_id from the row keys keeping only the category. - // Sort the bindings alphabetically into a Vector - val labelVectorByRowIndex = strippedObeservations - .getRowLabelBindings - .map(x => x._2 -> cParser(x._1)) - .toVector.sortWith(_._1 < _._1) - - //TODO: add a .toIntKeyed(...) method to DrmLike? - - // Copy stringKeyedObservations to an Int-Keyed Drm so that we can compute transpose - // Copy the Collected Matrices up front for now until we hav a distributed way of converting - val inCoreStringKeyedObservations = stringKeyedObservations.collect - val inCoreIntKeyedObservations = new SparseMatrix( - stringKeyedObservations.nrow.toInt, - stringKeyedObservations.ncol) - for (i <- 0 until inCoreStringKeyedObservations.nrow.toInt) { - inCoreIntKeyedObservations(i, ::) = inCoreStringKeyedObservations(i, ::) - } - - val intKeyedObservations= drmParallelize(inCoreIntKeyedObservations) - - stringKeyedObservations.uncache() - - var labelIndex = 0 - val labelIndexMap = new mutable.HashMap[String, Integer] - val encodedLabelByRowIndexVector = new DenseVector(labelVectorByRowIndex.size) - - // Encode Categories as an Integer (Double) so we can broadcast as a vector - // where each element is an Int-encoded category whose index corresponds - // to its row in the Drm - for (i <- 0 until labelVectorByRowIndex.size) { - if (!(labelIndexMap.contains(labelVectorByRowIndex(i)._2))) { - encodedLabelByRowIndexVector(i) = labelIndex.toDouble - labelIndexMap.put(labelVectorByRowIndex(i)._2, labelIndex) - labelIndex += 1 - } - // don't like this casting but need to use a java.lang.Integer when setting rowLabelBindings - encodedLabelByRowIndexVector(i) = labelIndexMap - .getOrElse(labelVectorByRowIndex(i)._2, -1) - .asInstanceOf[Int].toDouble - } - - // "Combiner": Map and aggregate by Category. Do this by broadcasting the encoded - // category vector and mapping a transposed IntKeyed Drm out so that all categories - // will be present on all nodes as columns and can be referenced by - // BCastEncodedCategoryByRowVector. Iteratively sum all categories. - val nLabels = labelIndex - - val bcastEncodedCategoryByRowVector = drmBroadcast(encodedLabelByRowIndexVector) - - val aggregetedObservationByLabelDrm = intKeyedObservations.t.mapBlock(ncol = nLabels) { - case (keys, blockA) => - val blockB = blockA.like(keys.size, nLabels) - var label : Int = 0 - for (i <- 0 until keys.size) { - blockA(i, ::).nonZeroes().foreach { elem => - label = bcastEncodedCategoryByRowVector.get(elem.index).toInt - blockB(i, label) = blockB(i, label) + blockA(i, elem.index) - } - } - keys -> blockB - }.t - - (labelIndexMap, aggregetedObservationByLabelDrm) - } - - /** - * Test a trained model with a labeled dataset sequentially - * @param model a trained NBModel - * @param testSet a labeled testing set - * @param testComplementary test using a complementary or a standard NB classifier - * @param cParser a String => String function used to extract categories from - * Keys of the testing set DRM. The default - * CategoryParser will extract "Category" from: '/Category/document_id' - * - * *Note*: this method brings the entire test set into upfront memory, - * This method is optimized and parallelized in SparkNaiveBayes - * - * @tparam K implicitly determined Key type of test set DRM: String - * @return a result analyzer with confusion matrix and accuracy statistics - */ - def test[K: ClassTag](model: NBModel, - testSet: DrmLike[K], - testComplementary: Boolean = false, - cParser: CategoryParser = seq2SparseCategoryParser) - (implicit ctx: DistributedContext): ResultAnalyzer = { - - val labelMap = model.labelIndex - - val numLabels = model.numLabels - - testSet.checkpoint() - - val numTestInstances = testSet.nrow.toInt - - // instantiate the correct type of classifier - val classifier = testComplementary match { - case true => new ComplementaryNBClassifier(model) with Serializable - case _ => new StandardNBClassifier(model) with Serializable - } - - if (testComplementary) { - assert(testComplementary == model.isComplementary, - "Complementary Label Assignment requires Complementary Training") - } - - - // Sequentially assign labels to the test set: - // *Note* this brings the entire test set into memory upfront: - - // Since we cant broadcast the model as is do it sequentially up front for now - val inCoreTestSet = testSet.collect - - // get the labels of the test set and extract the keys - val testSetLabelMap = testSet.getRowLabelBindings - - // empty Matrix in which we'll set the classification scores - val inCoreScoredTestSet = testSet.like(numTestInstances, numLabels) - - testSet.uncache() - - for (i <- 0 until numTestInstances) { - inCoreScoredTestSet(i, ::) := classifier.classifyFull(inCoreTestSet(i, ::)) - } - - // todo: reverse the labelMaps in training and through the model? - - // reverse the label map and extract the labels - val reverseTestSetLabelMap = testSetLabelMap.map(x => x._2 -> cParser(x._1)) - - val reverseLabelMap = labelMap.map(x => x._2 -> x._1) - - val analyzer = new ResultAnalyzer(labelMap.keys.toList.sorted, "DEFAULT") - - // assign labels- winner takes all - for (i <- 0 until numTestInstances) { - val (bestIdx, bestScore) = argmax(inCoreScoredTestSet(i, ::)) - val classifierResult = new ClassifierResult(reverseLabelMap(bestIdx), bestScore) - analyzer.addInstance(reverseTestSetLabelMap(i), classifierResult) - } - - analyzer - } - - /** - * argmax with values as well - * returns a tuple of index of the max score and the score itself. - * @param v Vector of of scores - * @return (bestIndex, bestScore) - */ - def argmax(v: Vector): (Int, Double) = { - var bestIdx: Int = Integer.MIN_VALUE - var bestScore: Double = Integer.MIN_VALUE.asInstanceOf[Int].toDouble - for(i <- 0 until v.size) { - if(v(i) > bestScore){ - bestScore = v(i) - bestIdx = i - } - } - (bestIdx, bestScore) - } - -} - -object NaiveBayes extends NaiveBayes with java.io.Serializable - -/** - * Trainer for the weight normalization vector used by Transform Weight Normalized Complement - * Naive Bayes. See: Rennie et.al.: Tackling the poor assumptions of Naive Bayes Text classifiers, - * ICML 2003, http://people.csail.mit.edu/jrennie/papers/icml03-nb.pdf Sec. 3.2. - * - * @param weightsPerFeature a Vector of summed TF or TF-IDF weights for each word in dictionary. - * @param weightsPerLabel a Vector of summed TF or TF-IDF weights for each label. - * @param alphaI Laplace smoothing factor. Defaut value of 1. - */ -class ComplementaryNBThetaTrainer(private val weightsPerFeature: Vector, - private val weightsPerLabel: Vector, - private val alphaI: Double = 1.0) { - - private val perLabelThetaNormalizer: Vector = weightsPerLabel.like() - private val totalWeightSum: Double = weightsPerLabel.zSum - private var numFeatures: Double = weightsPerFeature.getNumNondefaultElements - - assert(weightsPerFeature != null, "weightsPerFeature vector can not be null") - assert(weightsPerLabel != null, "weightsPerLabel vector can not be null") - - /** - * Train the weight normalization vector for each label - * @param label - * @param featurePerLabelWeight - */ - def train(label: Int, featurePerLabelWeight: Vector) { - val currentLabelWeight = labelWeight(label) - // sum weights for each label including those with zero word counts - for (i <- 0 until featurePerLabelWeight.size) { - val currentFeaturePerLabelWeight = featurePerLabelWeight(i) - updatePerLabelThetaNormalizer(label, - ComplementaryNBClassifier.computeWeight(featureWeight(i), - currentFeaturePerLabelWeight, - totalWeightSum, - currentLabelWeight, - alphaI, - numFeatures) - ) - } - } - - /** - * getter for summed TF or TF-IDF weights by label - * @param label index of label - * @return sum of word TF or TF-IDF weights for label - */ - def labelWeight(label: Int): Double = { - weightsPerLabel(label) - } - - /** - * getter for summed TF or TF-IDF weights by word. - * @param feature index of word. - * @return sum of TF or TF-IDF weights for word. - */ - def featureWeight(feature: Int): Double = { - weightsPerFeature(feature) - } - - /** - * add the magnitude of the current weight to the current - * label's corresponding Vector element. - * @param label index of label to update. - * @param weight weight to add. - */ - def updatePerLabelThetaNormalizer(label: Int, weight: Double) { - perLabelThetaNormalizer(label) = perLabelThetaNormalizer(label) + Math.abs(weight) - } - - /** - * Getter for the weight normalizer vector as indexed by label - * @return a copy of the weight normalizer vector. - */ - def retrievePerLabelThetaNormalizer: Vector = { - perLabelThetaNormalizer.cloned - } - - - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala b/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala deleted file mode 100644 index 8f1413a..0000000 --- a/samsara/src/main/scala/org/apache/mahout/classifier/stats/ClassifierStats.scala +++ /dev/null @@ -1,467 +0,0 @@ -/* - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package org.apache.mahout.classifier.stats - -import java.text.{DecimalFormat, NumberFormat} -import java.util -import org.apache.mahout.math.stats.OnlineSummarizer - - -/** - * Result of a document classification. The label and the associated score (usually probabilty) - */ -class ClassifierResult (private var label: String = null, - private var score: Double = 0.0, - private var logLikelihood: Double = Integer.MAX_VALUE.toDouble) { - - def getLogLikelihood: Double = logLikelihood - - def setLogLikelihood(llh: Double) { - logLikelihood = llh - } - - def getLabel: String = label - - def getScore: Double = score - - def setLabel(lbl: String) { - label = lbl - } - - def setScore(sc: Double) { - score = sc - } - - override def toString: String = { - "ClassifierResult{" + "category='" + label + '\'' + ", score=" + score + '}' - } - -} - -/** - * ResultAnalyzer captures the classification statistics and displays in a tabular manner - * @param labelSet Set of labels to be considered in classification - * @param defaultLabel the default label for an unknown classification - */ -class ResultAnalyzer(private val labelSet: util.Collection[String], defaultLabel: String) { - - val confusionMatrix = new ConfusionMatrix(labelSet, defaultLabel) - val summarizer = new OnlineSummarizer - - private var hasLL: Boolean = false - private var correctlyClassified: Int = 0 - private var incorrectlyClassified: Int = 0 - - - def getConfusionMatrix: ConfusionMatrix = confusionMatrix - - /** - * - * @param correctLabel - * The correct label - * @param classifiedResult - * The classified result - * @return whether the instance was correct or not - */ - def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Boolean = { - val result: Boolean = correctLabel == classifiedResult.getLabel - if (result) { - correctlyClassified += 1 - } - else { - incorrectlyClassified += 1 - } - confusionMatrix.addInstance(correctLabel, classifiedResult) - if (classifiedResult.getLogLikelihood != Integer.MAX_VALUE.toDouble) { - summarizer.add(classifiedResult.getLogLikelihood) - hasLL = true - } - - result - } - - /** Dump the resulting statistics to a string */ - override def toString: String = { - val returnString: StringBuilder = new StringBuilder - returnString.append('\n') - returnString.append("=======================================================\n") - returnString.append("Summary\n") - returnString.append("-------------------------------------------------------\n") - val totalClassified: Int = correctlyClassified + incorrectlyClassified - val percentageCorrect: Double = 100.asInstanceOf[Double] * correctlyClassified / totalClassified - val percentageIncorrect: Double = 100.asInstanceOf[Double] * incorrectlyClassified / totalClassified - val decimalFormatter: NumberFormat = new DecimalFormat("0.####") - returnString.append("Correctly Classified Instances") - .append(": ") - .append(Integer.toString(correctlyClassified)) - .append('\t') - .append(decimalFormatter.format(percentageCorrect)) - .append("%\n") - returnString.append("Incorrectly Classified Instances") - .append(": ") - .append(Integer.toString(incorrectlyClassified)) - .append('\t') - .append(decimalFormatter.format(percentageIncorrect)) - .append("%\n") - returnString.append("Total Classified Instances") - .append(": ") - .append(Integer.toString(totalClassified)) - .append('\n') - returnString.append('\n') - returnString.append(confusionMatrix) - returnString.append("=======================================================\n") - returnString.append("Statistics\n") - returnString.append("-------------------------------------------------------\n") - val normStats: RunningAverageAndStdDev = confusionMatrix.getNormalizedStats - returnString.append("Kappa: \t") - .append(decimalFormatter.format(confusionMatrix.getKappa)) - .append('\n') - returnString.append("Accuracy: \t") - .append(decimalFormatter.format(confusionMatrix.getAccuracy)) - .append("%\n") - returnString.append("Reliability: \t") - .append(decimalFormatter.format(normStats.getAverage * 100.00000001)) - .append("%\n") - returnString.append("Reliability (std dev): \t") - .append(decimalFormatter.format(normStats.getStandardDeviation)) - .append('\n') - returnString.append("Weighted precision: \t") - .append(decimalFormatter.format(confusionMatrix.getWeightedPrecision)) - .append('\n') - returnString.append("Weighted recall: \t") - .append(decimalFormatter.format(confusionMatrix.getWeightedRecall)) - .append('\n') - returnString.append("Weighted F1 score: \t") - .append(decimalFormatter.format(confusionMatrix.getWeightedF1score)) - .append('\n') - if (hasLL) { - returnString.append("Log-likelihood: \t") - .append("mean : \t") - .append(decimalFormatter.format(summarizer.getMean)) - .append('\n') - returnString.append("25%-ile : \t") - .append(decimalFormatter.format(summarizer.getQuartile(1))) - .append('\n') - returnString.append("75%-ile : \t") - .append(decimalFormatter.format(summarizer.getQuartile(3))) - .append('\n') - } - - returnString.toString() - } - - -} - -/** - * - * Interface for classes that can keep track of a running average of a series of numbers. One can add to or - * remove from the series, as well as update a datum in the series. The class does not actually keep track of - * the series of values, just its running average, so it doesn't even matter if you remove/change a value that - * wasn't added. - * - * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverage.java - */ -trait RunningAverage { - - /** - * @param datum - * new item to add to the running average - * @throws IllegalArgumentException - * if datum is { @link Double#NaN} - */ - def addDatum(datum: Double) - - /** - * @param datum - * item to remove to the running average - * @throws IllegalArgumentException - * if datum is { @link Double#NaN} - * @throws IllegalStateException - * if count is 0 - */ - def removeDatum(datum: Double) - - /** - * @param delta - * amount by which to change a datum in the running average - * @throws IllegalArgumentException - * if delta is { @link Double#NaN} - * @throws IllegalStateException - * if count is 0 - */ - def changeDatum(delta: Double) - - def getCount: Int - - def getAverage: Double - - /** - * @return a (possibly immutable) object whose average is the negative of this object's - */ - def inverse: RunningAverage -} - -/** - * - * Extends {@link RunningAverage} by adding standard deviation too. - * - * Ported from org.apache.mahout.cf.taste.impl.common.RunningAverageAndStdDev.java - */ -trait RunningAverageAndStdDev extends RunningAverage { - - /** @return standard deviation of data */ - def getStandardDeviation: Double - - /** - * @return a (possibly immutable) object whose average is the negative of this object's - */ - def inverse: RunningAverageAndStdDev -} - - -class InvertedRunningAverage(private val delegate: RunningAverage) extends RunningAverage { - - override def addDatum(datum: Double) { - throw new UnsupportedOperationException - } - - override def removeDatum(datum: Double) { - throw new UnsupportedOperationException - } - - override def changeDatum(delta: Double) { - throw new UnsupportedOperationException - } - - override def getCount: Int = { - delegate.getCount - } - - override def getAverage: Double = { - -delegate.getAverage - } - - override def inverse: RunningAverage = { - delegate - } -} - - -/** - * - * A simple class that can keep track of a running average of a series of numbers. One can add to or remove - * from the series, as well as update a datum in the series. The class does not actually keep track of the - * series of values, just its running average, so it doesn't even matter if you remove/change a value that - * wasn't added. - * - * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverage.java - */ -class FullRunningAverage(private var count: Int = 0, - private var average: Double = Double.NaN ) extends RunningAverage { - - /** - * @param datum - * new item to add to the running average - */ - override def addDatum(datum: Double) { - count += 1 - if (count == 1) { - average = datum - } - else { - average = average * (count - 1) / count + datum / count - } - } - - /** - * @param datum - * item to remove from the running average - * @throws IllegalStateException - * if count is 0 - */ - override def removeDatum(datum: Double) { - if (count == 0) { - throw new IllegalStateException - } - count -= 1 - if (count == 0) { - average = Double.NaN - } - else { - average = average * (count + 1) / count - datum / count - } - } - - /** - * @param delta - * amount by which to change a datum in the running average - * @throws IllegalStateException - * if count is 0 - */ - override def changeDatum(delta: Double) { - if (count == 0) { - throw new IllegalStateException - } - average += delta / count - } - - override def getCount: Int = { - count - } - - override def getAverage: Double = { - average - } - - override def inverse: RunningAverage = { - new InvertedRunningAverage(this) - } - - override def toString: String = { - String.valueOf(average) - } -} - - -/** - * - * Extends {@link FullRunningAverage} to add a running standard deviation computation. - * Uses Welford's method, as described at http://www.johndcook.com/standard_deviation.html - * - * Ported from org.apache.mahout.cf.taste.impl.common.FullRunningAverageAndStdDev.java - */ -class FullRunningAverageAndStdDev(private var count: Int = 0, - private var average: Double = 0.0, - private var mk: Double = 0.0, - private var sk: Double = 0.0) extends FullRunningAverage with RunningAverageAndStdDev { - - var stdDev: Double = 0.0 - - recomputeStdDev - - def getMk: Double = { - mk - } - - def getSk: Double = { - sk - } - - override def getStandardDeviation: Double = { - stdDev - } - - override def addDatum(datum: Double) { - super.addDatum(datum) - val count: Int = getCount - if (count == 1) { - mk = datum - sk = 0.0 - } - else { - val oldmk: Double = mk - val diff: Double = datum - oldmk - mk += diff / count - sk += diff * (datum - mk) - } - recomputeStdDev - } - - override def removeDatum(datum: Double) { - val oldCount: Int = getCount - super.removeDatum(datum) - val oldmk: Double = mk - mk = (oldCount * oldmk - datum) / (oldCount - 1) - sk -= (datum - mk) * (datum - oldmk) - recomputeStdDev - } - - /** - * @throws UnsupportedOperationException - */ - override def changeDatum(delta: Double) { - throw new UnsupportedOperationException - } - - private def recomputeStdDev { - val count: Int = getCount - stdDev = if (count > 1) Math.sqrt(sk / (count - 1)) else Double.NaN - } - - override def inverse: RunningAverageAndStdDev = { - new InvertedRunningAverageAndStdDev(this) - } - - override def toString: String = { - String.valueOf(String.valueOf(getAverage) + ',' + stdDev) - } - -} - - -/** - * - * @param delegate RunningAverageAndStdDev instance - * - * Ported from org.apache.mahout.cf.taste.impl.common.InvertedRunningAverageAndStdDev.java - */ -class InvertedRunningAverageAndStdDev(private val delegate: RunningAverageAndStdDev) extends RunningAverageAndStdDev { - - /** - * @throws UnsupportedOperationException - */ - override def addDatum(datum: Double) { - throw new UnsupportedOperationException - } - - /** - * @throws UnsupportedOperationException - */ - - override def removeDatum(datum: Double) { - throw new UnsupportedOperationException - } - - /** - * @throws UnsupportedOperationException - */ - override def changeDatum(delta: Double) { - throw new UnsupportedOperationException - } - - override def getCount: Int = { - delegate.getCount - } - - override def getAverage: Double = { - -delegate.getAverage - } - - override def getStandardDeviation: Double = { - delegate.getStandardDeviation - } - - override def inverse: RunningAverageAndStdDev = { - delegate - } -} - - - - http://git-wip-us.apache.org/repos/asf/mahout/blob/ef6d93a3/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala ---------------------------------------------------------------------- diff --git a/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala b/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala deleted file mode 100644 index 328d27b..0000000 --- a/samsara/src/main/scala/org/apache/mahout/classifier/stats/ConfusionMatrix.scala +++ /dev/null @@ -1,460 +0,0 @@ -/* - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -package org.apache.mahout.classifier.stats - -import java.util -import org.apache.commons.math3.stat.descriptive.moment.Mean // This is brought in by mahout-math -import org.apache.mahout.math.{DenseMatrix, Matrix} -import scala.collection.mutable -import scala.collection.JavaConversions._ - -/** - * - * Ported from org.apache.mahout.classifier.ConfusionMatrix.java - * - * The ConfusionMatrix Class stores the result of Classification of a Test Dataset. - * - * The fact of whether there is a default is not stored. A row of zeros is the only indicator that there is no default. - * - * See http://en.wikipedia.org/wiki/Confusion_matrix for background - * - * - * @param labels The labels to consider for classification - * @param defaultLabel default unknown label - */ -class ConfusionMatrix(private var labels: util.Collection[String] = null, - private var defaultLabel: String = "unknown") { - /** - * Matrix Constructor - * @param m a DenseMatrix with RowLabelBindings - */ -// def this(m: Matrix) { -// this() -// confusionMatrix = Array.ofDim[Int](m.numRows, m.numRows) -// setMatrix(m) -// } - - // val LOG: Logger = LoggerFactory.getLogger(classOf[ConfusionMatrix]) - - var confusionMatrix = Array.ofDim[Int](labels.size + 1, labels.size + 1) - - val labelMap = new mutable.HashMap[String,Integer]() - - var samples: Int = 0 - - var i: Integer = 0 - for (label <- labels) { - labelMap.put(label, i) - i+=1 - } - labelMap.put(defaultLabel, i) - - - def getConfusionMatrix: Array[Array[Int]] = confusionMatrix - - def getLabels = labelMap.keys.toList - - def numLabels: Int = labelMap.size - - def getAccuracy(label: String): Double = { - val labelId: Int = labelMap(label) - var labelTotal: Int = 0 - var correct: Int = 0 - for (i <- 0 until numLabels) { - labelTotal += confusionMatrix(labelId)(i) - if (i == labelId) { - correct += confusionMatrix(labelId)(i) - } - } - - 100.0 * correct / labelTotal - } - - def getAccuracy: Double = { - var total: Int = 0 - var correct: Int = 0 - for (i <- 0 until numLabels) { - for (j <- 0 until numLabels) { - total += confusionMatrix(i)(j) - if (i == j) { - correct += confusionMatrix(i)(j) - } - } - } - - 100.0 * correct / total - } - - /** Sum of true positives and false negatives */ - private def getActualNumberOfTestExamplesForClass(label: String): Int = { - val labelId: Int = labelMap(label) - var sum: Int = 0 - for (i <- 0 until numLabels) { - sum += confusionMatrix(labelId)(i) - } - sum - } - - def getPrecision(label: String): Double = { - val labelId: Int = labelMap(label) - val truePositives: Int = confusionMatrix(labelId)(labelId) - var falsePositives: Int = 0 - - for (i <- 0 until numLabels) { - if (i != labelId) { - falsePositives += confusionMatrix(i)(labelId) - } - } - - if (truePositives + falsePositives == 0) { - 0 - } else { - (truePositives.asInstanceOf[Double]) / (truePositives + falsePositives) - } - } - - - def getWeightedPrecision: Double = { - val precisions: Array[Double] = new Array[Double](numLabels) - val weights: Array[Double] = new Array[Double](numLabels) - var index: Int = 0 - for (label <- labelMap.keys) { - precisions(index) = getPrecision(label) - weights(index) = getActualNumberOfTestExamplesForClass(label) - index += 1 - } - new Mean().evaluate(precisions, weights) - } - - def getRecall(label: String): Double = { - val labelId: Int = labelMap(label) - val truePositives: Int = confusionMatrix(labelId)(labelId) - var falseNegatives: Int = 0 - for (i <- 0 until numLabels) { - if (i != labelId) { - falseNegatives += confusionMatrix(labelId)(i) - } - } - - if (truePositives + falseNegatives == 0) { - 0 - } else { - (truePositives.asInstanceOf[Double]) / (truePositives + falseNegatives) - } - } - - def getWeightedRecall: Double = { - val recalls: Array[Double] = new Array[Double](numLabels) - val weights: Array[Double] = new Array[Double](numLabels) - var index: Int = 0 - for (label <- labelMap.keys) { - recalls(index) = getRecall(label) - weights(index) = getActualNumberOfTestExamplesForClass(label) - index += 1 - } - new Mean().evaluate(recalls, weights) - } - - def getF1score(label: String): Double = { - val precision: Double = getPrecision(label) - val recall: Double = getRecall(label) - if (precision + recall == 0) { - 0 - } else { - 2 * precision * recall / (precision + recall) - } - } - - def getWeightedF1score: Double = { - val f1Scores: Array[Double] = new Array[Double](numLabels) - val weights: Array[Double] = new Array[Double](numLabels) - var index: Int = 0 - for (label <- labelMap.keys) { - f1Scores(index) = getF1score(label) - weights(index) = getActualNumberOfTestExamplesForClass(label) - index += 1 - } - new Mean().evaluate(f1Scores, weights) - } - - def getReliability: Double = { - var count: Int = 0 - var accuracy: Double = 0 - for (label <- labelMap.keys) { - if (!(label == defaultLabel)) { - accuracy += getAccuracy(label) - } - count += 1 - } - accuracy / count - } - - /** - * Accuracy v.s. randomly classifying all samples. - * kappa() = (totalAccuracy() - randomAccuracy()) / (1 - randomAccuracy()) - * Cohen, Jacob. 1960. A coefficient of agreement for nominal scales. - * Educational And Psychological Measurement 20:37-46. - * - * Formula and variable names from: - * http://www.yale.edu/ceo/OEFS/Accuracy.pdf - * - * @return double - */ - def getKappa: Double = { - var a: Double = 0.0 - var b: Double = 0.0 - for (i <- 0 until confusionMatrix.length) { - a += confusionMatrix(i)(i) - var br: Int = 0 - for (j <- 0 until confusionMatrix.length) { - br += confusionMatrix(i)(j) - } - var bc: Int = 0 - //TODO: verify this as an iterator - for (vec <- confusionMatrix) { - bc += vec(i) - } - b += br * bc - } - (samples * a - b) / (samples * samples - b) - } - - def getCorrect(label: String): Int = { - val labelId: Int = labelMap(label) - confusionMatrix(labelId)(labelId) - } - - def getTotal(label: String): Int = { - val labelId: Int = labelMap(label) - var labelTotal: Int = 0 - for (i <- 0 until numLabels) { - labelTotal += confusionMatrix(labelId)(i) - } - labelTotal - } - - /** - * Standard deviation of normalized producer accuracy - * Not a standard score - * @return double - */ - def getNormalizedStats: RunningAverageAndStdDev = { - val summer = new FullRunningAverageAndStdDev() - for (d <- 0 until confusionMatrix.length) { - var total: Double = 0.0 - for (j <- 0 until confusionMatrix.length) { - total += confusionMatrix(d)(j) - } - summer.addDatum(confusionMatrix(d)(d) / (total + 0.000001)) - } - summer - } - - def addInstance(correctLabel: String, classifiedResult: ClassifierResult): Unit = { - samples += 1 - incrementCount(correctLabel, classifiedResult.getLabel) - } - - def addInstance(correctLabel: String, classifiedLabel: String): Unit = { - samples += 1 - incrementCount(correctLabel, classifiedLabel) - } - - def getCount(correctLabel: String, classifiedLabel: String): Int = { - if (!labelMap.containsKey(correctLabel)) { - // LOG.warn("Label {} did not appear in the training examples", correctLabel) - return 0 - } - assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel) - val correctId: Int = labelMap(correctLabel) - val classifiedId: Int = labelMap(classifiedLabel) - confusionMatrix(correctId)(classifiedId) - } - - def putCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = { - if (!labelMap.containsKey(correctLabel)) { - // LOG.warn("Label {} did not appear in the training examples", correctLabel) - return - } - assert(labelMap.containsKey(classifiedLabel), "Label not found: " + classifiedLabel) - val correctId: Int = labelMap(correctLabel) - val classifiedId: Int = labelMap(classifiedLabel) - if (confusionMatrix(correctId)(classifiedId) == 0.0 && count != 0) { - samples += 1 - } - confusionMatrix(correctId)(classifiedId) = count - } - - def incrementCount(correctLabel: String, classifiedLabel: String, count: Int): Unit = { - putCount(correctLabel, classifiedLabel, count + getCount(correctLabel, classifiedLabel)) - } - - def incrementCount(correctLabel: String, classifiedLabel: String): Unit = { - incrementCount(correctLabel, classifiedLabel, 1) - } - - def getDefaultLabel: String = { - defaultLabel - } - - def merge(b: ConfusionMatrix): ConfusionMatrix = { - assert(labelMap.size == b.getLabels.size, "The label sizes do not match") - for (correctLabel <- this.labelMap.keys) { - for (classifiedLabel <- this.labelMap.keys) { - incrementCount(correctLabel, classifiedLabel, b.getCount(correctLabel, classifiedLabel)) - } - } - this - } - - def getMatrix: Matrix = { - val length: Int = confusionMatrix.length - val m: Matrix = new DenseMatrix(length, length) - - val labels: java.util.HashMap[String, Integer] = new java.util.HashMap() - - for (r <- 0 until length) { - for (c <- 0 until length) { - m.set(r, c, confusionMatrix(r)(c)) - } - } - - for (entry <- labelMap.entrySet) { - labels.put(entry.getKey, entry.getValue) - } - m.setRowLabelBindings(labels) - m.setColumnLabelBindings(labels) - - m - } - - def setMatrix(m: Matrix) : Unit = { - val length: Int = confusionMatrix.length - if (m.numRows != m.numCols) { - throw new IllegalArgumentException("ConfusionMatrix: matrix(" + m.numRows + ',' + m.numCols + ") must be square") - } - - for (r <- 0 until length) { - for (c <- 0 until length) { - confusionMatrix(r)(c) = Math.round(m.get(r, c)).toInt - } - } - - var labels = m.getRowLabelBindings - if (labels == null) { - labels = m.getColumnLabelBindings - } - - if (labels != null) { - val sorted: Array[String] = sortLabels(labels) - verifyLabels(length, sorted) - labelMap.clear - for (i <- 0 until length) { - labelMap.put(sorted(i), i) - } - } - } - - def verifyLabels(length: Int, sorted: Array[String]): Unit = { - assert(sorted.length == length, "One label, one row") - for (i <- 0 until length) { - if (sorted(i) == null) { - assert(false, "One label, one row") - } - } - } - - def sortLabels(labels: java.util.Map[String, Integer]): Array[String] = { - val sorted: Array[String] = new Array[String](labels.size) - for (entry <- labels.entrySet) { - sorted(entry.getValue) = entry.getKey - } - - sorted - } - - /** - * This is overloaded. toString() is not a formatted report you print for a manager :) - * Assume that if there are no default assignments, the default feature was not used - */ - override def toString: String = { - - val returnString: StringBuilder = new StringBuilder(200) - - returnString.append("=======================================================").append('\n') - returnString.append("Confusion Matrix\n") - returnString.append("-------------------------------------------------------").append('\n') - - val unclassified: Int = getTotal(defaultLabel) - - for (entry <- this.labelMap.entrySet) { - if (!((entry.getKey == defaultLabel) && unclassified == 0)) { - returnString.append(getSmallLabel(entry.getValue) + " ").append('\t') - } - } - - returnString.append("<--Classified as").append('\n') - - for (entry <- this.labelMap.entrySet) { - if (!((entry.getKey == defaultLabel) && unclassified == 0)) { - val correctLabel: String = entry.getKey - var labelTotal: Int = 0 - - for (classifiedLabel <- this.labelMap.keySet) { - if (!((classifiedLabel == defaultLabel) && unclassified == 0)) { - returnString.append(Integer.toString(getCount(correctLabel, classifiedLabel)) + " ") - .append('\t') - labelTotal += getCount(correctLabel, classifiedLabel) - } - } - returnString.append(" | ").append(String.valueOf(labelTotal) + " ") - .append('\t') - .append(getSmallLabel(entry.getValue) + " ") - .append(" = ") - .append(correctLabel) - .append('\n') - } - } - - if (unclassified > 0) { - returnString.append("Default Category: ") - .append(defaultLabel) - .append(": ") - .append(unclassified) - .append('\n') - } - returnString.append('\n') - - returnString.toString() - } - - - def getSmallLabel(i: Int): String = { - var value: Int = i - val returnString: StringBuilder = new StringBuilder - do { - val n: Int = value % 26 - returnString.insert(0, ('a' + n).asInstanceOf[Char]) - value /= 26 - } while (value > 0) - - returnString.toString() - } - - -}
