This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 9227cc8d1fa0267843c940685b6be6b47f2f04aa Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sat Jul 27 19:44:54 2024 +0200 [MINOR] Fix sum_sq via partial reversion of matrix block cleanup --- .../sysds/runtime/matrix/data/MatrixBlock.java | 90 +++++++++++++++++++--- 1 file changed, 78 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 7c8efec94e..cbf9485c34 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -3280,10 +3280,53 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, } else if(aggOp.correction==CorrectionLocationType.NONE) { //e.g., ak+ kahan plus as used in sum, mapmult, mmcj and tsmm - if(!(aggOp.increOp.fn instanceof KahanPlus)) - throw new DMLRuntimeException("Unsupported incremental aggregation: "+aggOp.increOp.fn); - - LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, cor, deep); + if(aggOp.increOp.fn instanceof KahanPlus) { + LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, cor, deep); + } + else + { + if( newWithCor.isInSparseFormat() && aggOp.sparseSafe ) //SPARSE + { + SparseBlock b = newWithCor.getSparseBlock(); + if( b==null ) //early abort on empty block + return; + for( int r=0; r<Math.min(rlen, b.numRows()); r++ ) + { + if( !b.isEmpty(r) ) + { + int bpos = b.pos(r); + int blen = b.size(r); + int[] bix = b.indexes(r); + double[] bvals = b.values(r); + for( int j=bpos; j<bpos+blen; j++) + { + int c = bix[j]; + buffer._sum = this.get(r, c); + buffer._correction = cor.get(r, c); + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, bvals[j]); + set(r, c, buffer._sum); + cor.set(r, c, buffer._correction); + } + } + } + + } + else //DENSE or SPARSE (!sparsesafe) + { + for(int r=0; r<rlen; r++) + for(int c=0; c<clen; c++) { + buffer._sum=this.get(r, c); + buffer._correction=cor.get(r, c); + buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c)); + set(r, c, buffer._sum); + cor.set(r, c, buffer._correction); + } + } + + //change representation if required + //(note since ak+ on blocks is currently only applied in MR, hence no need to account for this in mem estimates) + examSparsity(); + } } else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS) { double n, n2, mu2; @@ -3409,10 +3452,23 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, if(aggOp.correction==CorrectionLocationType.LASTROW) { - if( !(aggOp.increOp.fn instanceof KahanPlus) ) - throw new DMLRuntimeException("Unsupported incremental aggregation: "+aggOp.increOp.fn); - - LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp); + if( aggOp.increOp.fn instanceof KahanPlus ) + { + LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp); + } + else + { + for(int r=0; r<rlen-1; r++) + for(int c=0; c<clen; c++) + { + buffer._sum=this.get(r, c); + buffer._correction=this.get(r+1, c); + buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), + newWithCor.get(r+1, c)); + set(r, c, buffer._sum); + set(r+1, c, buffer._correction); + } + } } else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN) { @@ -3455,10 +3511,20 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, } else { - if(!(aggOp.increOp.fn instanceof KahanPlus)) - throw new DMLRuntimeException("Unsupported incremental aggregation: "+aggOp.increOp.fn); - - LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp); + if(aggOp.increOp.fn instanceof KahanPlus) { + LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp); + } + else { + for(int r=0; r<rlen; r++) + for(int c=0; c<clen-1; c++) + { + buffer._sum=this.get(r, c); + buffer._correction=this.get(r, c+1); + buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), newWithCor.get(r, c+1)); + set(r, c, buffer._sum); + set(r, c+1, buffer._correction); + } + } } } else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS)