This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemml.git


The following commit(s) were added to refs/heads/master by this push:
     new e443eff  [SYSTEMML-2289] Additional sampling-based sparsity estimator 
baseline
e443eff is described below

commit e443eff949b48f45d1453a6cbf483b87a612c307
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Feb 9 15:38:21 2019 +0100

    [SYSTEMML-2289] Additional sampling-based sparsity estimator baseline
    
    This patch adds an additional baseline sparsity estimator based on
    sampling and hashing, which implements the apporach described in
    
    Rasmus Resen Amossen, Andrea Campagna, Rasmus Pagh: Better Size
    Estimation for Sparse Matrix Products. Algorithmica 69(3): 741-757
    (2014)
    
    Credit: We're grateful to the authors who shared their code. This 
implementation improves upon it by fitting the SparsityEstimator API, support 
for binary matrix products, avoid unnecessary file access, use Well1024a for 
seeding local RNGs, and generally improve performance.
---
 .../apache/sysml/hops/estim/EstimatorSample.java   |   2 +-
 .../apache/sysml/hops/estim/EstimatorSampleRa.java | 268 +++++++++++++++++++++
 .../functions/estim/OuterProductTest.java          |  26 +-
 .../functions/estim/SelfProductTest.java           |  21 ++
 .../functions/estim/SquaredProductTest.java        |  85 ++++++-
 5 files changed, 398 insertions(+), 4 deletions(-)

diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
index ec624f0..821aa73 100644
--- a/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSample.java
@@ -56,7 +56,7 @@ public class EstimatorSample extends SparsityEstimator
        }
        
        public EstimatorSample(double sampleFrac, boolean extended) {
-               if( sampleFrac < 0 || sampleFrac > 1.0 )
+               if( sampleFrac <= 0 || sampleFrac > 1.0 )
                        throw new DMLRuntimeException("Invalid sample fraction: 
"+sampleFrac);
                _frac = sampleFrac;
                _extended = extended;
diff --git a/src/main/java/org/apache/sysml/hops/estim/EstimatorSampleRa.java 
b/src/main/java/org/apache/sysml/hops/estim/EstimatorSampleRa.java
new file mode 100644
index 0000000..2e39d02
--- /dev/null
+++ b/src/main/java/org/apache/sysml/hops/estim/EstimatorSampleRa.java
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysml.hops.estim;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.math3.random.Well1024a;
+import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.DMLRuntimeException;
+import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
+import org.apache.sysml.runtime.matrix.data.LibMatrixDatagen;
+import org.apache.sysml.runtime.matrix.data.MatrixBlock;
+import org.apache.sysml.runtime.matrix.data.SparseBlock;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.Random;
+
+/**
+ * This estimator implements an approach based on row/column sampling
+ * 
+ * Rasmus Resen Amossen, Andrea Campagna, Rasmus Pagh:
+ * Better Size Estimation for Sparse Matrix Products. Algorithmica 69(3): 
741-757 (2014)
+ * 
+ * Credit: This code is based on the original implementation provided by the 
authors,
+ * modified to fit the SparsityEstimator API, support binary matrix products, 
avoid 
+ * unnecessary file access, use Well1024a for seeding local RNGs, and 
generally 
+ * improve performance.
+ */
+public class EstimatorSampleRa extends SparsityEstimator 
+{
+       private static final int RUNS = -1;
+       private static final double SAMPLE_FRACTION = 0.1; //10%
+       private static final double EPSILON = 0.05; // Multiplicative error
+       private static final double DELTA = 0.1; // Probability of error
+       private static final int K = -1;
+       
+       private final int _runs;
+       private final double _sampleFrac; //sample fraction (0,1]
+       private final double _eps; //target error
+       private final double _delta; //probability of error
+       private final int _k; //k-minimum hash values
+       
+       private final Well1024a _bigrand;
+       
+       private double[] h1; // hash "function" rows A
+       private double[] h2; // hash "function" cols B
+       private double[] h3; // hash "function" cols A
+       private double[] h4; // hash "function" rows B
+       
+       public EstimatorSampleRa() {
+               this(RUNS, SAMPLE_FRACTION, EPSILON, DELTA, K);
+       }
+       
+       public EstimatorSampleRa(double sampleFrac) {
+               this(RUNS, sampleFrac, EPSILON, DELTA, K);
+       }
+       
+       public EstimatorSampleRa(int runs, double sampleFrac, double eps, 
double delta, int k) {
+               if( sampleFrac <= 0 || sampleFrac > 1.0 )
+                       throw new DMLRuntimeException("Invalid sample fraction: 
"+sampleFrac);
+               _sampleFrac = sampleFrac;
+               _eps = eps;
+               _delta = delta;
+               
+               //if runs/k not specified compute from epsilon and delta
+               _runs = (runs < 0) ? (int) (Math.log(1/_delta) / Math.log(2)) : 
runs;
+               _k = (k < 0) ? (int) Math.ceil(1 / (_eps * _eps)) : k;
+               
+               //construct Well1024a generator for good random numbers
+               _bigrand = LibMatrixDatagen.setupSeedsForRand(_k);
+       }
+       
+       @Override
+       public MatrixCharacteristics estim(MMNode root) {
+               LOG.warn("Recursive estimates not supported by 
EstimatorSampleRa,"
+                       + " falling back to EstimatorBasicAvg.");
+               return new EstimatorBasicAvg().estim(root);
+       }
+
+       @Override
+       public double estim(MatrixBlock m1, MatrixBlock m2, OpCode op) {
+               if( op == OpCode.MM )
+                       return estim(m1, m2);
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public double estim(MatrixBlock m, OpCode op) {
+               throw new NotImplementedException();
+       }
+       
+       @Override
+       public double estim(MatrixBlock m1, MatrixBlock m2) {
+               // perform runs to obtain desired precision (Chernoff bound)
+               double[] results = new double[_runs];
+               for(int i=0; i<_runs; i++) {
+                       initHashArrays(m1.getNumRows(),
+                               m1.getNumColumns(), m2.getNumColumns());
+                       results[i] = estimateSize(m1, m2);
+               }
+               
+               //compute estimate as median of all results
+               //error bound: nnz*(1-10/sqrt(k)), nnz*(1+10/sqrt(k)));
+               Arrays.sort(results);
+               long nnz = (long) results[_runs/2];
+               
+               //convert from nnz to sparsity
+               return OptimizerUtils.getSparsity(
+                       m1.getNumRows(), m2.getNumColumns(), nnz);
+       }
+       
+       private void initHashArrays(int m, int n, int l) {
+               if( h1 == null ) {
+                       h1 = new double[m];
+                       h2 = new double[l];
+                       h3 = new double[l];
+                       h4 = new double[m];
+               }
+               
+               //create local random number generator
+               Random rand = new Random(_bigrand.nextLong());
+               for(int t=0; t < h1.length; t++)
+                       h1[t] = rand.nextDouble();
+               for(int t=0; t < h2.length; t++)
+                       h2[t] = rand.nextDouble();
+               for(int t=0; t < h3.length; t++)
+                       h3[t] = rand.nextDouble();
+               for(int t=0; t < h4.length; t++)
+                       h4[t] = rand.nextDouble();
+       }
+       
+       private double estimateSize(MatrixBlock mb1, MatrixBlock mb2) {
+               AdjacencyLists A = new AdjacencyLists(mb1, false);
+               AdjacencyLists C = new AdjacencyLists(mb2, true);
+               ArrayList<Double> sketch = new ArrayList<>();
+               
+               //pick a large p, it will soon be decreased anyway
+               double p = 1;
+               int bufferSize = 0;
+               
+               for( int i=0; i<mb1.getNumColumns(); i++ ) {
+                       ArrayList<Integer> Ai = A.getList(i);
+                       ArrayList<Integer> Ci = C.getList(i);
+                       if( Ai.isEmpty() || Ci.isEmpty() )
+                               continue;
+                       
+                       //get Ai and Ci sorted by hash values h1, h2
+                       Integer[] x = Ai.stream().sorted(Comparator.comparing(a 
-> h1[a])).toArray(Integer[]::new);
+                       Integer[] y = Ci.stream().sorted(Comparator.comparing(a 
-> h2[a])).toArray(Integer[]::new);
+
+                       int s = 0;
+                       int sHat = 0;
+                       for(int t=0; t<y.length; t++) {
+                               int xIdx = (sHat > 0) ? sHat-1 : x.length-1;
+                               while( h(x[sHat], y[t]) > h(x[xIdx], y[t]))
+                                       sHat = (sHat + 1) % x.length;
+                               s = sHat;
+                               //column sampling
+                               if(h3[y[t]] > _sampleFrac)
+                                       continue;
+                               int num = 0;
+                               while(h(x[s], y[t]) < p && num < x.length) {
+                                       //row sampling
+                                       if(h4[x[s]] > _sampleFrac) {
+                                               s = (s + 1) % x.length;
+                                               num++;
+                                               continue;
+                                       }
+                                       //add hash to sketch
+                                       sketch.add(h(x[s], y[t]));
+                                       bufferSize++;
+                                       //truncate to size k if necessary
+                                       if(bufferSize > _k) {
+                                               sortAndTruncate(sketch);
+                                               if (sketch.size()==_k)
+                                                       p = 
sketch.get(sketch.size()-1);
+                                               bufferSize = 0;
+                                       }
+                                       s = (s + 1) % x.length;
+                                       num++;
+                               }
+                       }
+               }
+
+               //all pairs generated, truncate and finally estimate size
+               sortAndTruncate(sketch);
+               if(sketch.size() == _k) {
+                       //k'th smallest elements are at the top in the sketch
+                       double v = sketch.get(sketch.size()-1);
+                       return _k/(v*_sampleFrac*_sampleFrac);
+               }
+               else {
+                       return sketch.size()/(_sampleFrac*_sampleFrac);
+               }
+       }
+       
+       public void sortAndTruncate(ArrayList<Double> sketch) {
+               Collections.sort(sketch);
+               //remove duplicates (within some epsilon precision)
+               for(int t=1; t < sketch.size(); t++) {
+                       //sketch.get(t) is always larger than sketch.get(t-1)
+                       if(sketch.get(t)/sketch.get(t-1) < (1+1.0E-10)) {
+                               sketch.remove(t); t--;
+                       }
+               }
+               //truncate after the first k elements
+               sketch.subList(Math.min(sketch.size(),_k), 
sketch.size()).clear();
+       }
+       
+       public double h(int x, int y) {
+               //h(x,y) hash function
+               double a = (h1[x] - h2[y]);
+               return (a < 0) ? a + 1 : a;
+       }
+       
+       private class AdjacencyLists {
+               private ArrayList<Integer>[] indexes;
+               
+               @SuppressWarnings("unchecked")
+               public AdjacencyLists(MatrixBlock mb, boolean row) {
+                       int len = row ? mb.getNumRows() : mb.getNumColumns();
+                       indexes = new ArrayList[len];
+                       for(int i=0; i<len; i++)
+                               indexes[i] = new ArrayList<Integer>();
+                       if( mb.isEmptyBlock(false) )
+                               return; //early abort
+                       if( mb.isInSparseFormat() ) {
+                               SparseBlock sblock = mb.getSparseBlock();
+                               for(int i=0; i<sblock.numRows(); i++)  {
+                                       if( sblock.isEmpty(i) ) continue;
+                                       int apos = sblock.pos(i);
+                                       int alen = sblock.size(i);
+                                       int[] aix = sblock.indexes(i);
+                                       for(int k=apos; k<apos+alen; k++)
+                                               
indexes[row?i:aix[k]].add(row?aix[k]:i);
+                               }
+                       }
+                       else {
+                               for(int i=0; i<mb.getNumRows(); i++)
+                                       for(int j=0; j<mb.getNumColumns(); j++)
+                                               if( mb.quickGetValue(i, j) != 0 
)
+                                                       
indexes[row?i:j].add(row?j:i);
+                       }
+               }
+               
+               public ArrayList<Integer> getList(int i) {
+                       return indexes[i];
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java
index 70ea63b..8de5a41 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/OuterProductTest.java
@@ -26,6 +26,7 @@ import org.apache.sysml.hops.estim.EstimatorBitsetMM;
 import org.apache.sysml.hops.estim.EstimatorDensityMap;
 import org.apache.sysml.hops.estim.EstimatorMatrixHistogram;
 import org.apache.sysml.hops.estim.EstimatorSample;
+import org.apache.sysml.hops.estim.EstimatorSampleRa;
 import org.apache.sysml.hops.estim.SparsityEstimator;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -139,14 +140,35 @@ public class OuterProductTest extends AutomatedTestBase
                runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n, 
case2);
        }
        
+       @Test
+       public void testSamplingRaDefCase1() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, k, n, 
case1);
+       }
+       
+       @Test
+       public void testSamplingRaDefCase2() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, k, n, 
case2);
+       }
+       
+       @Test
+       public void testSamplingRa20Case1() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, k, n, 
case1);
+       }
+       
+       @Test
+       public void testSamplingRa20Case2() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, k, n, 
case2);
+       }
+       
        private void runSparsityEstimateTest(SparsityEstimator estim, int m, 
int k, int n, double[] sp) {
                MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, 
"uniform", 3);
-               MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, 
"uniform", 3);
+               MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, 
"uniform", 7);
                MatrixBlock m3 = m1.aggregateBinaryOperations(m1, m2, 
                        new MatrixBlock(), 
InstructionUtils.getMatMultOperator(1));
                
                //compare estimated and real sparsity
                double est = estim.estim(m1, m2);
-               TestUtils.compareScalars(est, m3.getSparsity(), 1e-16);
+               TestUtils.compareScalars(est, m3.getSparsity(),
+                       (estim instanceof EstimatorSampleRa)?5e-2:1e-16);
        }
 }
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SelfProductTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SelfProductTest.java
index 702514e..3e9e249 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SelfProductTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SelfProductTest.java
@@ -29,6 +29,7 @@ import org.apache.sysml.hops.estim.EstimatorDensityMap;
 import org.apache.sysml.hops.estim.EstimatorLayeredGraph;
 import org.apache.sysml.hops.estim.EstimatorMatrixHistogram;
 import org.apache.sysml.hops.estim.EstimatorSample;
+import org.apache.sysml.hops.estim.EstimatorSampleRa;
 import org.apache.sysml.hops.estim.SparsityEstimator;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -131,6 +132,26 @@ public class SelfProductTest extends AutomatedTestBase
        }
        
        @Test
+       public void testSamplingRaDefCase1() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, sparsity1);
+       }
+       
+       @Test
+       public void testSamplingRaDefCase2() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, sparsity2);
+       }
+       
+       @Test
+       public void testSamplingRa20Case1() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, 
sparsity1);
+       }
+       
+       @Test
+       public void testSamplingRa20Case2() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, 
sparsity2);
+       }
+       
+       @Test
        public void testLayeredGraphCase1() {
                runSparsityEstimateTest(new EstimatorLayeredGraph(), m, 
sparsity1);
        }
diff --git 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java
 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java
index 51eb5d6..c3c6f7a 100644
--- 
a/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java
+++ 
b/src/test/java/org/apache/sysml/test/integration/functions/estim/SquaredProductTest.java
@@ -27,6 +27,7 @@ import org.apache.sysml.hops.estim.EstimatorDensityMap;
 import org.apache.sysml.hops.estim.EstimatorLayeredGraph;
 import org.apache.sysml.hops.estim.EstimatorMatrixHistogram;
 import org.apache.sysml.hops.estim.EstimatorSample;
+import org.apache.sysml.hops.estim.EstimatorSampleRa;
 import org.apache.sysml.hops.estim.SparsityEstimator;
 import org.apache.sysml.runtime.instructions.InstructionUtils;
 import org.apache.sysml.runtime.matrix.data.MatrixBlock;
@@ -44,11 +45,12 @@ public class SquaredProductTest extends AutomatedTestBase
        private final static int n = 1000;
        private final static double[] case1 = new double[]{0.0001, 0.00007};
        private final static double[] case2 = new double[]{0.0006, 0.00007};
+       private final static double[] case3 = new double[]{1.0, 0.1};
 
        private final static double eps1 = 0.05;
        private final static double eps2 = 1e-4;
        private final static double eps3 = 0;
-       
+       private final static double eps4 = 0.07;
        
        @Override
        public void setUp() {
@@ -66,6 +68,11 @@ public class SquaredProductTest extends AutomatedTestBase
        }
        
        @Test
+       public void testBasicAvgCase3() {
+               runSparsityEstimateTest(new EstimatorBasicAvg(), m, k, n, 
case3);
+       }
+       
+       @Test
        public void testBasicWorstCase1() {
                runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n, 
case1);
        }
@@ -76,6 +83,11 @@ public class SquaredProductTest extends AutomatedTestBase
        }
        
        @Test
+       public void testBasicWorstCase3() {
+               runSparsityEstimateTest(new EstimatorBasicWorst(), m, k, n, 
case3);
+       }
+       
+       @Test
        public void testDensityMapCase1() {
                runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, 
case1);
        }
@@ -85,6 +97,11 @@ public class SquaredProductTest extends AutomatedTestBase
                runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, 
case2);
        }
        
+//     @Test
+//     public void testDensityMapCase3() {
+//             runSparsityEstimateTest(new EstimatorDensityMap(), m, k, n, 
case3);
+//     }
+       
        @Test
        public void testDensityMap8Case1() {
                runSparsityEstimateTest(new EstimatorDensityMap(8), m, k, n, 
case1);
@@ -95,6 +112,11 @@ public class SquaredProductTest extends AutomatedTestBase
                runSparsityEstimateTest(new EstimatorDensityMap(8), m, k, n, 
case2);
        }
        
+//     @Test
+//     public void testDensityMap8Case3() {
+//             runSparsityEstimateTest(new EstimatorDensityMap(8), m, k, n, 
case3);
+//     }
+       
        @Test
        public void testBitsetMatrixCase1() {
                runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, 
case1);
@@ -106,6 +128,11 @@ public class SquaredProductTest extends AutomatedTestBase
        }
        
        @Test
+       public void testBitsetMatrixCase3() {
+               runSparsityEstimateTest(new EstimatorBitsetMM(), m, k, n, 
case3);
+       }
+       
+       @Test
        public void testMatrixHistogramCase1() {
                runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, 
k, n, case1);
        }
@@ -116,6 +143,11 @@ public class SquaredProductTest extends AutomatedTestBase
        }
        
        @Test
+       public void testMatrixHistogramCase3() {
+               runSparsityEstimateTest(new EstimatorMatrixHistogram(false), m, 
k, n, case3);
+       }
+       
+       @Test
        public void testMatrixHistogramExceptCase1() {
                runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, 
k, n, case1);
        }
@@ -126,6 +158,11 @@ public class SquaredProductTest extends AutomatedTestBase
        }
        
        @Test
+       public void testMatrixHistogramExceptCase3() {
+               runSparsityEstimateTest(new EstimatorMatrixHistogram(true), m, 
k, n, case3);
+       }
+       
+       @Test
        public void testSamplingDefCase1() {
                runSparsityEstimateTest(new EstimatorSample(), m, k, n, case1);
        }
@@ -135,6 +172,11 @@ public class SquaredProductTest extends AutomatedTestBase
                runSparsityEstimateTest(new EstimatorSample(), m, k, n, case2);
        }
        
+//     @Test
+//     public void testSamplingDefCase3() {
+//             runSparsityEstimateTest(new EstimatorSample(), m, k, n, case3);
+//     }
+       
        @Test
        public void testSampling20Case1() {
                runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n, 
case1);
@@ -145,6 +187,41 @@ public class SquaredProductTest extends AutomatedTestBase
                runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n, 
case2);
        }
        
+//     @Test
+//     public void testSampling20Case3() {
+//             runSparsityEstimateTest(new EstimatorSample(0.2), m, k, n, 
case3);
+//     }
+       
+       @Test
+       public void testSamplingRaDefCase1() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, k, n, 
case1);
+       }
+       
+       @Test
+       public void testSamplingRaDefCase2() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, k, n, 
case2);
+       }
+       
+       @Test
+       public void testSamplingRaDefCase3() {
+               runSparsityEstimateTest(new EstimatorSampleRa(), m, k, n, 
case3);
+       }
+       
+       @Test
+       public void testSamplingRa20Case1() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, k, n, 
case1);
+       }
+       
+       @Test
+       public void testSamplingRa20Case2() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, k, n, 
case2);
+       }
+       
+       @Test
+       public void testSamplingRa20Case3() {
+               runSparsityEstimateTest(new EstimatorSampleRa(0.2), m, k, n, 
case3);
+       }
+       
        @Test
        public void testLayeredGraphCase1() {
                runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, 
case1);
@@ -155,6 +232,11 @@ public class SquaredProductTest extends AutomatedTestBase
                runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, 
case2);
        }
        
+//     @Test
+//     public void testLayeredGraphCase3() {
+//             runSparsityEstimateTest(new EstimatorLayeredGraph(), m, k, n, 
case3);
+//     }
+       
        private void runSparsityEstimateTest(SparsityEstimator estim, int m, 
int k, int n, double[] sp) {
                MatrixBlock m1 = MatrixBlock.randOperations(m, k, sp[0], 1, 1, 
"uniform", 3);
                MatrixBlock m2 = MatrixBlock.randOperations(k, n, sp[1], 1, 1, 
"uniform", 7);
@@ -165,6 +247,7 @@ public class SquaredProductTest extends AutomatedTestBase
                double est = estim.estim(m1, m2);
                TestUtils.compareScalars(est, m3.getSparsity(),
                        (estim instanceof EstimatorBitsetMM) ? eps3 : //exact
+                       (estim instanceof EstimatorSampleRa) ? eps4 : //sample 
ra
                        (estim instanceof EstimatorBasicWorst) ? eps1 : eps2);
        }
 }

Reply via email to