This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit ccb589056c3b6fa29332d9aa0fe33747b0774250 Author: Sebastian Baunsgaard <[email protected]> AuthorDate: Sun Jan 7 19:30:21 2024 +0100 [MINOR] MatrixBlock improved generic Unary Agg --- .../spark/AggregateUnarySPInstruction.java | 2 +- .../sysds/runtime/matrix/data/CM_N_COVCell.java | 6 - .../sysds/runtime/matrix/data/LibMatrixAgg.java | 60 +++++++ .../data/LibMatrixAggUnarySpecialization.java | 152 ++++++++++++++++ .../sysds/runtime/matrix/data/MatrixBlock.java | 191 ++------------------- .../sysds/runtime/matrix/data/MatrixCell.java | 39 ++--- .../sysds/runtime/matrix/data/MatrixValue.java | 6 +- .../sysds/runtime/matrix/data/WeightedCell.java | 9 +- 8 files changed, 246 insertions(+), 219 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java index 32b80a2360..ba7237ee35 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/AggregateUnarySPInstruction.java @@ -279,7 +279,7 @@ public class AggregateUnarySPInstruction extends UnarySPInstruction { throws Exception { //unary aggregate operation (always keep the correction) - return arg0._2.aggregateUnaryOperations( + return (MatrixBlock) arg0._2.aggregateUnaryOperations( _op, new MatrixBlock(), _blen, arg0._1()); } } diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java index 8e58630abe..a367af4f7b 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/CM_N_COVCell.java @@ -45,12 +45,6 @@ public class CM_N_COVCell extends MatrixValue public String toString() { return cm.toString(); } - - @Override - public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, - MatrixValue result, int blen, MatrixIndexes indexesIn) { - throw new DMLRuntimeException("operation not supported for CM_N_COVCell"); - } @Override public MatrixValue binaryOperations(BinaryOperator op, diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java index 0891d7f1ae..5d5cbc14e8 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java @@ -61,6 +61,7 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction; import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.instructions.cp.KahanObject; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; import org.apache.sysds.runtime.matrix.operators.AggregateOperator; import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; @@ -206,6 +207,24 @@ public class LibMatrixAgg { } + public static MatrixBlock aggregateUnaryMatrix(AggregateUnaryOperator op,MatrixBlock in, MatrixValue result, + int blen, MatrixIndexes indexesIn, boolean inCP){ + + MatrixBlock ret = LibMatrixAgg.prepareAggregateUnaryOutput(in, op, result, blen); + + if( LibMatrixAgg.isSupportedUnaryAggregateOperator(op) ) { + LibMatrixAgg.aggregateUnaryMatrix(in, ret, op, op.getNumThreads()); + LibMatrixAgg.recomputeIndexes(ret, op, blen, indexesIn); + } + else + LibMatrixAggUnarySpecialization.aggregateUnary(in, op, ret, blen, indexesIn); + + if(op.aggOp.existsCorrection() && inCP) + ret.dropLastRowsOrColumns(op.aggOp.correction); + + return ret; + } + public static void aggregateUnaryMatrix(MatrixBlock in, MatrixBlock out, AggregateUnaryOperator uaop) { AggType aggtype = getAggType(uaop); @@ -3672,6 +3691,47 @@ public class LibMatrixAgg { } + public static MatrixBlock prepareAggregateUnaryOutput(MatrixBlock in, AggregateUnaryOperator op, MatrixValue result, int blen){ + CellIndex tempCellIndex = new CellIndex(-1,-1); + final int rlen = in.getNumRows(); + final int clen = in.getNumColumns(); + op.indexFn.computeDimension(rlen, clen, tempCellIndex); + if(op.aggOp.existsCorrection()) + { + switch(op.aggOp.correction) + { + case LASTROW: + tempCellIndex.row++; + break; + case LASTCOLUMN: + tempCellIndex.column++; + break; + case LASTTWOROWS: + tempCellIndex.row+=2; + break; + case LASTTWOCOLUMNS: + tempCellIndex.column+=2; + break; + case LASTFOURROWS: + tempCellIndex.row+=4; + break; + case LASTFOURCOLUMNS: + tempCellIndex.column+=4; + break; + default: + throw new DMLRuntimeException("unrecognized correctionLocation: "+op.aggOp.correction); + } + } + + //prepare result matrix block + if(result==null) + result=new MatrixBlock(tempCellIndex.row, tempCellIndex.column, false); + else + result.reset(tempCellIndex.row, tempCellIndex.column, false); + return (MatrixBlock)result; + } + + ///////////////////////////////////////////////////////// // Task Implementations for Multi-Threaded Operations // ///////////////////////////////////////////////////////// diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAggUnarySpecialization.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAggUnarySpecialization.java new file mode 100644 index 0000000000..78f6a9a7bb --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAggUnarySpecialization.java @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.matrix.data; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.common.Types.CorrectionLocationType; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.data.DenseBlock; +import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.instructions.cp.KahanObject; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.matrix.operators.AggregateOperator; +import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; + +/** + * This class handel the generic case of aggregate Unary operations, it is only used in cases not optimized in + * LibMatrixAgg. + */ +public class LibMatrixAggUnarySpecialization { + protected static final Log LOG = LogFactory.getLog(LibMatrixAggUnarySpecialization.class.getName()); + + public static void aggregateUnary(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.sparseSafe) + sparseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + else + denseAggregateUnaryHelp(mb, op, result, blen, indexesIn); + } + + private static void sparseAggregateUnaryHelp(final MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, + int blen, MatrixIndexes indexesIn) { + // initialize result + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + + if(mb.sparse && mb.sparseBlock != null) { + SparseBlock a = mb.sparseBlock; + for(int r = 0; r < Math.min(mb.rlen, a.numRows()); r++) { + if(a.isEmpty(r)) + continue; + int apos = a.pos(r); + int alen = a.size(r); + int[] aix = a.indexes(r); + double[] aval = a.values(r); + for(int i = apos; i < apos + alen; i++) { + tempCellIndex.set(r, aix[i]); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, aval[i], + buffer); + } + } + } + else if(!mb.sparse && mb.denseBlock != null) { + DenseBlock a = mb.getDenseBlock(); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, a.get(i, j), + buffer); + } + } + } + + private static void denseAggregateUnaryHelp(MatrixBlock mb, AggregateUnaryOperator op, MatrixBlock result, int blen, + MatrixIndexes indexesIn) { + if(op.aggOp.initialValue != 0) + result.reset(result.rlen, result.clen, op.aggOp.initialValue); + CellIndex tempCellIndex = new CellIndex(-1, -1); + KahanObject buffer = new KahanObject(0, 0); + for(int i = 0; i < mb.rlen; i++) + for(int j = 0; j < mb.clen; j++) { + tempCellIndex.set(i, j); + op.indexFn.execute(tempCellIndex, tempCellIndex); + incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, + mb.quickGetValue(i, j), buffer); + } + } + + private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, + double newvalue, KahanObject buffer) { + if(aggOp.existsCorrection()) { + if(aggOp.correction == CorrectionLocationType.LASTROW || + aggOp.correction == CorrectionLocationType.LASTCOLUMN) { + int corRow = row, corCol = column; + if(aggOp.correction == CorrectionLocationType.LASTROW)// extra row + corRow++; + else if(aggOp.correction == CorrectionLocationType.LASTCOLUMN) + corCol++; + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + + buffer._sum = result.quickGetValue(row, column); + buffer._correction = result.quickGetValue(corRow, corCol); + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue); + result.quickSetValue(row, column, buffer._sum); + result.quickSetValue(corRow, corCol, buffer._correction); + } + else if(aggOp.correction == CorrectionLocationType.NONE) { + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + } + else// for mean + { + int corRow = row, corCol = column; + int countRow = row, countCol = column; + if(aggOp.correction == CorrectionLocationType.LASTTWOROWS) { + countRow++; + corRow += 2; + } + else if(aggOp.correction == CorrectionLocationType.LASTTWOCOLUMNS) { + countCol++; + corCol += 2; + } + else + throw new DMLRuntimeException("unrecognized correctionLocation: " + aggOp.correction); + buffer._sum = result.quickGetValue(row, column); + buffer._correction = result.quickGetValue(corRow, corCol); + double count = result.quickGetValue(countRow, countCol) + 1.0; + buffer = (KahanObject) aggOp.increOp.fn.execute(buffer, newvalue, count); + result.quickSetValue(row, column, buffer._sum); + result.quickSetValue(corRow, corCol, buffer._correction); + result.quickSetValue(countRow, countCol, count); + } + + } + else { + newvalue = aggOp.increOp.fn.execute(result.quickGetValue(row, column), newvalue); + result.quickSetValue(row, column, newvalue); + } + } + +} diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 276f1aacee..085b6a5c52 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -587,19 +587,15 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, return isEmptyBlock(true); } - /** - * Get if this MatrixBlock is an empty block. The call can potentially tricker a recomputation of non zeros if the - * non-zero count is unknown. - * - * @param safe True if we want to ensure the count non zeros if the nnz is unknown. - * @return If the block is empty. - */ - public boolean isEmptyBlock(boolean safe) { - boolean ret = (sparse && sparseBlock == null) || (!sparse && denseBlock == null); - if(nonZeros <= 0) { // estimate non zeros if unknown or 0. - if(safe) // only allow the recompute if safe flag is false. + public boolean isEmptyBlock(boolean safe) + { + boolean ret = ( sparse && sparseBlock==null ) || ( !sparse && denseBlock==null ); + if( nonZeros==0 ) + { + //prevent under-estimation + if(safe) recomputeNonZeros(); - ret = (nonZeros == 0); + ret = (nonZeros==0); } return ret; } @@ -4670,180 +4666,13 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, return (MatrixBlock)result; } - @Override - public final MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, - MatrixValue result, int blen, MatrixIndexes indexesIn) { - return aggregateUnaryOperations(op, result, blen, indexesIn, false); - } - @Override public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int blen, MatrixIndexes indexesIn, boolean inCP) { - - MatrixBlock ret = prepareAggregateUnaryOutput(op, result, blen); - - if( LibMatrixAgg.isSupportedUnaryAggregateOperator(op) ) { - LibMatrixAgg.aggregateUnaryMatrix(this, ret, op, op.getNumThreads()); - LibMatrixAgg.recomputeIndexes(ret, op, blen, indexesIn); - } - else if(op.sparseSafe) - sparseAggregateUnaryHelp(op, ret, blen, indexesIn); - else - denseAggregateUnaryHelp(op, ret, blen, indexesIn); - - if(op.aggOp.existsCorrection() && inCP) - ret.dropLastRowsOrColumns(op.aggOp.correction); - - return ret; + return LibMatrixAgg.aggregateUnaryMatrix(op, this, result, blen, indexesIn, inCP); } - public MatrixBlock prepareAggregateUnaryOutput(AggregateUnaryOperator op, MatrixValue result, int blen){ - CellIndex tempCellIndex = new CellIndex(-1,-1); - op.indexFn.computeDimension(rlen, clen, tempCellIndex); - if(op.aggOp.existsCorrection()) - { - switch(op.aggOp.correction) - { - case LASTROW: - tempCellIndex.row++; - break; - case LASTCOLUMN: - tempCellIndex.column++; - break; - case LASTTWOROWS: - tempCellIndex.row+=2; - break; - case LASTTWOCOLUMNS: - tempCellIndex.column+=2; - break; - case LASTFOURROWS: - tempCellIndex.row+=4; - break; - case LASTFOURCOLUMNS: - tempCellIndex.column+=4; - break; - default: - throw new DMLRuntimeException("unrecognized correctionLocation: "+op.aggOp.correction); - } - } - - //prepare result matrix block - if(result==null) - result=new MatrixBlock(tempCellIndex.row, tempCellIndex.column, false); - else - result.reset(tempCellIndex.row, tempCellIndex.column, false); - return (MatrixBlock)result; - } - - private void sparseAggregateUnaryHelp(AggregateUnaryOperator op, MatrixBlock result, - int blen, MatrixIndexes indexesIn) - { - //initialize result - if(op.aggOp.initialValue!=0) - result.reset(result.rlen, result.clen, op.aggOp.initialValue); - CellIndex tempCellIndex = new CellIndex(-1,-1); - KahanObject buffer = new KahanObject(0,0); - - if( sparse && sparseBlock!=null ) { - SparseBlock a = sparseBlock; - for(int r=0; r<Math.min(rlen, a.numRows()); r++) { - if(a.isEmpty(r)) continue; - int apos = a.pos(r); - int alen = a.size(r); - int[] aix = a.indexes(r); - double[] aval = a.values(r); - for(int i=apos; i<apos+alen; i++) { - tempCellIndex.set(r, aix[i]); - op.indexFn.execute(tempCellIndex, tempCellIndex); - incrementalAggregateUnaryHelp(op.aggOp, result, - tempCellIndex.row, tempCellIndex.column, aval[i], buffer); - } - } - } - else if( !sparse && denseBlock!=null ) { - DenseBlock a = getDenseBlock(); - for(int i=0; i<rlen; i++) - for(int j=0; j<clen; j++) { - tempCellIndex.set(i, j); - op.indexFn.execute(tempCellIndex, tempCellIndex); - incrementalAggregateUnaryHelp(op.aggOp, result, - tempCellIndex.row, tempCellIndex.column, a.get(i, j), buffer); - } - } - } - - private void denseAggregateUnaryHelp(AggregateUnaryOperator op, MatrixBlock result, - int blen, MatrixIndexes indexesIn) - { - if(op.aggOp.initialValue!=0) - result.reset(result.rlen, result.clen, op.aggOp.initialValue); - CellIndex tempCellIndex = new CellIndex(-1,-1); - KahanObject buffer=new KahanObject(0,0); - for(int i=0; i<rlen; i++) - for(int j=0; j<clen; j++) { - tempCellIndex.set(i, j); - op.indexFn.execute(tempCellIndex, tempCellIndex); - incrementalAggregateUnaryHelp(op.aggOp, result, tempCellIndex.row, tempCellIndex.column, quickGetValue(i,j), buffer); - } - } - - private static void incrementalAggregateUnaryHelp(AggregateOperator aggOp, MatrixBlock result, int row, int column, - double newvalue, KahanObject buffer) - { - if(aggOp.existsCorrection()) - { - if(aggOp.correction==CorrectionLocationType.LASTROW || aggOp.correction==CorrectionLocationType.LASTCOLUMN) - { - int corRow=row, corCol=column; - if(aggOp.correction==CorrectionLocationType.LASTROW)//extra row - corRow++; - else if(aggOp.correction==CorrectionLocationType.LASTCOLUMN) - corCol++; - else - throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction); - - buffer._sum=result.quickGetValue(row, column); - buffer._correction=result.quickGetValue(corRow, corCol); - buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newvalue); - result.quickSetValue(row, column, buffer._sum); - result.quickSetValue(corRow, corCol, buffer._correction); - }else if(aggOp.correction==CorrectionLocationType.NONE) - { - throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction); - }else// for mean - { - int corRow=row, corCol=column; - int countRow=row, countCol=column; - if(aggOp.correction==CorrectionLocationType.LASTTWOROWS) - { - countRow++; - corRow+=2; - } - else if(aggOp.correction==CorrectionLocationType.LASTTWOCOLUMNS) - { - countCol++; - corCol+=2; - } - else - throw new DMLRuntimeException("unrecognized correctionLocation: "+aggOp.correction); - buffer._sum=result.quickGetValue(row, column); - buffer._correction=result.quickGetValue(corRow, corCol); - double count=result.quickGetValue(countRow, countCol)+1.0; - buffer=(KahanObject) aggOp.increOp.fn.execute(buffer, newvalue, count); - result.quickSetValue(row, column, buffer._sum); - result.quickSetValue(corRow, corCol, buffer._correction); - result.quickSetValue(countRow, countCol, count); - } - - }else - { - newvalue=aggOp.increOp.fn.execute(result.quickGetValue(row, column), newvalue); - result.quickSetValue(row, column, newvalue); - } - } - - public void dropLastRowsOrColumns(CorrectionLocationType correctionLocation) - { + public void dropLastRowsOrColumns(CorrectionLocationType correctionLocation) { //determine number of rows/cols to be removed int step = correctionLocation.getNumRemovedRowsColumns(); if( step <= 0) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java index fa2190dfc5..5c9685d158 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixCell.java @@ -163,26 +163,6 @@ public class MatrixCell extends MatrixValue implements Serializable out.writeDouble(value); } - @Override - public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, - MatrixValue result, int blen, - MatrixIndexes indexesIn) { - - MatrixCell c3=checkType(result); - if(c3==null) - c3=new MatrixCell(); - - if(op.indexFn instanceof ReduceDiag) - { - if(indexesIn.getRowIndex()==indexesIn.getColumnIndex()) - c3.setValue(getValue()); - else - c3.setValue(0); - } - else - c3.setValue(getValue()); - return c3; - } @Override public MatrixValue binaryOperations(BinaryOperator op, @@ -364,10 +344,21 @@ public class MatrixCell extends MatrixValue implements Serializable } @Override - public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, - MatrixValue result, int blen, - MatrixIndexes indexesIn, boolean inCP) { - return aggregateUnaryOperations(op, result, blen,indexesIn); + public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int blen, + MatrixIndexes indexesIn, boolean inCP) { + MatrixCell c3 = checkType(result); + if(c3 == null) + c3 = new MatrixCell(); + + if(op.indexFn instanceof ReduceDiag) { + if(indexesIn.getRowIndex() == indexesIn.getColumnIndex()) + c3.setValue(getValue()); + else + c3.setValue(0); + } + else + c3.setValue(getValue()); + return c3; } @Override diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java index 98a5536011..a1e567547c 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixValue.java @@ -133,8 +133,10 @@ public abstract class MatrixValue implements WritableComparable public abstract void ctableOperations(Operator op, double scalarThat, MatrixValue that2, CTableMap ctableResult, MatrixBlock ctableResultBlock); - public abstract MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, - int blen, MatrixIndexes indexesIn); + public final MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, + int blen, MatrixIndexes indexesIn){ + return aggregateUnaryOperations(op, result, blen, indexesIn, false); + } public abstract MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int blen, MatrixIndexes indexesIn, boolean inCP); diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java b/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java index 97ac52e7a6..aa331a6b2f 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/WeightedCell.java @@ -116,11 +116,10 @@ public class WeightedCell extends MatrixCell } @Override - public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, - MatrixValue result, int blen, - MatrixIndexes indexesIn) { - super.aggregateUnaryOperations(op, result, blen, indexesIn); - WeightedCell c3=checkType(result); + public MatrixValue aggregateUnaryOperations(AggregateUnaryOperator op, MatrixValue result, int blen, + MatrixIndexes indexesIn, boolean inCP) { + super.aggregateUnaryOperations(op, result, blen, indexesIn, inCP); + WeightedCell c3 = checkType(result); c3.setWeight(weight); return c3; }
