This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0825f0fab7 [SYSTEMDS-3729] Add roll reorg operations in SP
0825f0fab7 is described below

commit 0825f0fab741250ee3d5b3e0cd78ff62641cd250
Author: min-guk <[email protected]>
AuthorDate: Wed Sep 25 09:38:55 2024 +0200

    [SYSTEMDS-3729] Add roll reorg operations in SP
    
    Closes #2112.
---
 .../java/org/apache/sysds/hops/AggUnaryOp.java     |   1 -
 .../runtime/instructions/SPInstructionParser.java  |   1 +
 .../instructions/spark/ReorgSPInstruction.java     |  53 +++++++-
 .../sysds/runtime/matrix/data/LibMatrixReorg.java  | 138 +++++++++++++--------
 .../sysds/test/functions/reorg/FullRollTest.java   | 121 ++++++++++++++++++
 src/test/scripts/functions/reorg/Roll1.dml         |  25 ++++
 6 files changed, 286 insertions(+), 53 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index eec86ec15b..954caa7a2e 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -31,7 +31,6 @@ import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.lops.Nary;
 import org.apache.sysds.lops.PartialAggregate;
 import org.apache.sysds.lops.TernaryAggregate;
 import org.apache.sysds.lops.UAggOuterChain;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 5e4dbaedeb..5c72b85436 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -166,6 +166,7 @@ public class SPInstructionParser extends InstructionParser
                // Reorg Instruction Opcodes (repositioning of existing values)
                String2SPInstructionType.put( "r'",       SPType.Reorg);
                String2SPInstructionType.put( "rev",      SPType.Reorg);
+               String2SPInstructionType.put( "roll",      SPType.Reorg);
                String2SPInstructionType.put( "rdiag",    SPType.Reorg);
                String2SPInstructionType.put( "rshape",   SPType.MatrixReshape);
                String2SPInstructionType.put( "rsort",    SPType.Reorg);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index de01a71ca8..b096405959 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -36,6 +36,7 @@ import 
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.functionobjects.DiagIndex;
 import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
 import org.apache.sysds.runtime.functionobjects.SortIndex;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -68,6 +69,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
        private CPOperand _desc = null;
        private CPOperand _ixret = null;
        private boolean _bSortIndInMem = false;
+       private CPOperand _shift = null;
 
        private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, 
String opcode, String istr) {
                super(SPType.Reorg, op, in, out, opcode, istr);
@@ -82,6 +84,11 @@ public class ReorgSPInstruction extends UnarySPInstruction {
                _bSortIndInMem = bSortIndInMem;
        }
 
+       private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out, 
CPOperand shift, String opcode, String istr) {
+               this(op, in, out, opcode, istr);
+               _shift = shift;
+       }
+
        public static ReorgSPInstruction parseInstruction ( String str ) {
                CPOperand in = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
                CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
@@ -95,6 +102,15 @@ public class ReorgSPInstruction extends UnarySPInstruction {
                        parseUnaryInstruction(str, in, out); //max 2 operands
                        return new ReorgSPInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
                }
+               else if (opcode.equalsIgnoreCase("roll")) {
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields(str, 3);
+               in.split(parts[1]);
+               out.split(parts[3]);
+               CPOperand shift = new CPOperand(parts[2]);
+               return new ReorgSPInstruction(new ReorgOperator(new 
RollIndex(0)),
+                               in, out, shift, opcode, str);
+}
                else if ( opcode.equalsIgnoreCase("rdiag") ) {
                        parseUnaryInstruction(str, in, out); //max 2 operands
                        return new ReorgSPInstruction(new 
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -141,6 +157,14 @@ public class ReorgSPInstruction extends UnarySPInstruction 
{
                        if( mcIn.getRows() % mcIn.getBlocksize() != 0 )
                                out = RDDAggregateUtils.mergeByKey(out, false);
                }
+               else if (opcode.equalsIgnoreCase("roll")) // ROLL
+               {
+                       int shift = (int) 
ec.getScalarInput(_shift).getLongValue();
+
+                       //execute roll reorg operation
+                       out = in1.flatMapToPair(new RDDRollFunction(mcIn, 
shift));
+                       out = RDDAggregateUtils.mergeByKey(out, false);
+               }
                else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG
                {       
                        if(mcIn.getCols() == 1) { // diagV2M
@@ -233,7 +257,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
                                boolean ixret = 
sec.getScalarInput(_ixret).getBooleanValue();
                                mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(), 
mc1.getBlocksize(), mc1.getBlocksize());
                        }
-                       else { //e.g., rev
+                       else { //e.g., rev, roll
                                mcOut.set(mc1);
                        }
                }
@@ -243,7 +267,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
                        boolean sortIx = getOpcode().equalsIgnoreCase("rsort") 
&& sec.getScalarInput(_ixret).getBooleanValue();                 
                        if( sortIx )
                                mcOut.setNonZeros(mc1.getRows());
-                       else //default (r', rdiag, rev, rsort data)
+                       else //default (r', rdiag, rev, roll, rsort data)
                                mcOut.setNonZeros(mc1.getNonZeros());
                }
        }
@@ -315,6 +339,31 @@ public class ReorgSPInstruction extends UnarySPInstruction 
{
                }
        }
 
+       private static class RDDRollFunction implements 
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, 
MatrixBlock> {
+               private static final long serialVersionUID = 
1183373828539843938L;
+
+               private DataCharacteristics _mcIn = null;
+               private int _shift = 0;
+
+               public RDDRollFunction(DataCharacteristics mcIn, int shift) {
+                       _mcIn = mcIn;
+                       _shift = shift;
+               }
+
+               @Override
+               public Iterator<Tuple2<MatrixIndexes, MatrixBlock>> 
call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
+                       //construct input
+                       IndexedMatrixValue in = 
SparkUtils.toIndexedMatrixBlock(arg0);
+
+                       //execute roll operation
+                       ArrayList<IndexedMatrixValue> out = new ArrayList<>();
+                       LibMatrixReorg.roll(in, _mcIn.getRows(), 
_mcIn.getBlocksize(), _shift, out);
+
+                       //construct output
+                       return 
SparkUtils.fromIndexedMatrixBlock(out).iterator();
+               }
+       }
+
        private static class ExtractColumn implements Function<MatrixBlock, 
MatrixBlock>  
        {
                private static final long serialVersionUID = 
-1472164797288449559L;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 53e8a88832..82defddca8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -445,6 +445,36 @@ public class LibMatrixReorg {
                return out;
        }
 
+       public static void roll(IndexedMatrixValue in, long rlen, int blen, int 
shift, ArrayList<IndexedMatrixValue> out) {
+               MatrixIndexes inMtxIdx = in.getIndexes();
+               MatrixBlock inMtxBlk = (MatrixBlock) in.getValue();
+               shift %= ((rlen != 0) ? (int) rlen : 1); // Handle row length 
boundaries for shift
+
+               long inRowIdx = 
UtilFunctions.computeCellIndex(inMtxIdx.getRowIndex(), blen, 0) - 1;
+
+               int totalCopyLen = 0;
+               while (totalCopyLen < inMtxBlk.getNumRows()) {
+                       // Calculate row and block index for the current part
+                       long outRowIdx = (inRowIdx + shift) % rlen;
+                       long outBlkIdx = 
UtilFunctions.computeBlockIndex(outRowIdx + 1, blen);
+                       int outBlkLen = UtilFunctions.computeBlockSize(rlen, 
outBlkIdx, blen);
+                       int outRowIdxInBlk = (int) (outRowIdx % blen);
+
+                       // Calculate copy length
+                       int copyLen = Math.min((int) (outBlkLen - 
outRowIdxInBlk), inMtxBlk.getNumRows() - totalCopyLen);
+
+                       // Create the output block and copy data
+                       MatrixIndexes outMtxIdx = new MatrixIndexes(outBlkIdx, 
inMtxIdx.getColumnIndex());
+                       MatrixBlock outMtxBlk = new MatrixBlock(outBlkLen, 
inMtxBlk.getNumColumns(), inMtxBlk.isInSparseFormat());
+                       copyMtx(inMtxBlk, outMtxBlk, totalCopyLen, 
outRowIdxInBlk, copyLen, false, false);
+                       out.add(new IndexedMatrixValue(outMtxIdx, outMtxBlk));
+
+                       // Update counters for next iteration
+                       totalCopyLen += copyLen;
+                       inRowIdx += totalCopyLen;
+               }
+       }
+
        public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) {
                //Timing time = new Timing(true);
                
@@ -2274,77 +2304,85 @@ public class LibMatrixReorg {
 
        private static void rollDense(MatrixBlock in, MatrixBlock out, int 
shift) {
                final int m = in.rlen;
-               final int n = in.clen;
+               shift %= (m != 0 ? m : 1); // roll matrix with axis=none
 
-               //set basic meta data and allocate output
-               out.sparse = false;
-               out.nonZeros = in.nonZeros;
-               out.allocateDenseBlock(false);
+               copyDenseMtx(in, out, 0, shift, m - shift, false, true);
+               copyDenseMtx(in, out, m - shift, 0, shift, true, true);
+       }
 
-               //copy all rows into target positions
-               if (n == 1) { //column vector
+       private static void rollSparse(MatrixBlock in, MatrixBlock out, int 
shift) {
+               final int m = in.rlen;
+               shift %= (m != 0 ? m : 1); // roll matrix with axis=0
+
+               copySparseMtx(in, out, 0, shift, m - shift, false, true);
+               copySparseMtx(in, out, m-shift, 0, shift, false, true);
+       }
+
+       public static void copyMtx(MatrixBlock in, MatrixBlock out, int 
inStart, int outStart, int copyLen,
+                                                          boolean isAllocated, 
boolean copyTotalNonZeros) {
+               if (in.isInSparseFormat()){
+                       copySparseMtx(in, out, inStart, outStart, copyLen, 
isAllocated, copyTotalNonZeros);
+               } else {
+                       copyDenseMtx(in, out, inStart, outStart, copyLen, 
isAllocated, copyTotalNonZeros);
+               }
+       }
+
+       public static void copyDenseMtx(MatrixBlock in, MatrixBlock out, int 
inIdx, int outIdx, int copyLen,
+                                                                       boolean 
isAllocated, boolean copyTotalNonZeros) {
+               int clen = in.clen;
+
+               // set basic meta data and allocate output
+               if (!isAllocated){
+                       out.sparse = false;
+                       if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
+                       out.allocateDenseBlock(false);
+               }
+
+               // copy all rows into target positions
+               if (clen == 1) { //column vector
                        double[] a = in.getDenseBlockValues();
                        double[] c = out.getDenseBlockValues();
 
-                       // roll matrix with axis=none
-                       shift %= (m != 0 ? m : 1);
-
-                       System.arraycopy(a, 0, c, shift, m - shift);
-                       System.arraycopy(a, m - shift, c, 0, shift);
-               } else { //general matrix case
+                       System.arraycopy(a, inIdx, c, outIdx, copyLen);
+               } else {
                        DenseBlock a = in.getDenseBlock();
                        DenseBlock c = out.getDenseBlock();
 
-                       // roll matrix with axis=0
-                       shift %= (m != 0 ? m : 1);
+                       while (copyLen > 0) {
+                               System.arraycopy(a.values(inIdx), a.pos(inIdx),
+                                               c.values(outIdx), 
c.pos(outIdx), clen);
 
-                       for (int i = 0; i < m - shift; i++) {
-                               System.arraycopy(a.values(i), a.pos(i), 
c.values(i + shift), c.pos(i + shift), n);
-                       }
-
-                       for (int i = m - shift; i < m; i++) {
-                               System.arraycopy(a.values(i), a.pos(i), 
c.values(i + shift - m), c.pos(i + shift - m), n);
+                               inIdx++; outIdx++; copyLen--;
                        }
                }
        }
 
-       private static void rollSparse(MatrixBlock in, MatrixBlock out, int 
shift) {
-               final int m = in.rlen;
-
+       private static void copySparseMtx(MatrixBlock in, MatrixBlock out, int 
inIdx, int outIdx, int copyLen,
+                                                                         
boolean isAllocated, boolean copyTotalNonZeros) {
                //set basic meta data and allocate output
-               out.sparse = true;
-               out.nonZeros = in.nonZeros;
-               out.allocateSparseRowsBlock(false);
+               if (!isAllocated){
+                       out.sparse = true;
+                       if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
+                       out.allocateSparseRowsBlock(false);
+               }
 
-               //copy all rows into target positions
                SparseBlock a = in.getSparseBlock();
                SparseBlock c = out.getSparseBlock();
 
-               // roll matrix with axis=0
-               shift %= (m != 0 ? m : 1);
-
-               for (int i = 0; i < m - shift; i++) {
-                       if (a.isEmpty(i)) continue;    // skip empty rows
+               while (copyLen > 0) {
+                       if (a.isEmpty(inIdx)) continue;    // skip empty rows
 
-                       rollSparseRow(a, c, i, i + shift);
-               }
-
-               for (int i = m - shift; i < m; i++) {
-                       if (a.isEmpty(i)) continue;    // skip empty rows
-
-                       rollSparseRow(a, c, i, i + shift - m);
-               }
-       }
+                       final int apos = a.pos(inIdx);
+                       final int alen = a.size(inIdx) + apos;
+                       final int[] aix = a.indexes(inIdx);
+                       final double[] avals = a.values(inIdx);
 
-       private static void rollSparseRow(SparseBlock a, SparseBlock c, int 
oriIdx, int shiftIdx) {
-               final int apos = a.pos(oriIdx);
-               final int alen = a.size(oriIdx) + apos;
-               final int[] aix = a.indexes(oriIdx);
-               final double[] avals = a.values(oriIdx);
+                       // copy only non-zero elements
+                       for (int k = apos; k < alen; k++) {
+                               c.set(outIdx, aix[k], avals[k]);
+                       }
 
-               // copy only non-zero elements
-               for (int k = apos; k < alen; k++) {
-                       c.set(shiftIdx, aix[k], avals[k]);
+                       inIdx++; outIdx++; copyLen--;
                }
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java 
b/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java
new file mode 100644
index 0000000000..121d378a6e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.reorg;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+
+public class FullRollTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "Roll1";
+       //private final static String TEST_NAME2 = "Roll2";
+
+       private final static String TEST_DIR = "functions/reorg/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FullRollTest.class.getSimpleName() + "/";
+
+       private final static int rows1 = 2017;
+       private final static int cols1 = 1001;
+       private final static double sparsity1 = 0.7;
+       private final static double sparsity2 = 0.1;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B", "C"}));
+       }
+
+       @Test
+       public void testRollVectorDense() {
+               runRollTest(TEST_NAME1, false, false);
+       }
+
+       @Test
+       public void testRollVectorSparse() {
+               runRollTest(TEST_NAME1, false, true);
+       }
+
+       @Test
+       public void testRollMatrixDense() {
+               runRollTest(TEST_NAME1, true, false);
+       }
+
+       @Test
+       public void testRollMatrixSparse() {
+               runRollTest(TEST_NAME1, true, true);
+       }
+
+       private void runRollTest(String testname, boolean matrix, boolean 
sparse) {
+               //rtplatform for MR
+               ExecMode platformOld = rtplatform;
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+
+               String TEST_NAME = testname;
+
+               try {
+                       int cols = matrix ? cols1 : 1;
+                       double sparsity = sparse ? sparsity2 : sparsity1;
+                       getAndLoadTestConfiguration(TEST_NAME);
+
+                       /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+                       //generate actual dataset
+                       double[][] A = getRandomMatrix(rows1, cols, -1, 1, 
sparsity, 7);
+                       writeInputMatrixWithMTD("A", A, true);
+
+                       // Run test CP
+                       rtplatform = ExecMode.HYBRID;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = false;
+                       programArgs = new String[]{"-stats", "-explain", 
"-args", input("A"), output("B")};
+                       runTest(true, false, null, -1);
+                       boolean opcodeCP = 
Statistics.getCPHeavyHitterOpCodes().contains("roll");
+                       
+                       // Run test SP
+                       rtplatform = ExecMode.SPARK;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+                       programArgs = new String[]{"-stats", "-explain", 
"-args", input("A"), output("C")};
+                       runTest(true, false, null, -1);
+                       boolean opcodeSP = 
Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX + 
"roll");
+                       
+                       //compare matrices
+                       HashMap<CellIndex, Double> dmlfileCP = 
readDMLMatrixFromOutputDir("B");
+                       HashMap<CellIndex, Double> dmlfileSP = 
readDMLMatrixFromOutputDir("C");
+                       TestUtils.compareMatrices(dmlfileCP, dmlfileSP, 0, 
"Stat-DML-CP", "Stat-DML-SP");
+
+                       Assert.assertTrue("Missing opcode: roll", opcodeCP);
+                       Assert.assertTrue("Missing opcode: sp_roll", opcodeSP);
+               } finally {
+                       //reset flags
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/reorg/Roll1.dml 
b/src/test/scripts/functions/reorg/Roll1.dml
new file mode 100644
index 0000000000..8928fdf219
--- /dev/null
+++ b/src/test/scripts/functions/reorg/Roll1.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+B = roll(A, 1);
+write(B, $2);  
\ No newline at end of file

Reply via email to