This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 14a79af [SYSTEMML-2468] Improved matrix histogram estimator for left-deep trees 14a79af is described below commit 14a79af677979f80f10328e67767822f6d43d2ff Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Mon Jan 14 21:47:27 2019 +0100 [SYSTEMML-2468] Improved matrix histogram estimator for left-deep trees This patch improves the matrix histogram sparsity estimator for combinations of derived and exact sketches as they appear for example in left-deep trees of matrix product chains. Specifically, we now use a generalized code path that exploits extension vectors if they are available and otherwise simply uses zero instead. --- .../apache/sysml/hops/estim/EstimatorMatrixHistogram.java | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java index 57fc97e..a82feed 100644 --- a/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java +++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorMatrixHistogram.java @@ -168,8 +168,8 @@ public class EstimatorMatrixHistogram extends SparsityEstimator nnz += (long)h1.cNnz[j] * h2.rNnz[j]; } //special case, with hybrid exact and approximate output - else if(h1.cNnz1e!=null && h2.rNnz1e != null) { - //note: normally h1.getRows()*h2.getCols() would define mnOut + else if(h1.cNnz1e!=null || h2.rNnz1e != null) { + //NOTE: normally h1.getRows()*h2.getCols() would define mnOut //but by leveraging the knowledge of rows/cols w/ <=1 nnz, we account //that exact and approximate fractions touch different areas long mnOut = _useExtended ? @@ -177,12 +177,15 @@ public class EstimatorMatrixHistogram extends SparsityEstimator (long)(h1.getRows()-h1.rN1) * (h2.getCols()-h2.cN1); double spOutRest = 0; for( int j=0; j<h1.getCols(); j++ ) { + //zero for non-existing extension vectors + int h1c1ej = (h1.cNnz1e != null) ? h1.cNnz1e[j] : 0; + int h2r1ej = (h2.rNnz1e != null) ? h2.rNnz1e[j] : 0; //exact fractions, w/o double counting - nnz += (long)h1.cNnz1e[j] * h2.rNnz[j]; - nnz += (long)(h1.cNnz[j]-h1.cNnz1e[j]) * h2.rNnz1e[j]; + nnz += (long)h1c1ej * h2.rNnz[j]; + nnz += (long)(h1.cNnz[j]-h1c1ej) * h2r1ej; //approximate fraction, w/o double counting - double lsp = (double)(h1.cNnz[j]-h1.cNnz1e[j]) - * (h2.rNnz[j]-h2.rNnz1e[j]) / mnOut; + double lsp = (double)(h1.cNnz[j]-h1c1ej) + * (h2.rNnz[j]-h2r1ej) / mnOut; spOutRest = spOutRest + lsp - spOutRest*lsp; } nnz += (long)(spOutRest * mnOut);