This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 9227cc8d1fa0267843c940685b6be6b47f2f04aa
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Jul 27 19:44:54 2024 +0200

    [MINOR] Fix sum_sq via partial reversion of matrix block cleanup
---
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 90 +++++++++++++++++++---
 1 file changed, 78 insertions(+), 12 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 7c8efec94e..cbf9485c34 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -3280,10 +3280,53 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                }
                else if(aggOp.correction==CorrectionLocationType.NONE) {
                        //e.g., ak+ kahan plus as used in sum, mapmult, mmcj 
and tsmm
-                       if(!(aggOp.increOp.fn instanceof KahanPlus))
-                               throw new DMLRuntimeException("Unsupported 
incremental aggregation: "+aggOp.increOp.fn);
-                       
-                       LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, 
cor, deep);
+                       if(aggOp.increOp.fn instanceof KahanPlus) {
+                               LibMatrixAgg.aggregateBinaryMatrix(newWithCor, 
this, cor, deep);
+                       }
+                       else
+                       {
+                               if( newWithCor.isInSparseFormat() && 
aggOp.sparseSafe ) //SPARSE
+                               {
+                                       SparseBlock b = 
newWithCor.getSparseBlock();
+                                       if( b==null ) //early abort on empty 
block
+                                               return;
+                                       for( int r=0; r<Math.min(rlen, 
b.numRows()); r++ )
+                                       {
+                                               if( !b.isEmpty(r) ) 
+                                               {
+                                                       int bpos = b.pos(r);
+                                                       int blen = b.size(r);
+                                                       int[] bix = 
b.indexes(r);
+                                                       double[] bvals = 
b.values(r);
+                                                       for( int j=bpos; 
j<bpos+blen; j++)
+                                                       {
+                                                               int c = bix[j];
+                                                               buffer._sum = 
this.get(r, c);
+                                                               
buffer._correction = cor.get(r, c);
+                                                               buffer = 
(KahanObject) aggOp.increOp.fn.execute(buffer, bvals[j]);
+                                                               set(r, c, 
buffer._sum);
+                                                               cor.set(r, c, 
buffer._correction);
+                                                       }
+                                               }
+                                       }
+
+                               }
+                               else //DENSE or SPARSE (!sparsesafe)
+                               {
+                                       for(int r=0; r<rlen; r++)
+                                               for(int c=0; c<clen; c++) {
+                                                       buffer._sum=this.get(r, 
c);
+                                                       
buffer._correction=cor.get(r, c);
+                                                       buffer=(KahanObject) 
aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c));
+                                                       set(r, c, buffer._sum);
+                                                       cor.set(r, c, 
buffer._correction);
+                                               }
+                               }
+
+                               //change representation if required
+                               //(note since ak+ on blocks is currently only 
applied in MR, hence no need to account for this in mem estimates)
+                               examSparsity(); 
+                       }
                }
                else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS) {
                        double n, n2, mu2;
@@ -3409,10 +3452,23 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                
                if(aggOp.correction==CorrectionLocationType.LASTROW)
                {
-                       if( !(aggOp.increOp.fn instanceof KahanPlus) )
-                               throw new DMLRuntimeException("Unsupported 
incremental aggregation: "+aggOp.increOp.fn);
-                       
-                       LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, 
aggOp);
+                       if( aggOp.increOp.fn instanceof KahanPlus )
+                       {
+                               LibMatrixAgg.aggregateBinaryMatrix(newWithCor, 
this, aggOp);
+                       }
+                       else
+                       {
+                               for(int r=0; r<rlen-1; r++)
+                                       for(int c=0; c<clen; c++)
+                                       {
+                                               buffer._sum=this.get(r, c);
+                                               
buffer._correction=this.get(r+1, c);
+                                               buffer=(KahanObject) 
aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), 
+                                                               
newWithCor.get(r+1, c));
+                                               set(r, c, buffer._sum);
+                                               set(r+1, c, buffer._correction);
+                                       }
+                       }
                }
                else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN)
                {
@@ -3455,10 +3511,20 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        }
                        else
                        {
-                               if(!(aggOp.increOp.fn instanceof KahanPlus)) 
-                                       throw new 
DMLRuntimeException("Unsupported incremental aggregation: "+aggOp.increOp.fn);
-                               
-                               LibMatrixAgg.aggregateBinaryMatrix(newWithCor, 
this, aggOp);
+                               if(aggOp.increOp.fn instanceof KahanPlus) {
+                                       
LibMatrixAgg.aggregateBinaryMatrix(newWithCor, this, aggOp);
+                               }
+                               else {
+                                       for(int r=0; r<rlen; r++)
+                                               for(int c=0; c<clen-1; c++)
+                                               {
+                                                       buffer._sum=this.get(r, 
c);
+                                                       
buffer._correction=this.get(r, c+1);
+                                                       buffer=(KahanObject) 
aggOp.increOp.fn.execute(buffer, newWithCor.get(r, c), newWithCor.get(r, c+1));
+                                                       set(r, c, buffer._sum);
+                                                       set(r, c+1, 
buffer._correction);
+                                               }
+                               }
                        }
                }
                else if(aggOp.correction==CorrectionLocationType.LASTTWOROWS)

Reply via email to