This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new edfce10724 [SYSTEMDS-3729] New roll reorg operations in CP, incl tests edfce10724 is described below commit edfce107242213c1e9199c270ecd07d41fe2ac67 Author: min-guk <koreacho...@gmail.com> AuthorDate: Mon Sep 23 09:50:43 2024 +0200 [SYSTEMDS-3729] New roll reorg operations in CP, incl tests Closes #2103. --- .github/workflows/javaTests.yml | 1 + .../java/org/apache/sysds/common/Builtins.java | 1 + src/main/java/org/apache/sysds/common/Types.java | 2 +- src/main/java/org/apache/sysds/hops/ReorgOp.java | 22 +++- src/main/java/org/apache/sysds/lops/Transform.java | 11 +- .../sysds/parser/BuiltinFunctionExpression.java | 12 +- .../org/apache/sysds/parser/DMLTranslator.java | 8 ++ .../sysds/runtime/functionobjects/RollIndex.java | 65 +++++++++++ .../runtime/instructions/CPInstructionParser.java | 1 + .../instructions/cp/ReorgCPInstruction.java | 80 +++++++++----- .../sysds/runtime/lineage/LineageCacheConfig.java | 2 +- .../sysds/runtime/matrix/data/LibMatrixReorg.java | 108 +++++++++++++++++- .../sysds/runtime/matrix/data/MatrixBlock.java | 4 +- src/main/python/systemds/operator/nodes/matrix.py | 7 ++ src/main/python/tests/matrix/test_roll.py | 65 +++++++++++ .../component/matrix/libMatrixReorg/RollTest.java | 123 +++++++++++++++++++++ 16 files changed, 473 insertions(+), 39 deletions(-) diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml index 366e20398e..595298b8f4 100644 --- a/.github/workflows/javaTests.yml +++ b/.github/workflows/javaTests.yml @@ -177,3 +177,4 @@ jobs: name: Java Code Coverage (Jacoco) path: target/site/jacoco retention-days: 3 + diff --git a/src/main/java/org/apache/sysds/common/Builtins.java b/src/main/java/org/apache/sysds/common/Builtins.java index d021597d7f..98f92ae55e 100644 --- a/src/main/java/org/apache/sysds/common/Builtins.java +++ b/src/main/java/org/apache/sysds/common/Builtins.java @@ -282,6 +282,7 @@ public enum Builtins { RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN), REMOVE("remove", false, ReturnType.MULTI_RETURN), REV("rev", false), + ROLL("roll", false), ROUND("round", false), ROW_COUNT_DISTINCT("rowCountDistinct",false), ROWINDEXMAX("rowIndexMax", false), diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index 6d64cbae54..e7274b25c4 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -749,7 +749,7 @@ public interface Types { /** Operations that perform internal reorganization of an allocation */ public enum ReOrgOp { DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if sizes unknown - RESHAPE, REV, SORT, TRANS; + RESHAPE, REV, ROLL, SORT, TRANS; @Override public String toString() { diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java b/src/main/java/org/apache/sysds/hops/ReorgOp.java index 0dce06964c..d7629eab10 100644 --- a/src/main/java/org/apache/sysds/hops/ReorgOp.java +++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java @@ -90,6 +90,9 @@ public class ReorgOp extends MultiThreadedHop case REV: HopsException.check(sz == 1, this, "should have arity 1 for op %s but has arity %d", _op, sz); break; + case ROLL: + HopsException.check(sz == 2, this, "should have arity 2 for op %s but has arity %d", _op, sz); + break; case RESHAPE: case SORT: HopsException.check(sz == 5, this, "should have arity 5 for op %s but has arity %d", _op, sz); @@ -125,6 +128,7 @@ public class ReorgOp extends MultiThreadedHop } case DIAG: case REV: + case ROLL: case SORT: return false; default: @@ -175,6 +179,18 @@ public class ReorgOp extends MultiThreadedHop setLops(transform1); break; } + case ROLL: { + Lop[] linputs = new Lop[2]; //input, shift + for (int i = 0; i < 2; i++) + linputs[i] = getInput().get(i).constructLops(); + + Transform transform1 = new Transform(linputs, _op, getDataType(), getValueType(), et, 1); + + setOutputDimensions(transform1); + setLineNumbers(transform1); + setLops(transform1); + break; + } case RESHAPE: { Lop[] linputs = new Lop[5]; //main, rows, cols, dims, byrow for (int i = 0; i < 5; i++) @@ -279,9 +295,10 @@ public class ReorgOp extends MultiThreadedHop ret = new MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros()); break; } - case REV: { + case REV: + case ROLL: { // dims and nnz are exactly the same as in input - if( dc.dimsKnown() ) + if (dc.dimsKnown()) ret = new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros()); break; } @@ -397,6 +414,7 @@ public class ReorgOp extends MultiThreadedHop break; } case REV: + case ROLL: { // dims and nnz are exactly the same as in input setDim1(input1.getDim1()); diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java index 0fcdc09fbf..8ef5925572 100644 --- a/src/main/java/org/apache/sysds/lops/Transform.java +++ b/src/main/java/org/apache/sysds/lops/Transform.java @@ -111,7 +111,10 @@ public class Transform extends Lop case REV: // Transpose a matrix return "rev"; - + + case ROLL: + return "roll"; + case DIAG: // Transform a vector into a diagonal matrix return "rdiag"; @@ -138,6 +141,12 @@ public class Transform extends Lop return getInstructions(input1, 1, output); } + @Override + public String getInstructions(String input1, String input2, String output) { + //opcodes: roll + return getInstructions(input1, 2, output); + } + @Override public String getInstructions(String input1, String input2, String input3, String input4, String output) { //opcodes: rsort diff --git a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java index de10e090f5..1de3442dd9 100644 --- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java +++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java @@ -1266,7 +1266,17 @@ public class BuiltinFunctionExpression extends DataIdentifier { output.setBlocksize (id.getBlocksize()); output.setValueType(id.getValueType()); break; - + + case ROLL: + checkNumParameters(2); + checkMatrixParam(getFirstExpr()); + checkScalarParam(getSecondExpr()); + output.setDataType(DataType.MATRIX); + output.setDimensions(id.getDim1(), id.getDim2()); + output.setBlocksize(id.getBlocksize()); + output.setValueType(id.getValueType()); + break; + case DIAG: checkNumParameters(1); checkMatrixParam(getFirstExpr()); diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java b/src/main/java/org/apache/sysds/parser/DMLTranslator.java index 77ed904821..b76425668f 100644 --- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java +++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java @@ -2481,6 +2481,14 @@ public class DMLTranslator target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), expr); break; + case ROLL: + ArrayList<Hop> inputs = new ArrayList<>(); + inputs.add(expr); + inputs.add(expr2); + currBuiltinOp = new ReorgOp(target.getName(), DataType.MATRIX, + target.getValueType(), ReOrgOp.valueOf(source.getOpCode().name()), inputs); + break; + case CBIND: case RBIND: OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) ? OpOp2.CBIND : OpOp2.RBIND; diff --git a/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java new file mode 100644 index 0000000000..5bd78bb703 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.functionobjects; + +import org.apache.commons.lang3.NotImplementedException; +import org.apache.sysds.runtime.matrix.data.MatrixIndexes; +import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex; +import org.apache.sysds.runtime.meta.DataCharacteristics; + +/** + * This index function is NOT used for actual sorting but just as a reference + * in ReorgOperator in order to identify sort operations. + */ +public class RollIndex extends IndexFunction { + private static final long serialVersionUID = -8446389232078905200L; + + private final int _shift; + + public RollIndex(int shift) { + _shift = shift; + } + + public int getShift() { + return _shift; + } + + @Override + public boolean computeDimension(int row, int col, CellIndex retDim) { + retDim.set(row, col); + return false; + } + + @Override + public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) { + out.set(in.getRows(), in.getCols(), in.getBlocksize(), in.getNonZeros()); + return false; + } + + @Override + public void execute(MatrixIndexes in, MatrixIndexes out) { + throw new NotImplementedException(); + } + + @Override + public void execute(CellIndex in, CellIndex out) { + throw new NotImplementedException(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java index 7a5ed3524c..a0270f6b20 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java @@ -271,6 +271,7 @@ public class CPInstructionParser extends InstructionParser { // Reorg Instruction Opcodes (repositioning of existing values) String2CPInstructionType.put( "r'" , CPType.Reorg); String2CPInstructionType.put( "rev" , CPType.Reorg); + String2CPInstructionType.put( "roll" , CPType.Reorg); String2CPInstructionType.put( "rdiag" , CPType.Reorg); String2CPInstructionType.put( "rshape" , CPType.Reshape); String2CPInstructionType.put( "rsort" , CPType.Reorg); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java index 18f5613c8b..e7b3000d52 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.functionobjects.DiagIndex; import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.RollIndex; import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -38,20 +39,16 @@ public class ReorgCPInstruction extends UnaryCPInstruction { private final CPOperand _col; private final CPOperand _desc; private final CPOperand _ixret; + private final CPOperand _shift; /** * for opcodes r' and rdiag - * - * @param op - * operator - * @param in - * cp input operand - * @param out - * cp output operand - * @param opcode - * the opcode - * @param istr - * ? + * + * @param op operator + * @param in cp input operand + * @param out cp output operand + * @param opcode the opcode + * @param istr ? */ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, String opcode, String istr) { this(op, in, out, null, null, null, opcode, istr); @@ -59,30 +56,41 @@ public class ReorgCPInstruction extends UnaryCPInstruction { /** * for opcode rsort - * - * @param op - * operator - * @param in - * cp input operand - * @param col - * ? - * @param desc - * ? - * @param ixret - * ? - * @param out - * cp output operand - * @param opcode - * the opcode - * @param istr - * ? + * + * @param op operator + * @param in cp input operand + * @param col ? + * @param desc ? + * @param ixret ? + * @param out cp output operand + * @param opcode the opcode + * @param istr ? */ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand col, CPOperand desc, CPOperand ixret, - String opcode, String istr) { + String opcode, String istr) { super(CPType.Reorg, op, in, out, opcode, istr); _col = col; _desc = desc; _ixret = ixret; + _shift = null; + } + + /** + * for opcode roll + * + * @param op operator + * @param in cp input operand + * @param shift ? + * @param out cp output operand + * @param opcode the opcode + * @param istr ? + */ + private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, CPOperand shift, String opcode, String istr) { + super(CPType.Reorg, op, in, out, opcode, istr); + _col = null; + _desc = null; + _ixret = null; + _shift = shift; } public static ReorgCPInstruction parseInstruction ( String str ) { @@ -103,6 +111,13 @@ public class ReorgCPInstruction extends UnaryCPInstruction { parseUnaryInstruction(str, in, out); //max 2 operands return new ReorgCPInstruction(new ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str); } + else if (opcode.equalsIgnoreCase("roll")) { + InstructionUtils.checkNumFields(str, 3); + in.split(parts[1]); + out.split(parts[3]); + CPOperand shift = new CPOperand(parts[2]); + return new ReorgCPInstruction(new ReorgOperator(new RollIndex(0)), in, out, shift, opcode, str); + } else if ( opcode.equalsIgnoreCase("rdiag") ) { parseUnaryInstruction(str, in, out); //max 2 operands return new ReorgCPInstruction(new ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str); @@ -136,7 +151,12 @@ public class ReorgCPInstruction extends UnaryCPInstruction { boolean ixret = ec.getScalarInput(_ixret).getBooleanValue(); r_op = r_op.setFn(new SortIndex(cols, desc, ixret)); } - + + if (r_op.fn instanceof RollIndex) { + int shift = (int) ec.getScalarInput(_shift).getLongValue(); + r_op = r_op.setFn(new RollIndex(shift)); + } + //execute operation MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new MatrixBlock(), 0, 0, 0); diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java index 53f517c058..b700410f62 100644 --- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java +++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java @@ -63,7 +63,7 @@ public class LineageCacheConfig // Relatively expensive instructions. Most include shuffles. private static final String[] PERSIST_OPCODES1 = new String[] { - "cpmm", "rmm", "pmm", "zipmm", "rev", "rshape", "rsort", "-", "*", "+", + "cpmm", "rmm", "pmm", "zipmm", "rev", "roll", "rshape", "rsort", "-", "*", "+", "/", "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">", "<=", ">=", "&&", "||", "xor", "max", "min", "rmempty", "rappend", "gappend", "galignedappend", "rbind", "cbind", "nmin", "nmax", diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java index 9e1e1cac9e..53e8a88832 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java @@ -50,6 +50,7 @@ import org.apache.sysds.runtime.data.SparseRow; import org.apache.sysds.runtime.data.SparseRowVector; import org.apache.sysds.runtime.functionobjects.DiagIndex; import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.RollIndex; import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue; @@ -94,6 +95,7 @@ public class LibMatrixReorg { private enum ReorgType { TRANSPOSE, REV, + ROLL, DIAG, RESHAPE, SORT, @@ -123,6 +125,9 @@ public class LibMatrixReorg { return transpose(in, out); case REV: return rev(in, out); + case ROLL: + RollIndex rix = (RollIndex) op.fn; + return roll(in, out, rix.getShift()); case DIAG: return diag(in, out); case SORT: @@ -142,6 +147,7 @@ public class LibMatrixReorg { case TRANSPOSE: return transposeInPlace(in, op.getNumThreads()); case REV: + case ROLL: case SORT: throw new DMLRuntimeException("Not implemented inplace: " + op.fn.getClass().getSimpleName()); default: @@ -420,6 +426,25 @@ public class LibMatrixReorg { } } + public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int shift) { + //sparse-safe operation + if (in.isEmptyBlock(false)) + return out; + + //special case: row vector + if (in.rlen == 1) { + out.copy(in); + return out; + } + + if (in.sparse) + rollSparse(in, out, shift); + else + rollDense(in, out, shift); + + return out; + } + public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) { //Timing time = new Timing(true); @@ -957,7 +982,10 @@ public class LibMatrixReorg { else if( op.fn instanceof RevIndex ) //rev return ReorgType.REV; - + + else if( op.fn instanceof RollIndex ) //roll + return ReorgType.ROLL; + else if( op.fn instanceof DiagIndex ) //diag return ReorgType.DIAG; @@ -2243,7 +2271,83 @@ public class LibMatrixReorg { if( !a.isEmpty(i) ) c.set(m-1-i, a.get(i), true); } - + + private static void rollDense(MatrixBlock in, MatrixBlock out, int shift) { + final int m = in.rlen; + final int n = in.clen; + + //set basic meta data and allocate output + out.sparse = false; + out.nonZeros = in.nonZeros; + out.allocateDenseBlock(false); + + //copy all rows into target positions + if (n == 1) { //column vector + double[] a = in.getDenseBlockValues(); + double[] c = out.getDenseBlockValues(); + + // roll matrix with axis=none + shift %= (m != 0 ? m : 1); + + System.arraycopy(a, 0, c, shift, m - shift); + System.arraycopy(a, m - shift, c, 0, shift); + } else { //general matrix case + DenseBlock a = in.getDenseBlock(); + DenseBlock c = out.getDenseBlock(); + + // roll matrix with axis=0 + shift %= (m != 0 ? m : 1); + + for (int i = 0; i < m - shift; i++) { + System.arraycopy(a.values(i), a.pos(i), c.values(i + shift), c.pos(i + shift), n); + } + + for (int i = m - shift; i < m; i++) { + System.arraycopy(a.values(i), a.pos(i), c.values(i + shift - m), c.pos(i + shift - m), n); + } + } + } + + private static void rollSparse(MatrixBlock in, MatrixBlock out, int shift) { + final int m = in.rlen; + + //set basic meta data and allocate output + out.sparse = true; + out.nonZeros = in.nonZeros; + out.allocateSparseRowsBlock(false); + + //copy all rows into target positions + SparseBlock a = in.getSparseBlock(); + SparseBlock c = out.getSparseBlock(); + + // roll matrix with axis=0 + shift %= (m != 0 ? m : 1); + + for (int i = 0; i < m - shift; i++) { + if (a.isEmpty(i)) continue; // skip empty rows + + rollSparseRow(a, c, i, i + shift); + } + + for (int i = m - shift; i < m; i++) { + if (a.isEmpty(i)) continue; // skip empty rows + + rollSparseRow(a, c, i, i + shift - m); + } + } + + private static void rollSparseRow(SparseBlock a, SparseBlock c, int oriIdx, int shiftIdx) { + final int apos = a.pos(oriIdx); + final int alen = a.size(oriIdx) + apos; + final int[] aix = a.indexes(oriIdx); + final double[] avals = a.values(oriIdx); + + // copy only non-zero elements + for (int k = apos; k < alen; k++) { + c.set(shiftIdx, aix[k], avals[k]); + } + } + /** * Generic implementation diagV2M * (in most-likely DENSE, out most likely SPARSE) diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java index 62f9a1febb..7da67d267b 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java @@ -87,6 +87,7 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll; import org.apache.sysds.runtime.functionobjects.ReduceCol; import org.apache.sysds.runtime.functionobjects.ReduceRow; import org.apache.sysds.runtime.functionobjects.RevIndex; +import org.apache.sysds.runtime.functionobjects.RollIndex; import org.apache.sysds.runtime.functionobjects.SortIndex; import org.apache.sysds.runtime.functionobjects.SwapIndex; import org.apache.sysds.runtime.instructions.InstructionUtils; @@ -3565,7 +3566,8 @@ public class MatrixBlock extends MatrixValue implements CacheBlock<MatrixBlock>, public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, int startRow, int startColumn, int length) { if ( !( op.fn instanceof SwapIndex || op.fn instanceof DiagIndex - || op.fn instanceof SortIndex || op.fn instanceof RevIndex ) ) + || op.fn instanceof SortIndex || op.fn instanceof RevIndex + || op.fn instanceof RollIndex) ) throw new DMLRuntimeException("the current reorgOperations cannot support: "+op.fn.getClass()+"."); MatrixBlock result = checkType(ret); diff --git a/src/main/python/systemds/operator/nodes/matrix.py b/src/main/python/systemds/operator/nodes/matrix.py index a96f8d884c..41bb481da5 100644 --- a/src/main/python/systemds/operator/nodes/matrix.py +++ b/src/main/python/systemds/operator/nodes/matrix.py @@ -664,6 +664,13 @@ class Matrix(OperationNode): """ return Matrix(self.sds_context, "rev", [self]) + def roll(self, shift: int) -> "Matrix": + """Reverses the rows + + :return: the OperationNode representing this operation + """ + return Matrix(self.sds_context, "roll", [self, shift]) + def round(self) -> "Matrix": """round all values to nearest natural number diff --git a/src/main/python/tests/matrix/test_roll.py b/src/main/python/tests/matrix/test_roll.py new file mode 100644 index 0000000000..1355f24082 --- /dev/null +++ b/src/main/python/tests/matrix/test_roll.py @@ -0,0 +1,65 @@ +# ------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +# ------------------------------------------------------------- + +import unittest +import numpy as np +import random +from scipy import sparse +from systemds.context import SystemDSContext + +np.random.seed(7) +shape = (random.randrange(1, 25), random.randrange(1, 25)) + +m = np.random.rand(shape[0], shape[1]) +my = np.random.rand(shape[0], 1) +m_empty = np.asarray([[]]) +m_sparse = sparse.random(shape[0], shape[1], density=0.1, format="csr").toarray() +m_sparse = np.around(m_sparse, decimals=22) + +class TestRoll(unittest.TestCase): + sds: SystemDSContext = None + + @classmethod + def setUpClass(cls): + cls.sds = SystemDSContext() + + @classmethod + def tearDownClass(cls): + cls.sds.close() + + def test_empty(self): + r = self.sds.from_numpy(np.asarray(m_empty)).roll(1).compute() + self.assertTrue(np.allclose(r, m_empty)) + + def test_col_vec(self): + r = self.sds.from_numpy(my).roll(1).compute() + self.assertTrue(np.allclose(r, np.roll(my, axis=None, shift=1))) + + def test_basic(self): + r = self.sds.from_numpy(m).roll(1).compute() + self.assertTrue(np.allclose(r, np.roll(m, axis=0, shift=1))) + + def test_sparse_matrix(self): + r = self.sds.from_numpy(m_sparse).roll(1).compute() + self.assertTrue(np.allclose(r, np.roll(m_sparse, axis=0, shift=1))) + +if __name__ == "__main__": + unittest.main(exit=False) diff --git a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java new file mode 100644 index 0000000000..17ba8fb068 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.matrix.libMatrixReorg; + +import org.apache.sysds.runtime.functionobjects.IndexFunction; +import org.apache.sysds.runtime.functionobjects.RollIndex; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.matrix.operators.ReorgOperator; +import org.apache.sysds.test.TestUtils; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import static org.junit.Assert.fail; + +/** + * Test class for the roll function in MatrixBlock. + * <p> + * This test verifies that the roll function produces identical results + * when applied to both sparse and dense representations of the same matrix. + */ +@RunWith(Parameterized.class) +public class RollTest { + private final int shift; + + // Input matrices + private MatrixBlock inputSparse; + private MatrixBlock inputDense; + + /** + * Constructor for parameterized test cases. + * + * @param rows Number of rows in the test matrix. + * @param cols Number of columns in the test matrix. + * @param sparsity Sparsity level of the test matrix (0.0 to 1.0). + * @param shift Shift value for the roll operation. + */ + public RollTest(int rows, int cols, double sparsity, int shift) { + this.shift = shift; + + // Generate a MatrixBlock with the given parameters + inputSparse = TestUtils.generateTestMatrixBlock(rows, cols, 0, 10, sparsity, 1); + inputSparse.recomputeNonZeros(); + + inputDense = new MatrixBlock(rows, cols, false); // false indicates dense + inputDense.copy(inputSparse, false); // Copy without maintaining sparsity + inputDense.recomputeNonZeros(); + } + + /** + * Defines the parameters for the test cases. + * Each Object[] contains {rows, cols, sparsity, shift}. + * + * @return Collection of test parameters. + */ + @Parameters(name = "Rows: {0}, Cols: {1}, Sparsity: {2}, Shift: {3}") + public static Collection<Object[]> data() { + List<Object[]> tests = new ArrayList<>(); + + // Define various sizes, sparsity levels, and shift values + int[] rows = {1, 19, 1001, 2017}; + int[] cols = {1, 17, 1001, 2017}; + double[] sparsities = {0.01, 0.1, 0.7, 1.0}; + int[] shifts = {0, 1, 5, 10, 15}; + + // Generate all combinations of sizes, sparsities, and shifts + for (int row : rows) { + for (int col : cols) { + for (double sparsity : sparsities) { + for (int shift : shifts) { + tests.add(new Object[]{row, col, sparsity, shift}); + } + } + } + } + return tests; + } + + /** + * The actual test method that performs the roll operation on both + * sparse and dense matrices and compares the results. + */ + @Test + public void test() { + try { + IndexFunction op = new RollIndex(shift); + MatrixBlock outputDense = inputDense.reorgOperations( + new ReorgOperator(op), new MatrixBlock(), 0, 0, 0); + MatrixBlock outputSparse = inputSparse.reorgOperations( + new ReorgOperator(op), new MatrixBlock(), 0, 0, 0); + outputSparse.sparseToDense(); + + // Compare the dense representations of both outputs + TestUtils.compareMatrices(outputSparse, outputDense, 1e-9, + "Compare Sparse and Dense Roll Results"); + + } catch (Exception e) { + e.printStackTrace(); + fail("Exception occurred during roll function test: " + e.getMessage()); + } + } +}