This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new edfce10724 [SYSTEMDS-3729] New roll reorg operations in CP, incl tests
edfce10724 is described below
commit edfce107242213c1e9199c270ecd07d41fe2ac67
Author: min-guk <[email protected]>
AuthorDate: Mon Sep 23 09:50:43 2024 +0200
[SYSTEMDS-3729] New roll reorg operations in CP, incl tests
Closes #2103.
---
.github/workflows/javaTests.yml | 1 +
.../java/org/apache/sysds/common/Builtins.java | 1 +
src/main/java/org/apache/sysds/common/Types.java | 2 +-
src/main/java/org/apache/sysds/hops/ReorgOp.java | 22 +++-
src/main/java/org/apache/sysds/lops/Transform.java | 11 +-
.../sysds/parser/BuiltinFunctionExpression.java | 12 +-
.../org/apache/sysds/parser/DMLTranslator.java | 8 ++
.../sysds/runtime/functionobjects/RollIndex.java | 65 +++++++++++
.../runtime/instructions/CPInstructionParser.java | 1 +
.../instructions/cp/ReorgCPInstruction.java | 80 +++++++++-----
.../sysds/runtime/lineage/LineageCacheConfig.java | 2 +-
.../sysds/runtime/matrix/data/LibMatrixReorg.java | 108 +++++++++++++++++-
.../sysds/runtime/matrix/data/MatrixBlock.java | 4 +-
src/main/python/systemds/operator/nodes/matrix.py | 7 ++
src/main/python/tests/matrix/test_roll.py | 65 +++++++++++
.../component/matrix/libMatrixReorg/RollTest.java | 123 +++++++++++++++++++++
16 files changed, 473 insertions(+), 39 deletions(-)
diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index 366e20398e..595298b8f4 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -177,3 +177,4 @@ jobs:
name: Java Code Coverage (Jacoco)
path: target/site/jacoco
retention-days: 3
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index d021597d7f..98f92ae55e 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -282,6 +282,7 @@ public enum Builtins {
RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN),
REMOVE("remove", false, ReturnType.MULTI_RETURN),
REV("rev", false),
+ ROLL("roll", false),
ROUND("round", false),
ROW_COUNT_DISTINCT("rowCountDistinct",false),
ROWINDEXMAX("rowIndexMax", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 6d64cbae54..e7274b25c4 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -749,7 +749,7 @@ public interface Types {
/** Operations that perform internal reorganization of an allocation */
public enum ReOrgOp {
DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if
sizes unknown
- RESHAPE, REV, SORT, TRANS;
+ RESHAPE, REV, ROLL, SORT, TRANS;
@Override
public String toString() {
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index 0dce06964c..d7629eab10 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -90,6 +90,9 @@ public class ReorgOp extends MultiThreadedHop
case REV:
HopsException.check(sz == 1, this, "should have
arity 1 for op %s but has arity %d", _op, sz);
break;
+ case ROLL:
+ HopsException.check(sz == 2, this, "should have
arity 2 for op %s but has arity %d", _op, sz);
+ break;
case RESHAPE:
case SORT:
HopsException.check(sz == 5, this, "should have
arity 5 for op %s but has arity %d", _op, sz);
@@ -125,6 +128,7 @@ public class ReorgOp extends MultiThreadedHop
}
case DIAG:
case REV:
+ case ROLL:
case SORT:
return false;
default:
@@ -175,6 +179,18 @@ public class ReorgOp extends MultiThreadedHop
setLops(transform1);
break;
}
+ case ROLL: {
+ Lop[] linputs = new Lop[2]; //input, shift
+ for (int i = 0; i < 2; i++)
+ linputs[i] =
getInput().get(i).constructLops();
+
+ Transform transform1 = new Transform(linputs,
_op, getDataType(), getValueType(), et, 1);
+
+ setOutputDimensions(transform1);
+ setLineNumbers(transform1);
+ setLops(transform1);
+ break;
+ }
case RESHAPE: {
Lop[] linputs = new Lop[5]; //main, rows, cols,
dims, byrow
for (int i = 0; i < 5; i++)
@@ -279,9 +295,10 @@ public class ReorgOp extends MultiThreadedHop
ret = new
MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros());
break;
}
- case REV: {
+ case REV:
+ case ROLL: {
// dims and nnz are exactly the same as in input
- if( dc.dimsKnown() )
+ if (dc.dimsKnown())
ret = new
MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros());
break;
}
@@ -397,6 +414,7 @@ public class ReorgOp extends MultiThreadedHop
break;
}
case REV:
+ case ROLL:
{
// dims and nnz are exactly the same as in input
setDim1(input1.getDim1());
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java
b/src/main/java/org/apache/sysds/lops/Transform.java
index 0fcdc09fbf..8ef5925572 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -111,7 +111,10 @@ public class Transform extends Lop
case REV:
// Transpose a matrix
return "rev";
-
+
+ case ROLL:
+ return "roll";
+
case DIAG:
// Transform a vector into a diagonal matrix
return "rdiag";
@@ -138,6 +141,12 @@ public class Transform extends Lop
return getInstructions(input1, 1, output);
}
+ @Override
+ public String getInstructions(String input1, String input2, String
output) {
+ //opcodes: roll
+ return getInstructions(input1, 2, output);
+ }
+
@Override
public String getInstructions(String input1, String input2, String
input3, String input4, String output) {
//opcodes: rsort
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index de10e090f5..1de3442dd9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1266,7 +1266,17 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
output.setBlocksize (id.getBlocksize());
output.setValueType(id.getValueType());
break;
-
+
+ case ROLL:
+ checkNumParameters(2);
+ checkMatrixParam(getFirstExpr());
+ checkScalarParam(getSecondExpr());
+ output.setDataType(DataType.MATRIX);
+ output.setDimensions(id.getDim1(), id.getDim2());
+ output.setBlocksize(id.getBlocksize());
+ output.setValueType(id.getValueType());
+ break;
+
case DIAG:
checkNumParameters(1);
checkMatrixParam(getFirstExpr());
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 77ed904821..b76425668f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2481,6 +2481,14 @@ public class DMLTranslator
target.getValueType(),
ReOrgOp.valueOf(source.getOpCode().name()), expr);
break;
+ case ROLL:
+ ArrayList<Hop> inputs = new ArrayList<>();
+ inputs.add(expr);
+ inputs.add(expr2);
+ currBuiltinOp = new ReorgOp(target.getName(),
DataType.MATRIX,
+ target.getValueType(),
ReOrgOp.valueOf(source.getOpCode().name()), inputs);
+ break;
+
case CBIND:
case RBIND:
OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND)
? OpOp2.CBIND : OpOp2.RBIND;
diff --git
a/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java
b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java
new file mode 100644
index 0000000000..5bd78bb703
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.functionobjects;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+/**
+ * This index function is NOT used for actual sorting but just as a reference
+ * in ReorgOperator in order to identify sort operations.
+ */
+public class RollIndex extends IndexFunction {
+ private static final long serialVersionUID = -8446389232078905200L;
+
+ private final int _shift;
+
+ public RollIndex(int shift) {
+ _shift = shift;
+ }
+
+ public int getShift() {
+ return _shift;
+ }
+
+ @Override
+ public boolean computeDimension(int row, int col, CellIndex retDim) {
+ retDim.set(row, col);
+ return false;
+ }
+
+ @Override
+ public boolean computeDimension(DataCharacteristics in,
DataCharacteristics out) {
+ out.set(in.getRows(), in.getCols(), in.getBlocksize(),
in.getNonZeros());
+ return false;
+ }
+
+ @Override
+ public void execute(MatrixIndexes in, MatrixIndexes out) {
+ throw new NotImplementedException();
+ }
+
+ @Override
+ public void execute(CellIndex in, CellIndex out) {
+ throw new NotImplementedException();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 7a5ed3524c..a0270f6b20 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -271,6 +271,7 @@ public class CPInstructionParser extends InstructionParser {
// Reorg Instruction Opcodes (repositioning of existing values)
String2CPInstructionType.put( "r'" , CPType.Reorg);
String2CPInstructionType.put( "rev" , CPType.Reorg);
+ String2CPInstructionType.put( "roll" , CPType.Reorg);
String2CPInstructionType.put( "rdiag" , CPType.Reorg);
String2CPInstructionType.put( "rshape" , CPType.Reshape);
String2CPInstructionType.put( "rsort" , CPType.Reorg);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
index 18f5613c8b..e7b3000d52 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -38,20 +39,16 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
private final CPOperand _col;
private final CPOperand _desc;
private final CPOperand _ixret;
+ private final CPOperand _shift;
/**
* for opcodes r' and rdiag
- *
- * @param op
- * operator
- * @param in
- * cp input operand
- * @param out
- * cp output operand
- * @param opcode
- * the opcode
- * @param istr
- * ?
+ *
+ * @param op operator
+ * @param in cp input operand
+ * @param out cp output operand
+ * @param opcode the opcode
+ * @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out,
String opcode, String istr) {
this(op, in, out, null, null, null, opcode, istr);
@@ -59,30 +56,41 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
/**
* for opcode rsort
- *
- * @param op
- * operator
- * @param in
- * cp input operand
- * @param col
- * ?
- * @param desc
- * ?
- * @param ixret
- * ?
- * @param out
- * cp output operand
- * @param opcode
- * the opcode
- * @param istr
- * ?
+ *
+ * @param op operator
+ * @param in cp input operand
+ * @param col ?
+ * @param desc ?
+ * @param ixret ?
+ * @param out cp output operand
+ * @param opcode the opcode
+ * @param istr ?
*/
private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out,
CPOperand col, CPOperand desc, CPOperand ixret,
- String opcode, String istr) {
+ String opcode,
String istr) {
super(CPType.Reorg, op, in, out, opcode, istr);
_col = col;
_desc = desc;
_ixret = ixret;
+ _shift = null;
+ }
+
+ /**
+ * for opcode roll
+ *
+ * @param op operator
+ * @param in cp input operand
+ * @param shift ?
+ * @param out cp output operand
+ * @param opcode the opcode
+ * @param istr ?
+ */
+ private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out,
CPOperand shift, String opcode, String istr) {
+ super(CPType.Reorg, op, in, out, opcode, istr);
+ _col = null;
+ _desc = null;
+ _ixret = null;
+ _shift = shift;
}
public static ReorgCPInstruction parseInstruction ( String str ) {
@@ -103,6 +111,13 @@ public class ReorgCPInstruction extends UnaryCPInstruction
{
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
+ else if (opcode.equalsIgnoreCase("roll")) {
+ InstructionUtils.checkNumFields(str, 3);
+ in.split(parts[1]);
+ out.split(parts[3]);
+ CPOperand shift = new CPOperand(parts[2]);
+ return new ReorgCPInstruction(new ReorgOperator(new
RollIndex(0)), in, out, shift, opcode, str);
+ }
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgCPInstruction(new
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -136,7 +151,12 @@ public class ReorgCPInstruction extends UnaryCPInstruction
{
boolean ixret =
ec.getScalarInput(_ixret).getBooleanValue();
r_op = r_op.setFn(new SortIndex(cols, desc, ixret));
}
-
+
+ if (r_op.fn instanceof RollIndex) {
+ int shift = (int)
ec.getScalarInput(_shift).getLongValue();
+ r_op = r_op.setFn(new RollIndex(shift));
+ }
+
//execute operation
MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new
MatrixBlock(), 0, 0, 0);
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 53f517c058..b700410f62 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -63,7 +63,7 @@ public class LineageCacheConfig
// Relatively expensive instructions. Most include shuffles.
private static final String[] PERSIST_OPCODES1 = new String[] {
- "cpmm", "rmm", "pmm", "zipmm", "rev", "rshape", "rsort", "-",
"*", "+",
+ "cpmm", "rmm", "pmm", "zipmm", "rev", "roll", "rshape",
"rsort", "-", "*", "+",
"/", "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">",
"<=", ">=", "&&", "||", "xor", "max", "min", "rmempty",
"rappend",
"gappend", "galignedappend", "rbind", "cbind", "nmin", "nmax",
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 9e1e1cac9e..53e8a88832 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.data.SparseRow;
import org.apache.sysds.runtime.data.SparseRowVector;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -94,6 +95,7 @@ public class LibMatrixReorg {
private enum ReorgType {
TRANSPOSE,
REV,
+ ROLL,
DIAG,
RESHAPE,
SORT,
@@ -123,6 +125,9 @@ public class LibMatrixReorg {
return transpose(in, out);
case REV:
return rev(in, out);
+ case ROLL:
+ RollIndex rix = (RollIndex) op.fn;
+ return roll(in, out, rix.getShift());
case DIAG:
return diag(in, out);
case SORT:
@@ -142,6 +147,7 @@ public class LibMatrixReorg {
case TRANSPOSE:
return transposeInPlace(in, op.getNumThreads());
case REV:
+ case ROLL:
case SORT:
throw new DMLRuntimeException("Not implemented
inplace: " + op.fn.getClass().getSimpleName());
default:
@@ -420,6 +426,25 @@ public class LibMatrixReorg {
}
}
+ public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int
shift) {
+ //sparse-safe operation
+ if (in.isEmptyBlock(false))
+ return out;
+
+ //special case: row vector
+ if (in.rlen == 1) {
+ out.copy(in);
+ return out;
+ }
+
+ if (in.sparse)
+ rollSparse(in, out, shift);
+ else
+ rollDense(in, out, shift);
+
+ return out;
+ }
+
public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) {
//Timing time = new Timing(true);
@@ -957,7 +982,10 @@ public class LibMatrixReorg {
else if( op.fn instanceof RevIndex ) //rev
return ReorgType.REV;
-
+
+ else if( op.fn instanceof RollIndex ) //roll
+ return ReorgType.ROLL;
+
else if( op.fn instanceof DiagIndex ) //diag
return ReorgType.DIAG;
@@ -2243,7 +2271,83 @@ public class LibMatrixReorg {
if( !a.isEmpty(i) )
c.set(m-1-i, a.get(i), true);
}
-
+
+ private static void rollDense(MatrixBlock in, MatrixBlock out, int
shift) {
+ final int m = in.rlen;
+ final int n = in.clen;
+
+ //set basic meta data and allocate output
+ out.sparse = false;
+ out.nonZeros = in.nonZeros;
+ out.allocateDenseBlock(false);
+
+ //copy all rows into target positions
+ if (n == 1) { //column vector
+ double[] a = in.getDenseBlockValues();
+ double[] c = out.getDenseBlockValues();
+
+ // roll matrix with axis=none
+ shift %= (m != 0 ? m : 1);
+
+ System.arraycopy(a, 0, c, shift, m - shift);
+ System.arraycopy(a, m - shift, c, 0, shift);
+ } else { //general matrix case
+ DenseBlock a = in.getDenseBlock();
+ DenseBlock c = out.getDenseBlock();
+
+ // roll matrix with axis=0
+ shift %= (m != 0 ? m : 1);
+
+ for (int i = 0; i < m - shift; i++) {
+ System.arraycopy(a.values(i), a.pos(i),
c.values(i + shift), c.pos(i + shift), n);
+ }
+
+ for (int i = m - shift; i < m; i++) {
+ System.arraycopy(a.values(i), a.pos(i),
c.values(i + shift - m), c.pos(i + shift - m), n);
+ }
+ }
+ }
+
+ private static void rollSparse(MatrixBlock in, MatrixBlock out, int
shift) {
+ final int m = in.rlen;
+
+ //set basic meta data and allocate output
+ out.sparse = true;
+ out.nonZeros = in.nonZeros;
+ out.allocateSparseRowsBlock(false);
+
+ //copy all rows into target positions
+ SparseBlock a = in.getSparseBlock();
+ SparseBlock c = out.getSparseBlock();
+
+ // roll matrix with axis=0
+ shift %= (m != 0 ? m : 1);
+
+ for (int i = 0; i < m - shift; i++) {
+ if (a.isEmpty(i)) continue; // skip empty rows
+
+ rollSparseRow(a, c, i, i + shift);
+ }
+
+ for (int i = m - shift; i < m; i++) {
+ if (a.isEmpty(i)) continue; // skip empty rows
+
+ rollSparseRow(a, c, i, i + shift - m);
+ }
+ }
+
+ private static void rollSparseRow(SparseBlock a, SparseBlock c, int
oriIdx, int shiftIdx) {
+ final int apos = a.pos(oriIdx);
+ final int alen = a.size(oriIdx) + apos;
+ final int[] aix = a.indexes(oriIdx);
+ final double[] avals = a.values(oriIdx);
+
+ // copy only non-zero elements
+ for (int k = apos; k < alen; k++) {
+ c.set(shiftIdx, aix[k], avals[k]);
+ }
+ }
+
/**
* Generic implementation diagV2M
* (in most-likely DENSE, out most likely SPARSE)
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 62f9a1febb..7da67d267b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -87,6 +87,7 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -3565,7 +3566,8 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret,
int startRow, int startColumn, int length)
{
if ( !( op.fn instanceof SwapIndex || op.fn instanceof
DiagIndex
- || op.fn instanceof SortIndex || op.fn instanceof
RevIndex ) )
+ || op.fn instanceof SortIndex || op.fn instanceof
RevIndex
+ || op.fn instanceof RollIndex) )
throw new DMLRuntimeException("the current
reorgOperations cannot support: "+op.fn.getClass()+".");
MatrixBlock result = checkType(ret);
diff --git a/src/main/python/systemds/operator/nodes/matrix.py
b/src/main/python/systemds/operator/nodes/matrix.py
index a96f8d884c..41bb481da5 100644
--- a/src/main/python/systemds/operator/nodes/matrix.py
+++ b/src/main/python/systemds/operator/nodes/matrix.py
@@ -664,6 +664,13 @@ class Matrix(OperationNode):
"""
return Matrix(self.sds_context, "rev", [self])
+ def roll(self, shift: int) -> "Matrix":
+ """Reverses the rows
+
+ :return: the OperationNode representing this operation
+ """
+ return Matrix(self.sds_context, "roll", [self, shift])
+
def round(self) -> "Matrix":
"""round all values to nearest natural number
diff --git a/src/main/python/tests/matrix/test_roll.py
b/src/main/python/tests/matrix/test_roll.py
new file mode 100644
index 0000000000..1355f24082
--- /dev/null
+++ b/src/main/python/tests/matrix/test_roll.py
@@ -0,0 +1,65 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+import numpy as np
+import random
+from scipy import sparse
+from systemds.context import SystemDSContext
+
+np.random.seed(7)
+shape = (random.randrange(1, 25), random.randrange(1, 25))
+
+m = np.random.rand(shape[0], shape[1])
+my = np.random.rand(shape[0], 1)
+m_empty = np.asarray([[]])
+m_sparse = sparse.random(shape[0], shape[1], density=0.1,
format="csr").toarray()
+m_sparse = np.around(m_sparse, decimals=22)
+
+class TestRoll(unittest.TestCase):
+ sds: SystemDSContext = None
+
+ @classmethod
+ def setUpClass(cls):
+ cls.sds = SystemDSContext()
+
+ @classmethod
+ def tearDownClass(cls):
+ cls.sds.close()
+
+ def test_empty(self):
+ r = self.sds.from_numpy(np.asarray(m_empty)).roll(1).compute()
+ self.assertTrue(np.allclose(r, m_empty))
+
+ def test_col_vec(self):
+ r = self.sds.from_numpy(my).roll(1).compute()
+ self.assertTrue(np.allclose(r, np.roll(my, axis=None, shift=1)))
+
+ def test_basic(self):
+ r = self.sds.from_numpy(m).roll(1).compute()
+ self.assertTrue(np.allclose(r, np.roll(m, axis=0, shift=1)))
+
+ def test_sparse_matrix(self):
+ r = self.sds.from_numpy(m_sparse).roll(1).compute()
+ self.assertTrue(np.allclose(r, np.roll(m_sparse, axis=0, shift=1)))
+
+if __name__ == "__main__":
+ unittest.main(exit=False)
diff --git
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java
new file mode 100644
index 0000000000..17ba8fb068
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix.libMatrixReorg;
+
+import org.apache.sysds.runtime.functionobjects.IndexFunction;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import static org.junit.Assert.fail;
+
+/**
+ * Test class for the roll function in MatrixBlock.
+ * <p>
+ * This test verifies that the roll function produces identical results
+ * when applied to both sparse and dense representations of the same matrix.
+ */
+@RunWith(Parameterized.class)
+public class RollTest {
+ private final int shift;
+
+ // Input matrices
+ private MatrixBlock inputSparse;
+ private MatrixBlock inputDense;
+
+ /**
+ * Constructor for parameterized test cases.
+ *
+ * @param rows Number of rows in the test matrix.
+ * @param cols Number of columns in the test matrix.
+ * @param sparsity Sparsity level of the test matrix (0.0 to 1.0).
+ * @param shift Shift value for the roll operation.
+ */
+ public RollTest(int rows, int cols, double sparsity, int shift) {
+ this.shift = shift;
+
+ // Generate a MatrixBlock with the given parameters
+ inputSparse = TestUtils.generateTestMatrixBlock(rows, cols, 0, 10,
sparsity, 1);
+ inputSparse.recomputeNonZeros();
+
+ inputDense = new MatrixBlock(rows, cols, false); // false indicates
dense
+ inputDense.copy(inputSparse, false); // Copy without maintaining
sparsity
+ inputDense.recomputeNonZeros();
+ }
+
+ /**
+ * Defines the parameters for the test cases.
+ * Each Object[] contains {rows, cols, sparsity, shift}.
+ *
+ * @return Collection of test parameters.
+ */
+ @Parameters(name = "Rows: {0}, Cols: {1}, Sparsity: {2}, Shift: {3}")
+ public static Collection<Object[]> data() {
+ List<Object[]> tests = new ArrayList<>();
+
+ // Define various sizes, sparsity levels, and shift values
+ int[] rows = {1, 19, 1001, 2017};
+ int[] cols = {1, 17, 1001, 2017};
+ double[] sparsities = {0.01, 0.1, 0.7, 1.0};
+ int[] shifts = {0, 1, 5, 10, 15};
+
+ // Generate all combinations of sizes, sparsities, and shifts
+ for (int row : rows) {
+ for (int col : cols) {
+ for (double sparsity : sparsities) {
+ for (int shift : shifts) {
+ tests.add(new Object[]{row, col, sparsity, shift});
+ }
+ }
+ }
+ }
+ return tests;
+ }
+
+ /**
+ * The actual test method that performs the roll operation on both
+ * sparse and dense matrices and compares the results.
+ */
+ @Test
+ public void test() {
+ try {
+ IndexFunction op = new RollIndex(shift);
+ MatrixBlock outputDense = inputDense.reorgOperations(
+ new ReorgOperator(op), new
MatrixBlock(), 0, 0, 0);
+ MatrixBlock outputSparse = inputSparse.reorgOperations(
+ new ReorgOperator(op), new
MatrixBlock(), 0, 0, 0);
+ outputSparse.sparseToDense();
+
+ // Compare the dense representations of both outputs
+ TestUtils.compareMatrices(outputSparse, outputDense, 1e-9,
+ "Compare Sparse and Dense Roll
Results");
+
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail("Exception occurred during roll function test: " +
e.getMessage());
+ }
+ }
+}