[ https://issues.apache.org/jira/browse/MAHOUT-1885?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=15564584#comment-15564584 ]
ASF GitHub Bot commented on MAHOUT-1885: ---------------------------------------- Github user sscdotopen commented on a diff in the pull request: https://github.com/apache/mahout/pull/261#discussion_r82728207 --- Diff: viennacl-omp/src/main/scala/org/apache/mahout/viennacl/openmp/OMPMMul.scala --- @@ -0,0 +1,448 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.viennacl.openmp + +import org.apache.mahout.logging._ +import org.apache.mahout.math +import org.apache.mahout.math._ +import org.apache.mahout.math.flavor.{BackEnum, TraversingStructureEnum} +import org.apache.mahout.math.function.Functions +import org.apache.mahout.math.scalabindings.RLikeOps._ +import org.apache.mahout.math.scalabindings._ +import org.apache.mahout.viennacl.openmp.javacpp.Functions._ +import org.apache.mahout.viennacl.openmp.javacpp.LinalgFunctions._ +import org.apache.mahout.viennacl.openmp.javacpp.{CompressedMatrix, Context, DenseRowMatrix} + +import scala.collection.JavaConversions._ + +object OMPMMul extends MMBinaryFunc { + + private final implicit val log = getLog(OMPMMul.getClass) + + override def apply(a: Matrix, b: Matrix, r: Option[Matrix]): Matrix = { + + require(a.ncol == b.nrow, "Incompatible matrix sizes in matrix multiplication.") + + val (af, bf) = (a.getFlavor, b.getFlavor) + val backs = (af.getBacking, bf.getBacking) + val sd = (af.getStructure, math.scalabindings.densityAnalysis(a), bf.getStructure, densityAnalysis(b)) + + + try { + + val alg: MMulAlg = backs match { + + // Both operands are jvm memory backs. + case (BackEnum.JVMMEM, BackEnum.JVMMEM) ⇒ + + sd match { + + // Multiplication cases by a diagonal matrix. + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.COLWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagCW + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.SPARSECOLWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagCW + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.ROWWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagRW + case (TraversingStructureEnum.VECTORBACKED, _, TraversingStructureEnum.SPARSEROWWISE, _) + if a.isInstanceOf[DiagonalMatrix] ⇒ jvmDiagRW + + case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmCWDiag + case (TraversingStructureEnum.SPARSECOLWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmCWDiag + case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmRWDiag + case (TraversingStructureEnum.SPARSEROWWISE, _, TraversingStructureEnum.VECTORBACKED, _) + if b.isInstanceOf[DiagonalMatrix] ⇒ jvmRWDiag + + // Dense-dense cases + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) if a eq b.t ⇒ ompDRWAAt + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) if a.t eq b ⇒ ompDRWAAt + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.COLWISE, true) ⇒ ompRWCW + case (TraversingStructureEnum.ROWWISE, true, TraversingStructureEnum.ROWWISE, true) ⇒ jvmRWRW + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.COLWISE, true) ⇒ jvmCWCW + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) if a eq b.t ⇒ jvmDCWAAt + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) if a.t eq b ⇒ jvmDCWAAt + case (TraversingStructureEnum.COLWISE, true, TraversingStructureEnum.ROWWISE, true) ⇒ jvmCWRW + + // Sparse row matrix x sparse row matrix (array of vectors) + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.ROWWISE, false) ⇒ ompSparseRWRW + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.COLWISE, false) ⇒ jvmSparseRWCW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.ROWWISE, false) ⇒ jvmSparseCWRW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.COLWISE, false) ⇒ jvmSparseCWCW + + // Sparse matrix x sparse matrix (hashtable of vectors) + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ + ompSparseRowRWRW + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ + jvmSparseRowRWCW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ + jvmSparseRowCWRW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ + jvmSparseRowCWCW + + // Sparse matrix x non-like + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ ompSparseRowRWRW + case (TraversingStructureEnum.SPARSEROWWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseRowRWCW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseRowCWRW + case (TraversingStructureEnum.SPARSECOLWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseCWCW + case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ ompSparseRWRW + case (TraversingStructureEnum.ROWWISE, _, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ jvmSparseRWCW + case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.SPARSEROWWISE, false) ⇒ jvmSparseCWRW + case (TraversingStructureEnum.COLWISE, _, TraversingStructureEnum.SPARSECOLWISE, false) ⇒ jvmSparseRowCWCW + + // Everything else including at least one sparse LHS or RHS argument + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ ompSparseRWRW + case (TraversingStructureEnum.ROWWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseRWCW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.ROWWISE, _) ⇒ jvmSparseCWRW + case (TraversingStructureEnum.COLWISE, false, TraversingStructureEnum.COLWISE, _) ⇒ jvmSparseCWCW2flips + + // Sparse methods are only effective if the first argument is sparse, so we need to do a swap. + case (_, _, _, false) ⇒ (a, b, r) ⇒ apply(b.t, a.t, r.map { + _.t + }).t + + // Default jvm-jvm case. + // for some reason a SrarseRowMatrix DRM %*% SrarseRowMatrix DRM was dumping off to here + case _ ⇒ ompRWCW + } + } + + alg(a, b, r) + } catch { + // TODO FASTHACK: just revert to JVM if there is an exception.. + // eg. java.lang.nullPointerException if more openCL contexts + // have been created than number of GPU cards. + // better option wuold be to fall back to OpenCl First. + case ex: Exception => + println(ex.getMessage + "falling back to JVM MMUL") + return MMul(a, b, r) + } + } + + type MMulAlg = MMBinaryFunc + + @inline + private def ompRWCW(a: Matrix, b: Matrix, r: Option[Matrix] = None): Matrix = { + println("ompRWCW") + // + // require(r.forall(mxR ⇒ mxR.nrow == a.nrow && mxR.ncol == b.ncol)) + // val (m, n) = (a.nrow, b.ncol) + // + // val mxR = r.getOrElse(if (densityAnalysis(a)) a.like(m, n) else b.like(m, n)) + // + // for (row ← 0 until mxR.nrow; col ← 0 until mxR.ncol) { + // // this vector-vector should be sort of optimized, right? + // mxR(row, col) = a(row, ::) dot b(::, col) + // } + // mxR + + val hasElementsA = a.zSum() > 0.0 + val hasElementsB = b.zSum() > 0.0 + + // A has a sparse matrix structure of unknown size. We do not want to + // simply convert it to a Dense Matrix which may result in an OOM error. + + // If it is empty use JVM MMul, since we can not convert it to a VCL CSR Matrix. + if (!hasElementsA) { + println("Matrix a has zero elements can not convert to CSR") + return MMul(a, b, r) + } + + // CSR matrices are efficient up to 50% non-zero + if(b.getFlavor.isDense) { --- End diff -- space after if > Inital Implementation of VCL Bindings > ------------------------------------- > > Key: MAHOUT-1885 > URL: https://issues.apache.org/jira/browse/MAHOUT-1885 > Project: Mahout > Issue Type: Improvement > Components: Math > Affects Versions: 0.12.2 > Reporter: Andrew Palumbo > Assignee: Andrew Palumbo > Fix For: 0.13.0 > > > Push a working experimental branch of VCL bindings into master. There is > still a lot of work to be done. All tests are passing, At the moment there > am opening this JIRA mostly to get a number for PR and to test profiles > against on travis. -- This message was sent by Atlassian JIRA (v6.3.4#6332)