This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit ccb589056c3b6fa29332d9aa0fe33747b0774250
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Jan 7 19:30:21 2024 +0100

    [MINOR] MatrixBlock improved generic Unary Agg
---
 .../spark/AggregateUnarySPInstruction.java         |   2 +-
 .../sysds/runtime/matrix/data/CM_N_COVCell.java    |   6 -
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    |  60 +++++++
 .../data/LibMatrixAggUnarySpecialization.java      | 152 ++++++++++++++++
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 191 ++-------------------
 .../sysds/runtime/matrix/data/MatrixCell.java      |  39 ++---
 .../sysds/runtime/matrix/data/MatrixValue.java     |   6 +-
 .../sysds/runtime/matrix/data/WeightedCell.java    |   9 +-
 8 files changed, 246 insertions(+), 219 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
index 32b80a2360..ba7237ee35 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java
@@ -279,7 +279,7 @@ public class AggregateUnarySPInstruction extends 
UnarySPInstruction {
                        throws Exception 
                {
                        //unary aggregate operation (always keep the correction)
-                       return arg0._2.aggregateUnaryOperations(
+                       return (MatrixBlock) arg0._2.aggregateUnaryOperations(
                                        _op, new MatrixBlock(), _blen, 
arg0._1());
                }
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
index 8e58630abe..a367af4f7b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java
@@ -45,12 +45,6 @@ public class CM_N_COVCell extends MatrixValue
        public String toString() {
                return cm.toString();
        }
-       
-       @Override
-       public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op,
-                       MatrixValue result, int blen, MatrixIndexes indexesIn) {
-               throw new DMLRuntimeException("operation not supported for 
CM_N_COVCell");
-       }
 
        @Override
        public MatrixValue binaryOperations(BinaryOperator op,
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 0891d7f1ae..5d5cbc14e8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -61,6 +61,7 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
 import org.apache.sysds.runtime.instructions.cp.KahanObject;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
 import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
 import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
@@ -206,6 +207,24 @@ public class LibMatrixAgg {
 
        }
 
+       public static MatrixBlock aggregateUnaryMatrix(AggregateUnaryOperator 
op,MatrixBlock in, MatrixValue result,
+       int blen, MatrixIndexes indexesIn, boolean inCP){
+
+               MatrixBlock ret = LibMatrixAgg.prepareAggregateUnaryOutput(in, 
op, result, blen);
+               
+               if( LibMatrixAgg.isSupportedUnaryAggregateOperator(op) ) {
+                       LibMatrixAgg.aggregateUnaryMatrix(in, ret, op, 
op.getNumThreads());
+                       LibMatrixAgg.recomputeIndexes(ret, op, blen, indexesIn);
+               }
+               else
+                       LibMatrixAggUnarySpecialization.aggregateUnary(in, op, 
ret, blen, indexesIn);
+               
+               if(op.aggOp.existsCorrection() && inCP)
+                       ret.dropLastRowsOrColumns(op.aggOp.correction);
+               
+               return ret;
+       }
+
        public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock 
out, AggregateUnaryOperator uaop) {
 
                AggType aggtype = getAggType(uaop);
@@ -3672,6 +3691,47 @@ public class LibMatrixAgg {
        }
 
 
+       public static MatrixBlock prepareAggregateUnaryOutput(MatrixBlock in, 
AggregateUnaryOperator op, MatrixValue result, int blen){
+               CellIndex tempCellIndex = new CellIndex(-1,-1);
+               final int rlen = in.getNumRows();
+               final int clen = in.getNumColumns();
+               op.indexFn.computeDimension(rlen, clen, tempCellIndex);
+               if(op.aggOp.existsCorrection())
+               {
+                       switch(op.aggOp.correction)
+                       {
+                               case LASTROW: 
+                                       tempCellIndex.row++;
+                                       break;
+                               case LASTCOLUMN: 
+                                       tempCellIndex.column++;
+                                       break;
+                               case LASTTWOROWS: 
+                                       tempCellIndex.row+=2;
+                                       break;
+                               case LASTTWOCOLUMNS: 
+                                       tempCellIndex.column+=2;
+                                       break;
+                               case LASTFOURROWS:
+                                       tempCellIndex.row+=4;
+                                       break;
+                               case LASTFOURCOLUMNS:
+                                       tempCellIndex.column+=4;
+                                       break;
+                               default:
+                                       throw new 
DMLRuntimeException("unrecognized correctionLocation: "+op.aggOp.correction);
+                       }
+               }
+               
+               //prepare result matrix block
+               if(result==null)
+                       result=new MatrixBlock(tempCellIndex.row, 
tempCellIndex.column, false);
+               else
+                       result.reset(tempCellIndex.row, tempCellIndex.column, 
false);
+               return (MatrixBlock)result;
+       }
+
+
        /////////////////////////////////////////////////////////
        // Task Implementations for Multi-Threaded Operations  //
        /////////////////////////////////////////////////////////
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAggUnarySpecialization.java
 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAggUnarySpecialization.java
new file mode 100644
index 0000000000..78f6a9a7bb
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAggUnarySpecialization.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.matrix.data;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.CorrectionLocationType;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.instructions.cp.KahanObject;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+
+/**
+ * This class handel the generic case of aggregate Unary operations, it is 
only used in cases not optimized in
+ * LibMatrixAgg.
+ */
+public class LibMatrixAggUnarySpecialization {
+       protected static final Log LOG = 
LogFactory.getLog(LibMatrixAggUnarySpecialization.class.getName());
+
+       public static void aggregateUnary(final MatrixBlock mb, 
AggregateUnaryOperator op, MatrixBlock result, int blen,
+               MatrixIndexes indexesIn) {
+               if(op.sparseSafe)
+                       sparseAggregateUnaryHelp(mb, op, result, blen, 
indexesIn);
+               else
+                       denseAggregateUnaryHelp(mb, op, result, blen, 
indexesIn);
+       }
+
+       private static void sparseAggregateUnaryHelp(final MatrixBlock mb, 
AggregateUnaryOperator op, MatrixBlock result,
+               int blen, MatrixIndexes indexesIn) {
+               // initialize result
+               if(op.aggOp.initialValue != 0)
+                       result.reset(result.rlen, result.clen, 
op.aggOp.initialValue);
+               CellIndex tempCellIndex = new CellIndex(-1, -1);
+               KahanObject buffer = new KahanObject(0, 0);
+
+               if(mb.sparse && mb.sparseBlock != null) {
+                       SparseBlock a = mb.sparseBlock;
+                       for(int r = 0; r < Math.min(mb.rlen, a.numRows()); r++) 
{
+                               if(a.isEmpty(r))
+                                       continue;
+                               int apos = a.pos(r);
+                               int alen = a.size(r);
+                               int[] aix = a.indexes(r);
+                               double[] aval = a.values(r);
+                               for(int i = apos; i < apos + alen; i++) {
+                                       tempCellIndex.set(r, aix[i]);
+                                       op.indexFn.execute(tempCellIndex, 
tempCellIndex);
+                                       incrementalAggregateUnaryHelp(op.aggOp, 
result, tempCellIndex.row, tempCellIndex.column, aval[i],
+                                               buffer);
+                               }
+                       }
+               }
+               else if(!mb.sparse && mb.denseBlock != null) {
+                       DenseBlock a = mb.getDenseBlock();
+                       for(int i = 0; i < mb.rlen; i++)
+                               for(int j = 0; j < mb.clen; j++) {
+                                       tempCellIndex.set(i, j);
+                                       op.indexFn.execute(tempCellIndex, 
tempCellIndex);
+                                       incrementalAggregateUnaryHelp(op.aggOp, 
result, tempCellIndex.row, tempCellIndex.column, a.get(i, j),
+                                               buffer);
+                               }
+               }
+       }
+
+       private static void denseAggregateUnaryHelp(MatrixBlock mb, 
AggregateUnaryOperator op, MatrixBlock result, int blen,
+               MatrixIndexes indexesIn) {
+               if(op.aggOp.initialValue != 0)
+                       result.reset(result.rlen, result.clen, 
op.aggOp.initialValue);
+               CellIndex tempCellIndex = new CellIndex(-1, -1);
+               KahanObject buffer = new KahanObject(0, 0);
+               for(int i = 0; i < mb.rlen; i++)
+                       for(int j = 0; j < mb.clen; j++) {
+                               tempCellIndex.set(i, j);
+                               op.indexFn.execute(tempCellIndex, 
tempCellIndex);
+                               incrementalAggregateUnaryHelp(op.aggOp, result, 
tempCellIndex.row, tempCellIndex.column,
+                                       mb.quickGetValue(i, j), buffer);
+                       }
+       }
+
+       private static void incrementalAggregateUnaryHelp(AggregateOperator 
aggOp, MatrixBlock result, int row, int column,
+               double newvalue, KahanObject buffer) {
+               if(aggOp.existsCorrection()) {
+                       if(aggOp.correction == CorrectionLocationType.LASTROW ||
+                               aggOp.correction == 
CorrectionLocationType.LASTCOLUMN) {
+                               int corRow = row, corCol = column;
+                               if(aggOp.correction == 
CorrectionLocationType.LASTROW)// extra row
+                                       corRow++;
+                               else if(aggOp.correction == 
CorrectionLocationType.LASTCOLUMN)
+                                       corCol++;
+                               else
+                                       throw new 
DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction);
+
+                               buffer._sum = result.quickGetValue(row, column);
+                               buffer._correction = 
result.quickGetValue(corRow, corCol);
+                               buffer = (KahanObject) 
aggOp.increOp.fn.execute(buffer, newvalue);
+                               result.quickSetValue(row, column, buffer._sum);
+                               result.quickSetValue(corRow, corCol, 
buffer._correction);
+                       }
+                       else if(aggOp.correction == 
CorrectionLocationType.NONE) {
+                               throw new DMLRuntimeException("unrecognized 
correctionLocation: " + aggOp.correction);
+                       }
+                       else// for mean
+                       {
+                               int corRow = row, corCol = column;
+                               int countRow = row, countCol = column;
+                               if(aggOp.correction == 
CorrectionLocationType.LASTTWOROWS) {
+                                       countRow++;
+                                       corRow += 2;
+                               }
+                               else if(aggOp.correction == 
CorrectionLocationType.LASTTWOCOLUMNS) {
+                                       countCol++;
+                                       corCol += 2;
+                               }
+                               else
+                                       throw new 
DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction);
+                               buffer._sum = result.quickGetValue(row, column);
+                               buffer._correction = 
result.quickGetValue(corRow, corCol);
+                               double count = result.quickGetValue(countRow, 
countCol) + 1.0;
+                               buffer = (KahanObject) 
aggOp.increOp.fn.execute(buffer, newvalue, count);
+                               result.quickSetValue(row, column, buffer._sum);
+                               result.quickSetValue(corRow, corCol, 
buffer._correction);
+                               result.quickSetValue(countRow, countCol, count);
+                       }
+
+               }
+               else {
+                       newvalue = 
aggOp.increOp.fn.execute(result.quickGetValue(row, column), newvalue);
+                       result.quickSetValue(row, column, newvalue);
+               }
+       }
+
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 276f1aacee..085b6a5c52 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -587,19 +587,15 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                return isEmptyBlock(true);
        }
        
-       /**
-        * Get if this MatrixBlock is an empty block. The call can potentially 
tricker a recomputation of non zeros if the
-        * non-zero count is unknown.
-        * 
-        * @param safe True if we want to ensure the count non zeros if the nnz 
is unknown.
-        * @return If the block is empty.
-        */
-       public boolean isEmptyBlock(boolean safe) {
-               boolean ret = (sparse && sparseBlock == null) || (!sparse && 
denseBlock == null);
-               if(nonZeros <= 0) { // estimate non zeros if unknown or 0.
-                       if(safe) // only allow the recompute if safe flag is 
false.
+       public boolean isEmptyBlock(boolean safe)
+       {
+               boolean ret = ( sparse && sparseBlock==null ) || ( !sparse && 
denseBlock==null );
+               if( nonZeros==0 )
+               {
+                       //prevent under-estimation
+                       if(safe)
                                recomputeNonZeros();
-                       ret = (nonZeros == 0);
+                       ret = (nonZeros==0);
                }
                return ret;
        }
@@ -4670,180 +4666,13 @@ public class MatrixBlock extends MatrixValue 
implements CacheBlock<MatrixBlock>,
                return (MatrixBlock)result;
        }
        
-       @Override
-       public final MatrixBlock 
aggregateUnaryOperations(AggregateUnaryOperator op,
-                       MatrixValue result, int blen, MatrixIndexes indexesIn) {
-               return aggregateUnaryOperations(op, result, blen, indexesIn, 
false);
-       }
-       
        @Override
        public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, 
MatrixValue result,
                        int blen, MatrixIndexes indexesIn, boolean inCP)  {
-
-               MatrixBlock ret = prepareAggregateUnaryOutput(op, result, blen);
-               
-               if( LibMatrixAgg.isSupportedUnaryAggregateOperator(op) ) {
-                       LibMatrixAgg.aggregateUnaryMatrix(this, ret, op, 
op.getNumThreads());
-                       LibMatrixAgg.recomputeIndexes(ret, op, blen, indexesIn);
-               }
-               else if(op.sparseSafe)
-                       sparseAggregateUnaryHelp(op, ret, blen, indexesIn);
-               else
-                       denseAggregateUnaryHelp(op, ret, blen, indexesIn);
-               
-               if(op.aggOp.existsCorrection() && inCP)
-                       ret.dropLastRowsOrColumns(op.aggOp.correction);
-               
-               return ret;
+               return LibMatrixAgg.aggregateUnaryMatrix(op, this, result, 
blen, indexesIn, inCP);
        }
 
-       public MatrixBlock prepareAggregateUnaryOutput(AggregateUnaryOperator 
op, MatrixValue result, int blen){
-               CellIndex tempCellIndex = new CellIndex(-1,-1);
-               op.indexFn.computeDimension(rlen, clen, tempCellIndex);
-               if(op.aggOp.existsCorrection())
-               {
-                       switch(op.aggOp.correction)
-                       {
-                               case LASTROW: 
-                                       tempCellIndex.row++;
-                                       break;
-                               case LASTCOLUMN: 
-                                       tempCellIndex.column++;
-                                       break;
-                               case LASTTWOROWS: 
-                                       tempCellIndex.row+=2;
-                                       break;
-                               case LASTTWOCOLUMNS: 
-                                       tempCellIndex.column+=2;
-                                       break;
-                               case LASTFOURROWS:
-                                       tempCellIndex.row+=4;
-                                       break;
-                               case LASTFOURCOLUMNS:
-                                       tempCellIndex.column+=4;
-                                       break;
-                               default:
-                                       throw new 
DMLRuntimeException("unrecognized correctionLocation: "+op.aggOp.correction);
-                       }
-               }
-               
-               //prepare result matrix block
-               if(result==null)
-                       result=new MatrixBlock(tempCellIndex.row, 
tempCellIndex.column, false);
-               else
-                       result.reset(tempCellIndex.row, tempCellIndex.column, 
false);
-               return (MatrixBlock)result;
-       }
-       
-       private void sparseAggregateUnaryHelp(AggregateUnaryOperator op, 
MatrixBlock result,
-                       int blen, MatrixIndexes indexesIn)
-       {
-               //initialize result
-               if(op.aggOp.initialValue!=0)
-                       result.reset(result.rlen, result.clen, 
op.aggOp.initialValue);
-               CellIndex tempCellIndex = new CellIndex(-1,-1);
-               KahanObject buffer = new KahanObject(0,0);
-               
-               if( sparse && sparseBlock!=null ) {
-                       SparseBlock a = sparseBlock;
-                       for(int r=0; r<Math.min(rlen, a.numRows()); r++) {
-                               if(a.isEmpty(r)) continue;
-                               int apos = a.pos(r);
-                               int alen = a.size(r);
-                               int[] aix = a.indexes(r);
-                               double[] aval = a.values(r);
-                               for(int i=apos; i<apos+alen; i++) {
-                                       tempCellIndex.set(r, aix[i]);
-                                       op.indexFn.execute(tempCellIndex, 
tempCellIndex);
-                                       incrementalAggregateUnaryHelp(op.aggOp, 
result, 
-                                                       tempCellIndex.row, 
tempCellIndex.column, aval[i], buffer);
-                               }
-                       }
-               }
-               else if( !sparse && denseBlock!=null ) {
-                       DenseBlock a = getDenseBlock();
-                       for(int i=0; i<rlen; i++)
-                               for(int j=0; j<clen; j++) {
-                                       tempCellIndex.set(i, j);
-                                       op.indexFn.execute(tempCellIndex, 
tempCellIndex);
-                                       incrementalAggregateUnaryHelp(op.aggOp, 
result,
-                                               tempCellIndex.row, 
tempCellIndex.column, a.get(i, j), buffer);
-                               }
-               }
-       }
-       
-       private void denseAggregateUnaryHelp(AggregateUnaryOperator op, 
MatrixBlock result,
-                       int blen, MatrixIndexes indexesIn)
-       {
-               if(op.aggOp.initialValue!=0)
-                       result.reset(result.rlen, result.clen, 
op.aggOp.initialValue);
-               CellIndex tempCellIndex = new CellIndex(-1,-1);
-               KahanObject buffer=new KahanObject(0,0);
-               for(int i=0; i<rlen; i++)
-                       for(int j=0; j<clen; j++) {
-                               tempCellIndex.set(i, j);
-                               op.indexFn.execute(tempCellIndex, 
tempCellIndex);
-                               incrementalAggregateUnaryHelp(op.aggOp, result, 
tempCellIndex.row, tempCellIndex.column, quickGetValue(i,j), buffer);
-                       }
-       }
-       
-       private static void incrementalAggregateUnaryHelp(AggregateOperator 
aggOp, MatrixBlock result, int row, int column, 
-                       double newvalue, KahanObject buffer)
-       {
-               if(aggOp.existsCorrection())
-               {
-                       if(aggOp.correction==CorrectionLocationType.LASTROW || 
aggOp.correction==CorrectionLocationType.LASTCOLUMN)
-                       {
-                               int corRow=row, corCol=column;
-                               
if(aggOp.correction==CorrectionLocationType.LASTROW)//extra row
-                                       corRow++;
-                               else 
if(aggOp.correction==CorrectionLocationType.LASTCOLUMN)
-                                       corCol++;
-                               else
-                                       throw new 
DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction);
-                               
-                               buffer._sum=result.quickGetValue(row, column);
-                               buffer._correction=result.quickGetValue(corRow, 
corCol);
-                               buffer=(KahanObject) 
aggOp.increOp.fn.execute(buffer, newvalue);
-                               result.quickSetValue(row, column, buffer._sum);
-                               result.quickSetValue(corRow, corCol, 
buffer._correction);
-                       }else if(aggOp.correction==CorrectionLocationType.NONE)
-                       {
-                               throw new DMLRuntimeException("unrecognized 
correctionLocation: "+aggOp.correction);
-                       }else// for mean
-                       {
-                               int corRow=row, corCol=column;
-                               int countRow=row, countCol=column;
-                               
if(aggOp.correction==CorrectionLocationType.LASTTWOROWS)
-                               {
-                                       countRow++;
-                                       corRow+=2;
-                               }
-                               else 
if(aggOp.correction==CorrectionLocationType.LASTTWOCOLUMNS)
-                               {
-                                       countCol++;
-                                       corCol+=2;
-                               }
-                               else
-                                       throw new 
DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction);
-                               buffer._sum=result.quickGetValue(row, column);
-                               buffer._correction=result.quickGetValue(corRow, 
corCol);
-                               double count=result.quickGetValue(countRow, 
countCol)+1.0;
-                               buffer=(KahanObject) 
aggOp.increOp.fn.execute(buffer, newvalue, count);
-                               result.quickSetValue(row, column, buffer._sum);
-                               result.quickSetValue(corRow, corCol, 
buffer._correction);
-                               result.quickSetValue(countRow, countCol, count);
-                       }
-                       
-               }else
-               {
-                       
newvalue=aggOp.increOp.fn.execute(result.quickGetValue(row, column), newvalue);
-                       result.quickSetValue(row, column, newvalue);
-               }
-       }
-       
-       public void dropLastRowsOrColumns(CorrectionLocationType 
correctionLocation) 
-       {
+       public void dropLastRowsOrColumns(CorrectionLocationType 
correctionLocation) {
                //determine number of rows/cols to be removed
                int step = correctionLocation.getNumRemovedRowsColumns();
                if( step <= 0)
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
index fa2190dfc5..5c9685d158 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java
@@ -163,26 +163,6 @@ public class MatrixCell extends MatrixValue implements 
Serializable
                out.writeDouble(value);
        }
 
-       @Override
-       public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op,
-                       MatrixValue result, int blen,
-                       MatrixIndexes indexesIn) {
-               
-               MatrixCell c3=checkType(result);
-               if(c3==null)
-                       c3=new MatrixCell();
-               
-               if(op.indexFn instanceof ReduceDiag)
-               {
-                       if(indexesIn.getRowIndex()==indexesIn.getColumnIndex())
-                               c3.setValue(getValue());
-                       else
-                               c3.setValue(0);
-               }
-               else
-                       c3.setValue(getValue());
-               return c3;
-       }
 
        @Override
        public MatrixValue binaryOperations(BinaryOperator op,
@@ -364,10 +344,21 @@ public class MatrixCell extends MatrixValue implements 
Serializable
        }
 
        @Override
-       public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op,
-                       MatrixValue result, int blen,
-                       MatrixIndexes indexesIn, boolean inCP) {
-               return aggregateUnaryOperations(op,     result, blen,indexesIn);
+       public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, 
MatrixValue result, int blen,
+               MatrixIndexes indexesIn, boolean inCP) {
+               MatrixCell c3 = checkType(result);
+               if(c3 == null)
+                       c3 = new MatrixCell();
+
+               if(op.indexFn instanceof ReduceDiag) {
+                       if(indexesIn.getRowIndex() == 
indexesIn.getColumnIndex())
+                               c3.setValue(getValue());
+                       else
+                               c3.setValue(0);
+               }
+               else
+                       c3.setValue(getValue());
+               return c3;
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
index 98a5536011..a1e567547c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java
@@ -133,8 +133,10 @@ public abstract class MatrixValue implements 
WritableComparable
 
        public abstract void ctableOperations(Operator op, double scalarThat, 
MatrixValue that2, CTableMap ctableResult, MatrixBlock ctableResultBlock);
        
-       public abstract MatrixValue 
aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, 
-               int blen, MatrixIndexes indexesIn);
+       public final  MatrixValue 
aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, 
+               int blen, MatrixIndexes indexesIn){
+               return aggregateUnaryOperations(op, result, blen, indexesIn, 
false);
+       }
        
        public abstract MatrixValue 
aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, 
                int blen, MatrixIndexes indexesIn, boolean inCP);
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java
index 97ac52e7a6..aa331a6b2f 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java
@@ -116,11 +116,10 @@ public class WeightedCell extends MatrixCell
        }
        
        @Override
-       public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op,
-                       MatrixValue result, int blen,
-                       MatrixIndexes indexesIn) {
-               super.aggregateUnaryOperations(op, result, blen, indexesIn);
-               WeightedCell c3=checkType(result);
+       public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, 
MatrixValue result, int blen,
+               MatrixIndexes indexesIn, boolean inCP) {
+               super.aggregateUnaryOperations(op, result, blen, indexesIn, 
inCP);
+               WeightedCell c3 = checkType(result);
                c3.setWeight(weight);
                return c3;
        }

Reply via email to