This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new edfce10724 [SYSTEMDS-3729] New roll reorg operations in CP, incl tests
edfce10724 is described below

commit edfce107242213c1e9199c270ecd07d41fe2ac67
Author: min-guk <koreacho...@gmail.com>
AuthorDate: Mon Sep 23 09:50:43 2024 +0200

    [SYSTEMDS-3729] New roll reorg operations in CP, incl tests
    
    Closes #2103.
---
 .github/workflows/javaTests.yml                    |   1 +
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 src/main/java/org/apache/sysds/hops/ReorgOp.java   |  22 +++-
 src/main/java/org/apache/sysds/lops/Transform.java |  11 +-
 .../sysds/parser/BuiltinFunctionExpression.java    |  12 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |   8 ++
 .../sysds/runtime/functionobjects/RollIndex.java   |  65 +++++++++++
 .../runtime/instructions/CPInstructionParser.java  |   1 +
 .../instructions/cp/ReorgCPInstruction.java        |  80 +++++++++-----
 .../sysds/runtime/lineage/LineageCacheConfig.java  |   2 +-
 .../sysds/runtime/matrix/data/LibMatrixReorg.java  | 108 +++++++++++++++++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |   4 +-
 src/main/python/systemds/operator/nodes/matrix.py  |   7 ++
 src/main/python/tests/matrix/test_roll.py          |  65 +++++++++++
 .../component/matrix/libMatrixReorg/RollTest.java  | 123 +++++++++++++++++++++
 16 files changed, 473 insertions(+), 39 deletions(-)

diff --git a/.github/workflows/javaTests.yml b/.github/workflows/javaTests.yml
index 366e20398e..595298b8f4 100644
--- a/.github/workflows/javaTests.yml
+++ b/.github/workflows/javaTests.yml
@@ -177,3 +177,4 @@ jobs:
         name: Java Code Coverage (Jacoco)
         path: target/site/jacoco
         retention-days: 3
+
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index d021597d7f..98f92ae55e 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -282,6 +282,7 @@ public enum Builtins {
        RCM("rowClassMeet", "rcm", false, false, ReturnType.MULTI_RETURN),
        REMOVE("remove", false, ReturnType.MULTI_RETURN),
        REV("rev", false),
+       ROLL("roll", false),
        ROUND("round", false),
        ROW_COUNT_DISTINCT("rowCountDistinct",false),
        ROWINDEXMAX("rowIndexMax", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 6d64cbae54..e7274b25c4 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -749,7 +749,7 @@ public interface Types {
        /** Operations that perform internal reorganization of an allocation */
        public enum ReOrgOp {
                DIAG, //DIAG_V2M and DIAG_M2V could not be distinguished if 
sizes unknown
-               RESHAPE, REV, SORT, TRANS;
+               RESHAPE, REV, ROLL, SORT, TRANS;
                
                @Override
                public String toString() {
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java 
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index 0dce06964c..d7629eab10 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -90,6 +90,9 @@ public class ReorgOp extends MultiThreadedHop
                        case REV:
                                HopsException.check(sz == 1, this, "should have 
arity 1 for op %s but has arity %d", _op, sz);
                                break;
+                       case ROLL:
+                               HopsException.check(sz == 2, this, "should have 
arity 2 for op %s but has arity %d", _op, sz);
+                               break;
                        case RESHAPE:
                        case SORT:
                                HopsException.check(sz == 5, this, "should have 
arity 5 for op %s but has arity %d", _op, sz);
@@ -125,6 +128,7 @@ public class ReorgOp extends MultiThreadedHop
                        }
                        case DIAG:
                        case REV:
+                       case ROLL:
                        case SORT:
                                return false;
                        default:
@@ -175,6 +179,18 @@ public class ReorgOp extends MultiThreadedHop
                                setLops(transform1);
                                break;
                        }
+                       case ROLL: {
+                               Lop[] linputs = new Lop[2]; //input, shift
+                               for (int i = 0; i < 2; i++)
+                                       linputs[i] = 
getInput().get(i).constructLops();
+
+                               Transform transform1 = new Transform(linputs, 
_op, getDataType(), getValueType(), et, 1);
+
+                               setOutputDimensions(transform1);
+                               setLineNumbers(transform1);
+                               setLops(transform1);
+                               break;
+                       }
                        case RESHAPE: {
                                Lop[] linputs = new Lop[5]; //main, rows, cols, 
dims, byrow
                                for (int i = 0; i < 5; i++)
@@ -279,9 +295,10 @@ public class ReorgOp extends MultiThreadedHop
                                        ret = new 
MatrixCharacteristics(dc.getCols(), dc.getRows(), -1, dc.getNonZeros());
                                break;
                        }
-                       case REV: {
+                       case REV:
+                       case ROLL: {
                                // dims and nnz are exactly the same as in input
-                               if( dc.dimsKnown() )
+                               if (dc.dimsKnown())
                                        ret = new 
MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros());
                                break;
                        }
@@ -397,6 +414,7 @@ public class ReorgOp extends MultiThreadedHop
                                break;
                        }
                        case REV:
+                       case ROLL:
                        {
                                // dims and nnz are exactly the same as in input
                                setDim1(input1.getDim1());
diff --git a/src/main/java/org/apache/sysds/lops/Transform.java 
b/src/main/java/org/apache/sysds/lops/Transform.java
index 0fcdc09fbf..8ef5925572 100644
--- a/src/main/java/org/apache/sysds/lops/Transform.java
+++ b/src/main/java/org/apache/sysds/lops/Transform.java
@@ -111,7 +111,10 @@ public class Transform extends Lop
                        case REV:
                                // Transpose a matrix
                                return "rev";
-                       
+
+                       case ROLL:
+                               return "roll";
+
                        case DIAG:
                                // Transform a vector into a diagonal matrix
                                return "rdiag";
@@ -138,6 +141,12 @@ public class Transform extends Lop
                return getInstructions(input1, 1, output);
        }
 
+       @Override
+       public String getInstructions(String input1, String input2, String 
output) {
+               //opcodes: roll
+               return getInstructions(input1, 2, output);
+       }
+
        @Override
        public String getInstructions(String input1, String input2, String 
input3, String input4, String output) {
                //opcodes: rsort
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index de10e090f5..1de3442dd9 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -1266,7 +1266,17 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        output.setBlocksize (id.getBlocksize());
                        output.setValueType(id.getValueType());
                        break;
-                       
+
+               case ROLL:
+                       checkNumParameters(2);
+                       checkMatrixParam(getFirstExpr());
+                       checkScalarParam(getSecondExpr());
+                       output.setDataType(DataType.MATRIX);
+                       output.setDimensions(id.getDim1(), id.getDim2());
+                       output.setBlocksize(id.getBlocksize());
+                       output.setValueType(id.getValueType());
+                       break;
+
                case DIAG:
                        checkNumParameters(1);
                        checkMatrixParam(getFirstExpr());
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 77ed904821..b76425668f 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2481,6 +2481,14 @@ public class DMLTranslator
                                target.getValueType(), 
ReOrgOp.valueOf(source.getOpCode().name()), expr);
                        break;
 
+               case ROLL:
+                       ArrayList<Hop> inputs = new ArrayList<>();
+                       inputs.add(expr);
+                       inputs.add(expr2);
+                       currBuiltinOp = new ReorgOp(target.getName(), 
DataType.MATRIX,
+                                       target.getValueType(), 
ReOrgOp.valueOf(source.getOpCode().name()), inputs);
+                       break;
+
                case CBIND:
                case RBIND:
                        OpOp2 appendOp2 = (source.getOpCode()==Builtins.CBIND) 
? OpOp2.CBIND : OpOp2.RBIND;
diff --git 
a/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java 
b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java
new file mode 100644
index 0000000000..5bd78bb703
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/functionobjects/RollIndex.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.functionobjects;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
+
+/**
+ * This index function is NOT used for actual sorting but just as a reference
+ * in ReorgOperator in order to identify sort operations.
+ */
+public class RollIndex extends IndexFunction {
+    private static final long serialVersionUID = -8446389232078905200L;
+
+    private final int _shift;
+
+    public RollIndex(int shift) {
+        _shift = shift;
+    }
+
+    public int getShift() {
+        return _shift;
+    }
+
+    @Override
+    public boolean computeDimension(int row, int col, CellIndex retDim) {
+        retDim.set(row, col);
+        return false;
+    }
+
+    @Override
+    public boolean computeDimension(DataCharacteristics in, 
DataCharacteristics out) {
+        out.set(in.getRows(), in.getCols(), in.getBlocksize(), 
in.getNonZeros());
+        return false;
+    }
+
+    @Override
+    public void execute(MatrixIndexes in, MatrixIndexes out) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public void execute(CellIndex in, CellIndex out) {
+        throw new NotImplementedException();
+    }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 7a5ed3524c..a0270f6b20 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -271,6 +271,7 @@ public class CPInstructionParser extends InstructionParser {
                // Reorg Instruction Opcodes (repositioning of existing values)
                String2CPInstructionType.put( "r'"          , CPType.Reorg);
                String2CPInstructionType.put( "rev"         , CPType.Reorg);
+               String2CPInstructionType.put( "roll"         , CPType.Reorg);
                String2CPInstructionType.put( "rdiag"       , CPType.Reorg);
                String2CPInstructionType.put( "rshape"      , CPType.Reshape);
                String2CPInstructionType.put( "rsort"       , CPType.Reorg);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
index 18f5613c8b..e7b3000d52 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ReorgCPInstruction.java
@@ -25,6 +25,7 @@ import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
 import org.apache.sysds.runtime.functionobjects.SortIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -38,20 +39,16 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
        private final CPOperand _col;
        private final CPOperand _desc;
        private final CPOperand _ixret;
+       private final CPOperand _shift;
 
        /**
         * for opcodes r' and rdiag
-        * 
-        * @param op
-        *            operator
-        * @param in
-        *            cp input operand
-        * @param out
-        *            cp output operand
-        * @param opcode
-        *            the opcode
-        * @param istr
-        *            ?
+        *
+        * @param op     operator
+        * @param in     cp input operand
+        * @param out    cp output operand
+        * @param opcode the opcode
+        * @param istr   ?
         */
        private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, 
String opcode, String istr) {
                this(op, in, out, null, null, null, opcode, istr);
@@ -59,30 +56,41 @@ public class ReorgCPInstruction extends UnaryCPInstruction {
 
        /**
         * for opcode rsort
-        * 
-        * @param op
-        *            operator
-        * @param in
-        *            cp input operand
-        * @param col
-        *            ?
-        * @param desc
-        *            ?
-        * @param ixret
-        *            ?
-        * @param out
-        *            cp output operand
-        * @param opcode
-        *            the opcode
-        * @param istr
-        *            ?
+        *
+        * @param op     operator
+        * @param in     cp input operand
+        * @param col    ?
+        * @param desc   ?
+        * @param ixret  ?
+        * @param out    cp output operand
+        * @param opcode the opcode
+        * @param istr   ?
         */
        private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, 
CPOperand col, CPOperand desc, CPOperand ixret,
-                       String opcode, String istr) {
+                                                          String opcode, 
String istr) {
                super(CPType.Reorg, op, in, out, opcode, istr);
                _col = col;
                _desc = desc;
                _ixret = ixret;
+               _shift = null;
+       }
+
+       /**
+        * for opcode roll
+        *
+        * @param op     operator
+        * @param in     cp input operand
+        * @param shift  ?
+        * @param out    cp output operand
+        * @param opcode the opcode
+        * @param istr   ?
+        */
+       private ReorgCPInstruction(Operator op, CPOperand in, CPOperand out, 
CPOperand shift, String opcode, String istr) {
+               super(CPType.Reorg, op, in, out, opcode, istr);
+               _col = null;
+               _desc = null;
+               _ixret = null;
+               _shift = shift;
        }
 
        public static ReorgCPInstruction parseInstruction ( String str ) {
@@ -103,6 +111,13 @@ public class ReorgCPInstruction extends UnaryCPInstruction 
{
                        parseUnaryInstruction(str, in, out); //max 2 operands
                        return new ReorgCPInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
                }
+               else if (opcode.equalsIgnoreCase("roll")) {
+                       InstructionUtils.checkNumFields(str, 3);
+                       in.split(parts[1]);
+                       out.split(parts[3]);
+                       CPOperand shift = new CPOperand(parts[2]);
+                       return new ReorgCPInstruction(new ReorgOperator(new 
RollIndex(0)), in, out, shift, opcode, str);
+               }
                else if ( opcode.equalsIgnoreCase("rdiag") ) {
                        parseUnaryInstruction(str, in, out); //max 2 operands
                        return new ReorgCPInstruction(new 
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -136,7 +151,12 @@ public class ReorgCPInstruction extends UnaryCPInstruction 
{
                        boolean ixret = 
ec.getScalarInput(_ixret).getBooleanValue();
                        r_op = r_op.setFn(new SortIndex(cols, desc, ixret));
                }
-               
+
+               if (r_op.fn instanceof RollIndex) {
+                       int shift = (int) 
ec.getScalarInput(_shift).getLongValue();
+                       r_op = r_op.setFn(new RollIndex(shift));
+               }
+
                //execute operation
                MatrixBlock soresBlock = matBlock.reorgOperations(r_op, new 
MatrixBlock(), 0, 0, 0);
                
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
index 53f517c058..b700410f62 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageCacheConfig.java
@@ -63,7 +63,7 @@ public class LineageCacheConfig
 
        // Relatively expensive instructions. Most include shuffles.
        private static final String[] PERSIST_OPCODES1 = new String[] {
-               "cpmm", "rmm", "pmm", "zipmm", "rev", "rshape", "rsort", "-", 
"*", "+",
+               "cpmm", "rmm", "pmm", "zipmm", "rev", "roll", "rshape", 
"rsort", "-", "*", "+",
                "/", "%%", "%/%", "1-*", "^", "^2", "*2", "==", "!=", "<", ">",
                "<=", ">=", "&&", "||", "xor", "max", "min", "rmempty", 
"rappend",
                "gappend", "galignedappend", "rbind", "cbind", "nmin", "nmax",
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 9e1e1cac9e..53e8a88832 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -50,6 +50,7 @@ import org.apache.sysds.runtime.data.SparseRow;
 import org.apache.sysds.runtime.data.SparseRowVector;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
 import org.apache.sysds.runtime.functionobjects.SortIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
@@ -94,6 +95,7 @@ public class LibMatrixReorg {
        private enum ReorgType {
                TRANSPOSE,
                REV,
+               ROLL,
                DIAG,
                RESHAPE,
                SORT,
@@ -123,6 +125,9 @@ public class LibMatrixReorg {
                                        return transpose(in, out);
                        case REV:
                                return rev(in, out);
+                       case ROLL:
+                               RollIndex rix = (RollIndex) op.fn;
+                               return roll(in, out, rix.getShift());
                        case DIAG:
                                return diag(in, out);
                        case SORT:
@@ -142,6 +147,7 @@ public class LibMatrixReorg {
                        case TRANSPOSE:
                                return transposeInPlace(in, op.getNumThreads());
                        case REV:
+                       case ROLL:
                        case SORT:
                                throw new DMLRuntimeException("Not implemented 
inplace: " + op.fn.getClass().getSimpleName());
                        default:
@@ -420,6 +426,25 @@ public class LibMatrixReorg {
                }
        }
 
+       public static MatrixBlock roll(MatrixBlock in, MatrixBlock out, int 
shift) {
+               //sparse-safe operation
+               if (in.isEmptyBlock(false))
+                       return out;
+
+               //special case: row vector
+               if (in.rlen == 1) {
+                       out.copy(in);
+                       return out;
+               }
+
+               if (in.sparse)
+                       rollSparse(in, out, shift);
+               else
+                       rollDense(in, out, shift);
+
+               return out;
+       }
+
        public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) {
                //Timing time = new Timing(true);
                
@@ -957,7 +982,10 @@ public class LibMatrixReorg {
                
                else if( op.fn instanceof RevIndex ) //rev
                        return ReorgType.REV;
-               
+
+               else if( op.fn instanceof RollIndex ) //roll
+                       return ReorgType.ROLL;
+
                else if( op.fn instanceof DiagIndex ) //diag
                        return ReorgType.DIAG;
                
@@ -2243,7 +2271,83 @@ public class LibMatrixReorg {
                        if( !a.isEmpty(i) )
                                c.set(m-1-i, a.get(i), true);
        }
-       
+
+       private static void rollDense(MatrixBlock in, MatrixBlock out, int 
shift) {
+               final int m = in.rlen;
+               final int n = in.clen;
+
+               //set basic meta data and allocate output
+               out.sparse = false;
+               out.nonZeros = in.nonZeros;
+               out.allocateDenseBlock(false);
+
+               //copy all rows into target positions
+               if (n == 1) { //column vector
+                       double[] a = in.getDenseBlockValues();
+                       double[] c = out.getDenseBlockValues();
+
+                       // roll matrix with axis=none
+                       shift %= (m != 0 ? m : 1);
+
+                       System.arraycopy(a, 0, c, shift, m - shift);
+                       System.arraycopy(a, m - shift, c, 0, shift);
+               } else { //general matrix case
+                       DenseBlock a = in.getDenseBlock();
+                       DenseBlock c = out.getDenseBlock();
+
+                       // roll matrix with axis=0
+                       shift %= (m != 0 ? m : 1);
+
+                       for (int i = 0; i < m - shift; i++) {
+                               System.arraycopy(a.values(i), a.pos(i), 
c.values(i + shift), c.pos(i + shift), n);
+                       }
+
+                       for (int i = m - shift; i < m; i++) {
+                               System.arraycopy(a.values(i), a.pos(i), 
c.values(i + shift - m), c.pos(i + shift - m), n);
+                       }
+               }
+       }
+
+       private static void rollSparse(MatrixBlock in, MatrixBlock out, int 
shift) {
+               final int m = in.rlen;
+
+               //set basic meta data and allocate output
+               out.sparse = true;
+               out.nonZeros = in.nonZeros;
+               out.allocateSparseRowsBlock(false);
+
+               //copy all rows into target positions
+               SparseBlock a = in.getSparseBlock();
+               SparseBlock c = out.getSparseBlock();
+
+               // roll matrix with axis=0
+               shift %= (m != 0 ? m : 1);
+
+               for (int i = 0; i < m - shift; i++) {
+                       if (a.isEmpty(i)) continue;    // skip empty rows
+
+                       rollSparseRow(a, c, i, i + shift);
+               }
+
+               for (int i = m - shift; i < m; i++) {
+                       if (a.isEmpty(i)) continue;    // skip empty rows
+
+                       rollSparseRow(a, c, i, i + shift - m);
+               }
+       }
+
+       private static void rollSparseRow(SparseBlock a, SparseBlock c, int 
oriIdx, int shiftIdx) {
+               final int apos = a.pos(oriIdx);
+               final int alen = a.size(oriIdx) + apos;
+               final int[] aix = a.indexes(oriIdx);
+               final double[] avals = a.values(oriIdx);
+
+               // copy only non-zero elements
+               for (int k = apos; k < alen; k++) {
+                       c.set(shiftIdx, aix[k], avals[k]);
+               }
+       }
+
        /**
         * Generic implementation diagV2M
         * (in most-likely DENSE, out most likely SPARSE)
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 62f9a1febb..7da67d267b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -87,6 +87,7 @@ import org.apache.sysds.runtime.functionobjects.ReduceAll;
 import org.apache.sysds.runtime.functionobjects.ReduceCol;
 import org.apache.sysds.runtime.functionobjects.ReduceRow;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
 import org.apache.sysds.runtime.functionobjects.SortIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -3565,7 +3566,8 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        public MatrixBlock reorgOperations(ReorgOperator op, MatrixValue ret, 
int startRow, int startColumn, int length)
        {
                if ( !( op.fn instanceof SwapIndex || op.fn instanceof 
DiagIndex 
-                       || op.fn instanceof SortIndex || op.fn instanceof 
RevIndex ) )
+                       || op.fn instanceof SortIndex || op.fn instanceof 
RevIndex
+                       || op.fn instanceof RollIndex) )
                        throw new DMLRuntimeException("the current 
reorgOperations cannot support: "+op.fn.getClass()+".");
                
                MatrixBlock result = checkType(ret);
diff --git a/src/main/python/systemds/operator/nodes/matrix.py 
b/src/main/python/systemds/operator/nodes/matrix.py
index a96f8d884c..41bb481da5 100644
--- a/src/main/python/systemds/operator/nodes/matrix.py
+++ b/src/main/python/systemds/operator/nodes/matrix.py
@@ -664,6 +664,13 @@ class Matrix(OperationNode):
         """
         return Matrix(self.sds_context, "rev", [self])
 
+    def roll(self, shift: int) -> "Matrix":
+        """Reverses the rows
+
+        :return: the OperationNode representing this operation
+        """
+        return Matrix(self.sds_context, "roll", [self, shift])
+
     def round(self) -> "Matrix":
         """round all values to nearest natural number
 
diff --git a/src/main/python/tests/matrix/test_roll.py 
b/src/main/python/tests/matrix/test_roll.py
new file mode 100644
index 0000000000..1355f24082
--- /dev/null
+++ b/src/main/python/tests/matrix/test_roll.py
@@ -0,0 +1,65 @@
+# -------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+# -------------------------------------------------------------
+
+import unittest
+import numpy as np
+import random
+from scipy import sparse
+from systemds.context import SystemDSContext
+
+np.random.seed(7)
+shape = (random.randrange(1, 25), random.randrange(1, 25))
+
+m = np.random.rand(shape[0], shape[1])
+my = np.random.rand(shape[0], 1)
+m_empty = np.asarray([[]])
+m_sparse = sparse.random(shape[0], shape[1], density=0.1, 
format="csr").toarray()
+m_sparse = np.around(m_sparse, decimals=22)
+
+class TestRoll(unittest.TestCase):
+    sds: SystemDSContext = None
+
+    @classmethod
+    def setUpClass(cls):
+        cls.sds = SystemDSContext()
+
+    @classmethod
+    def tearDownClass(cls):
+        cls.sds.close()
+
+    def test_empty(self):
+        r = self.sds.from_numpy(np.asarray(m_empty)).roll(1).compute()
+        self.assertTrue(np.allclose(r, m_empty))
+
+    def test_col_vec(self):
+        r = self.sds.from_numpy(my).roll(1).compute()
+        self.assertTrue(np.allclose(r, np.roll(my, axis=None, shift=1)))
+
+    def test_basic(self):
+        r = self.sds.from_numpy(m).roll(1).compute()
+        self.assertTrue(np.allclose(r, np.roll(m, axis=0, shift=1)))
+
+    def test_sparse_matrix(self):
+        r = self.sds.from_numpy(m_sparse).roll(1).compute()
+        self.assertTrue(np.allclose(r, np.roll(m_sparse, axis=0, shift=1)))
+
+if __name__ == "__main__":
+    unittest.main(exit=False)
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java
 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java
new file mode 100644
index 0000000000..17ba8fb068
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix.libMatrixReorg;
+
+import org.apache.sysds.runtime.functionobjects.IndexFunction;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import static org.junit.Assert.fail;
+
+/**
+ * Test class for the roll function in MatrixBlock.
+ * <p>
+ * This test verifies that the roll function produces identical results
+ * when applied to both sparse and dense representations of the same matrix.
+ */
+@RunWith(Parameterized.class)
+public class RollTest {
+    private final int shift;
+
+    // Input matrices
+    private MatrixBlock inputSparse;
+    private MatrixBlock inputDense;
+
+    /**
+     * Constructor for parameterized test cases.
+     *
+     * @param rows     Number of rows in the test matrix.
+     * @param cols     Number of columns in the test matrix.
+     * @param sparsity Sparsity level of the test matrix (0.0 to 1.0).
+     * @param shift    Shift value for the roll operation.
+     */
+    public RollTest(int rows, int cols, double sparsity, int shift) {
+        this.shift = shift;
+
+        // Generate a MatrixBlock with the given parameters
+        inputSparse = TestUtils.generateTestMatrixBlock(rows, cols, 0, 10, 
sparsity, 1);
+        inputSparse.recomputeNonZeros();
+
+        inputDense = new MatrixBlock(rows, cols, false); // false indicates 
dense
+        inputDense.copy(inputSparse, false); // Copy without maintaining 
sparsity
+        inputDense.recomputeNonZeros();
+    }
+
+    /**
+     * Defines the parameters for the test cases.
+     * Each Object[] contains {rows, cols, sparsity, shift}.
+     *
+     * @return Collection of test parameters.
+     */
+    @Parameters(name = "Rows: {0}, Cols: {1}, Sparsity: {2}, Shift: {3}")
+    public static Collection<Object[]> data() {
+        List<Object[]> tests = new ArrayList<>();
+
+        // Define various sizes, sparsity levels, and shift values
+        int[] rows = {1, 19, 1001, 2017};
+        int[] cols = {1, 17, 1001, 2017};
+        double[] sparsities = {0.01, 0.1, 0.7, 1.0};
+        int[] shifts = {0, 1, 5, 10, 15};
+
+        // Generate all combinations of sizes, sparsities, and shifts
+        for (int row : rows) {
+            for (int col : cols) {
+                for (double sparsity : sparsities) {
+                    for (int shift : shifts) {
+                        tests.add(new Object[]{row, col, sparsity, shift});
+                    }
+                }
+            }
+        }
+        return tests;
+    }
+
+    /**
+     * The actual test method that performs the roll operation on both
+     * sparse and dense matrices and compares the results.
+     */
+    @Test
+    public void test() {
+        try {
+            IndexFunction op = new RollIndex(shift);
+            MatrixBlock outputDense = inputDense.reorgOperations(
+                                       new ReorgOperator(op), new 
MatrixBlock(), 0, 0, 0);
+            MatrixBlock outputSparse = inputSparse.reorgOperations(
+                                       new ReorgOperator(op), new 
MatrixBlock(), 0, 0, 0);
+            outputSparse.sparseToDense();
+
+            // Compare the dense representations of both outputs
+            TestUtils.compareMatrices(outputSparse, outputDense, 1e-9,
+                                       "Compare Sparse and Dense Roll 
Results");
+
+        } catch (Exception e) {
+            e.printStackTrace();
+            fail("Exception occurred during roll function test: " + 
e.getMessage());
+        }
+    }
+}

Reply via email to