This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 46d8e19ace5bb7d6a7c2636488a553ee7845356d Author: Matthias Boehm <[email protected]> AuthorDate: Wed Aug 11 00:16:42 2021 +0200 [SYSTEMDS-3090] Fix nnz mismatch in shuffle (invalid aggregation) A recent change in the aggregation logic forced the accumulator block to dense for performance, but also changed the nnz metadata to invalid values (likely to preserve the dense representation), which however, can lead to severe correctness issues and as surfaced in DBSCAN test failures (after other script modifications) can cause crashes during shuffle. This patch corrects the sources of this metadata corruption and adds additional utils to simplify debugging of invalid block metadata in distributed RDDs (to quickly narrow down the operation that introduces the violation). --- .../sysds/runtime/controlprogram/ProgramBlock.java | 13 +++++- .../instructions/spark/utils/SparkUtils.java | 19 ++++++++ .../sysds/runtime/matrix/data/LibMatrixAgg.java | 50 ++++++++++------------ .../test/functions/builtin/BuiltinDBSCANTest.java | 3 +- 4 files changed, 54 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index 19b769f..ff33f6d 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -24,6 +24,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.api.jmlc.JMLCUtils; +import org.apache.sysds.common.Types.FileFormat; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.Hop; @@ -46,9 +47,12 @@ import org.apache.sysds.runtime.instructions.cp.DoubleObject; import org.apache.sysds.runtime.instructions.cp.IntObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; import org.apache.sysds.runtime.lineage.LineageCache; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.meta.MetaData; +import org.apache.sysds.runtime.meta.MetaDataFormat; import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator; import org.apache.sysds.utils.Statistics; @@ -274,7 +278,7 @@ public abstract class ProgramBlock implements ParseInfo { // optional check for correct nnz and sparse/dense representation of all // variables in symbol table (for tracking source of wrong representation) if(CHECK_MATRIX_PROPERTIES) { - checkSparsity(tmp, ec.getVariables()); + checkSparsity(tmp, ec.getVariables(), ec); checkFederated(tmp, ec.getVariables()); } } @@ -333,7 +337,7 @@ public abstract class ProgramBlock implements ParseInfo { } } - private static void checkSparsity(Instruction lastInst, LocalVariableMap vars) { + private static void checkSparsity(Instruction lastInst, LocalVariableMap vars, ExecutionContext ec) { for(String varname : vars.keySet()) { Data dat = vars.get(varname); if(dat instanceof MatrixObject) { @@ -364,6 +368,11 @@ public abstract class ProgramBlock implements ParseInfo { + ", actual=" + sparse1 + ", expected=" + sparse2 + ", nrow=" + mb.getNumRows() + ", ncol=" + mb.getNumColumns() + ", nnz=" + nnz1 + ", inst=" + lastInst + ")"); } + MetaData meta = mo.getMetaData(); + if( mo.getRDDHandle() != null && !(meta instanceof MetaDataFormat + && ((MetaDataFormat)meta).getFileFormat() != FileFormat.BINARY) ) { + SparkUtils.checkSparsity(varname, ec); + } } } } diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java index 5e51977..2c15b91 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/SparkUtils.java @@ -25,10 +25,12 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFlatMapFunction; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.storage.StorageLevel; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Checkpoint; import org.apache.sysds.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; import org.apache.sysds.runtime.data.IndexedTensorBlock; @@ -200,6 +202,12 @@ public class SparkUtils else //requires key access, so use mappartitions return in.mapPartitionsToPair(new CopyTensorBlockPairFunction(deep), true); } + + public static void checkSparsity(String varname, ExecutionContext ec) { + SparkExecutionContext sec = (SparkExecutionContext) ec; + sec.getBinaryMatrixBlockRDDHandleForVariable(varname) + .foreach(new CheckSparsityFunction()); + } // This returns RDD with identifier as well as location public static String getStartLineFromSparkDebugInfo(String line) { @@ -288,6 +296,17 @@ public class SparkUtils mo.acquireReadAndRelease(); } + private static class CheckSparsityFunction implements VoidFunction<Tuple2<MatrixIndexes,MatrixBlock>> + { + private static final long serialVersionUID = 4150132775681848807L; + + @Override + public void call(Tuple2<MatrixIndexes, MatrixBlock> arg) throws Exception { + arg._2.checkNonZeros(); + arg._2.checkSparseRows(); + } + } + private static class AnalyzeCellDataCharacteristics implements Function<Tuple2<MatrixIndexes,MatrixCell>, DataCharacteristics> { private static final long serialVersionUID = 8899395272683723008L; diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java index 722eeca..608d79c 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java @@ -1153,16 +1153,14 @@ public class LibMatrixAgg double[] a = in.getDenseBlockValues(); - if(aggVal.isEmpty()){ + if(aggVal.isEmpty()) { aggVal.allocateDenseBlock(); - aggVal.setNonZeros(in.getNonZeros()); } else if(aggVal.isInSparseFormat()){ // If for some reason the agg Val is sparse then force it to dence, // since the values that are going to be added // will make it dense anyway. aggVal.sparseToDense(); - aggVal.setNonZeros(in.getNonZeros()); if(aggVal.denseBlock == null) aggVal.allocateDenseBlock(); } @@ -1171,8 +1169,6 @@ public class LibMatrixAgg KahanObject buffer = new KahanObject(0, 0); KahanPlus akplus = KahanPlus.getKahanPlusFnObject(); - // Don't include nnz maintenence since this function most likely aggregate more than one matrixblock. - // j is the pointer to column. // c is the pointer to correction. for(int j=0, c = n; j<n; j++, c++){ @@ -1182,6 +1178,8 @@ public class LibMatrixAgg t[j] = buffer._sum; t[c] = buffer._correction; } + + aggVal.recomputeNonZeros(); } private static void aggregateBinaryMatrixLastRowSparseGeneric(MatrixBlock in, MatrixBlock aggVal) { @@ -1197,30 +1195,26 @@ public class LibMatrixAgg final int m = in.rlen; final int rlen = Math.min(a.numRows(), m); - if(aggVal.isEmpty()){ + if(aggVal.isEmpty()) aggVal.allocateSparseRowsBlock(); - aggVal.setNonZeros(in.getNonZeros()); - } - - for( int i=0; i<rlen-1; i++ ) - { - if( !a.isEmpty(i) ) - { - int apos = a.pos(i); - int alen = a.size(i); - int[] aix = a.indexes(i); - double[] avals = a.values(i); - - for( int j=apos; j<apos+alen; j++ ) - { - int jix = aix[j]; - double corr = in.quickGetValue(m-1, jix); - buffer1._sum = aggVal.quickGetValue(i, jix); - buffer1._correction = aggVal.quickGetValue(m-1, jix); - akplus.execute(buffer1, avals[j], corr); - aggVal.quickSetValue(i, jix, buffer1._sum); - aggVal.quickSetValue(m-1, jix, buffer1._correction); - } + + // add to aggVal with implicit nnz maintenance + for( int i=0; i<rlen-1; i++ ) { + if( a.isEmpty(i) ) + continue; + int apos = a.pos(i); + int alen = a.size(i); + int[] aix = a.indexes(i); + double[] avals = a.values(i); + + for( int j=apos; j<apos+alen; j++ ) { + int jix = aix[j]; + double corr = in.quickGetValue(m-1, jix); + buffer1._sum = aggVal.quickGetValue(i, jix); + buffer1._correction = aggVal.quickGetValue(m-1, jix); + akplus.execute(buffer1, avals[j], corr); + aggVal.quickSetValue(i, jix, buffer1._sum); + aggVal.quickSetValue(m-1, jix, buffer1._correction); } } } diff --git a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java index 41ca5d0..dfb79cf 100644 --- a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java +++ b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinDBSCANTest.java @@ -67,7 +67,8 @@ public class BuiltinDBSCANTest extends AutomatedTestBase String HOME = SCRIPT_DIR + TEST_DIR; fullDMLScriptName = HOME + TEST_NAME + ".dml"; - programArgs = new String[]{"-nvargs", "X=" + input("A"), "Y=" + output("B"), "eps=" + epsDBSCAN, "minPts=" + minPts}; + programArgs = new String[]{"-explain","-nvargs", + "X=" + input("A"), "Y=" + output("B"), "eps=" + epsDBSCAN, "minPts=" + minPts}; fullRScriptName = HOME + TEST_NAME + ".R"; rCmd = getRCmd(inputDir(), Double.toString(epsDBSCAN), Integer.toString(minPts), expectedDir());
