This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 0825f0fab7 [SYSTEMDS-3729] Add roll reorg operations in SP
0825f0fab7 is described below
commit 0825f0fab741250ee3d5b3e0cd78ff62641cd250
Author: min-guk <[email protected]>
AuthorDate: Wed Sep 25 09:38:55 2024 +0200
[SYSTEMDS-3729] Add roll reorg operations in SP
Closes #2112.
---
.../java/org/apache/sysds/hops/AggUnaryOp.java | 1 -
.../runtime/instructions/SPInstructionParser.java | 1 +
.../instructions/spark/ReorgSPInstruction.java | 53 +++++++-
.../sysds/runtime/matrix/data/LibMatrixReorg.java | 138 +++++++++++++--------
.../sysds/test/functions/reorg/FullRollTest.java | 121 ++++++++++++++++++
src/test/scripts/functions/reorg/Roll1.dml | 25 ++++
6 files changed, 286 insertions(+), 53 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index eec86ec15b..954caa7a2e 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -31,7 +31,6 @@ import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
-import org.apache.sysds.lops.Nary;
import org.apache.sysds.lops.PartialAggregate;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.lops.UAggOuterChain;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 5e4dbaedeb..5c72b85436 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -166,6 +166,7 @@ public class SPInstructionParser extends InstructionParser
// Reorg Instruction Opcodes (repositioning of existing values)
String2SPInstructionType.put( "r'", SPType.Reorg);
String2SPInstructionType.put( "rev", SPType.Reorg);
+ String2SPInstructionType.put( "roll", SPType.Reorg);
String2SPInstructionType.put( "rdiag", SPType.Reorg);
String2SPInstructionType.put( "rshape", SPType.MatrixReshape);
String2SPInstructionType.put( "rsort", SPType.Reorg);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
index de01a71ca8..b096405959 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ReorgSPInstruction.java
@@ -36,6 +36,7 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
import org.apache.sysds.runtime.functionobjects.RevIndex;
+import org.apache.sysds.runtime.functionobjects.RollIndex;
import org.apache.sysds.runtime.functionobjects.SortIndex;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -68,6 +69,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
private CPOperand _desc = null;
private CPOperand _ixret = null;
private boolean _bSortIndInMem = false;
+ private CPOperand _shift = null;
private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out,
String opcode, String istr) {
super(SPType.Reorg, op, in, out, opcode, istr);
@@ -82,6 +84,11 @@ public class ReorgSPInstruction extends UnarySPInstruction {
_bSortIndInMem = bSortIndInMem;
}
+ private ReorgSPInstruction(Operator op, CPOperand in, CPOperand out,
CPOperand shift, String opcode, String istr) {
+ this(op, in, out, opcode, istr);
+ _shift = shift;
+ }
+
public static ReorgSPInstruction parseInstruction ( String str ) {
CPOperand in = new CPOperand("", ValueType.UNKNOWN,
DataType.UNKNOWN);
CPOperand out = new CPOperand("", ValueType.UNKNOWN,
DataType.UNKNOWN);
@@ -95,6 +102,15 @@ public class ReorgSPInstruction extends UnarySPInstruction {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgSPInstruction(new
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
}
+ else if (opcode.equalsIgnoreCase("roll")) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ InstructionUtils.checkNumFields(str, 3);
+ in.split(parts[1]);
+ out.split(parts[3]);
+ CPOperand shift = new CPOperand(parts[2]);
+ return new ReorgSPInstruction(new ReorgOperator(new
RollIndex(0)),
+ in, out, shift, opcode, str);
+}
else if ( opcode.equalsIgnoreCase("rdiag") ) {
parseUnaryInstruction(str, in, out); //max 2 operands
return new ReorgSPInstruction(new
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
@@ -141,6 +157,14 @@ public class ReorgSPInstruction extends UnarySPInstruction
{
if( mcIn.getRows() % mcIn.getBlocksize() != 0 )
out = RDDAggregateUtils.mergeByKey(out, false);
}
+ else if (opcode.equalsIgnoreCase("roll")) // ROLL
+ {
+ int shift = (int)
ec.getScalarInput(_shift).getLongValue();
+
+ //execute roll reorg operation
+ out = in1.flatMapToPair(new RDDRollFunction(mcIn,
shift));
+ out = RDDAggregateUtils.mergeByKey(out, false);
+ }
else if ( opcode.equalsIgnoreCase("rdiag") ) // DIAG
{
if(mcIn.getCols() == 1) { // diagV2M
@@ -233,7 +257,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
boolean ixret =
sec.getScalarInput(_ixret).getBooleanValue();
mcOut.set(mc1.getRows(), ixret?1:mc1.getCols(),
mc1.getBlocksize(), mc1.getBlocksize());
}
- else { //e.g., rev
+ else { //e.g., rev, roll
mcOut.set(mc1);
}
}
@@ -243,7 +267,7 @@ public class ReorgSPInstruction extends UnarySPInstruction {
boolean sortIx = getOpcode().equalsIgnoreCase("rsort")
&& sec.getScalarInput(_ixret).getBooleanValue();
if( sortIx )
mcOut.setNonZeros(mc1.getRows());
- else //default (r', rdiag, rev, rsort data)
+ else //default (r', rdiag, rev, roll, rsort data)
mcOut.setNonZeros(mc1.getNonZeros());
}
}
@@ -315,6 +339,31 @@ public class ReorgSPInstruction extends UnarySPInstruction
{
}
}
+ private static class RDDRollFunction implements
PairFlatMapFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes,
MatrixBlock> {
+ private static final long serialVersionUID =
1183373828539843938L;
+
+ private DataCharacteristics _mcIn = null;
+ private int _shift = 0;
+
+ public RDDRollFunction(DataCharacteristics mcIn, int shift) {
+ _mcIn = mcIn;
+ _shift = shift;
+ }
+
+ @Override
+ public Iterator<Tuple2<MatrixIndexes, MatrixBlock>>
call(Tuple2<MatrixIndexes, MatrixBlock> arg0) {
+ //construct input
+ IndexedMatrixValue in =
SparkUtils.toIndexedMatrixBlock(arg0);
+
+ //execute roll operation
+ ArrayList<IndexedMatrixValue> out = new ArrayList<>();
+ LibMatrixReorg.roll(in, _mcIn.getRows(),
_mcIn.getBlocksize(), _shift, out);
+
+ //construct output
+ return
SparkUtils.fromIndexedMatrixBlock(out).iterator();
+ }
+ }
+
private static class ExtractColumn implements Function<MatrixBlock,
MatrixBlock>
{
private static final long serialVersionUID =
-1472164797288449559L;
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
index 53e8a88832..82defddca8 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixReorg.java
@@ -445,6 +445,36 @@ public class LibMatrixReorg {
return out;
}
+ public static void roll(IndexedMatrixValue in, long rlen, int blen, int
shift, ArrayList<IndexedMatrixValue> out) {
+ MatrixIndexes inMtxIdx = in.getIndexes();
+ MatrixBlock inMtxBlk = (MatrixBlock) in.getValue();
+ shift %= ((rlen != 0) ? (int) rlen : 1); // Handle row length
boundaries for shift
+
+ long inRowIdx =
UtilFunctions.computeCellIndex(inMtxIdx.getRowIndex(), blen, 0) - 1;
+
+ int totalCopyLen = 0;
+ while (totalCopyLen < inMtxBlk.getNumRows()) {
+ // Calculate row and block index for the current part
+ long outRowIdx = (inRowIdx + shift) % rlen;
+ long outBlkIdx =
UtilFunctions.computeBlockIndex(outRowIdx + 1, blen);
+ int outBlkLen = UtilFunctions.computeBlockSize(rlen,
outBlkIdx, blen);
+ int outRowIdxInBlk = (int) (outRowIdx % blen);
+
+ // Calculate copy length
+ int copyLen = Math.min((int) (outBlkLen -
outRowIdxInBlk), inMtxBlk.getNumRows() - totalCopyLen);
+
+ // Create the output block and copy data
+ MatrixIndexes outMtxIdx = new MatrixIndexes(outBlkIdx,
inMtxIdx.getColumnIndex());
+ MatrixBlock outMtxBlk = new MatrixBlock(outBlkLen,
inMtxBlk.getNumColumns(), inMtxBlk.isInSparseFormat());
+ copyMtx(inMtxBlk, outMtxBlk, totalCopyLen,
outRowIdxInBlk, copyLen, false, false);
+ out.add(new IndexedMatrixValue(outMtxIdx, outMtxBlk));
+
+ // Update counters for next iteration
+ totalCopyLen += copyLen;
+ inRowIdx += totalCopyLen;
+ }
+ }
+
public static MatrixBlock diag( MatrixBlock in, MatrixBlock out ) {
//Timing time = new Timing(true);
@@ -2274,77 +2304,85 @@ public class LibMatrixReorg {
private static void rollDense(MatrixBlock in, MatrixBlock out, int
shift) {
final int m = in.rlen;
- final int n = in.clen;
+ shift %= (m != 0 ? m : 1); // roll matrix with axis=none
- //set basic meta data and allocate output
- out.sparse = false;
- out.nonZeros = in.nonZeros;
- out.allocateDenseBlock(false);
+ copyDenseMtx(in, out, 0, shift, m - shift, false, true);
+ copyDenseMtx(in, out, m - shift, 0, shift, true, true);
+ }
- //copy all rows into target positions
- if (n == 1) { //column vector
+ private static void rollSparse(MatrixBlock in, MatrixBlock out, int
shift) {
+ final int m = in.rlen;
+ shift %= (m != 0 ? m : 1); // roll matrix with axis=0
+
+ copySparseMtx(in, out, 0, shift, m - shift, false, true);
+ copySparseMtx(in, out, m-shift, 0, shift, false, true);
+ }
+
+ public static void copyMtx(MatrixBlock in, MatrixBlock out, int
inStart, int outStart, int copyLen,
+ boolean isAllocated,
boolean copyTotalNonZeros) {
+ if (in.isInSparseFormat()){
+ copySparseMtx(in, out, inStart, outStart, copyLen,
isAllocated, copyTotalNonZeros);
+ } else {
+ copyDenseMtx(in, out, inStart, outStart, copyLen,
isAllocated, copyTotalNonZeros);
+ }
+ }
+
+ public static void copyDenseMtx(MatrixBlock in, MatrixBlock out, int
inIdx, int outIdx, int copyLen,
+ boolean
isAllocated, boolean copyTotalNonZeros) {
+ int clen = in.clen;
+
+ // set basic meta data and allocate output
+ if (!isAllocated){
+ out.sparse = false;
+ if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
+ out.allocateDenseBlock(false);
+ }
+
+ // copy all rows into target positions
+ if (clen == 1) { //column vector
double[] a = in.getDenseBlockValues();
double[] c = out.getDenseBlockValues();
- // roll matrix with axis=none
- shift %= (m != 0 ? m : 1);
-
- System.arraycopy(a, 0, c, shift, m - shift);
- System.arraycopy(a, m - shift, c, 0, shift);
- } else { //general matrix case
+ System.arraycopy(a, inIdx, c, outIdx, copyLen);
+ } else {
DenseBlock a = in.getDenseBlock();
DenseBlock c = out.getDenseBlock();
- // roll matrix with axis=0
- shift %= (m != 0 ? m : 1);
+ while (copyLen > 0) {
+ System.arraycopy(a.values(inIdx), a.pos(inIdx),
+ c.values(outIdx),
c.pos(outIdx), clen);
- for (int i = 0; i < m - shift; i++) {
- System.arraycopy(a.values(i), a.pos(i),
c.values(i + shift), c.pos(i + shift), n);
- }
-
- for (int i = m - shift; i < m; i++) {
- System.arraycopy(a.values(i), a.pos(i),
c.values(i + shift - m), c.pos(i + shift - m), n);
+ inIdx++; outIdx++; copyLen--;
}
}
}
- private static void rollSparse(MatrixBlock in, MatrixBlock out, int
shift) {
- final int m = in.rlen;
-
+ private static void copySparseMtx(MatrixBlock in, MatrixBlock out, int
inIdx, int outIdx, int copyLen,
+
boolean isAllocated, boolean copyTotalNonZeros) {
//set basic meta data and allocate output
- out.sparse = true;
- out.nonZeros = in.nonZeros;
- out.allocateSparseRowsBlock(false);
+ if (!isAllocated){
+ out.sparse = true;
+ if (copyTotalNonZeros) out.nonZeros = in.nonZeros;
+ out.allocateSparseRowsBlock(false);
+ }
- //copy all rows into target positions
SparseBlock a = in.getSparseBlock();
SparseBlock c = out.getSparseBlock();
- // roll matrix with axis=0
- shift %= (m != 0 ? m : 1);
-
- for (int i = 0; i < m - shift; i++) {
- if (a.isEmpty(i)) continue; // skip empty rows
+ while (copyLen > 0) {
+ if (a.isEmpty(inIdx)) continue; // skip empty rows
- rollSparseRow(a, c, i, i + shift);
- }
-
- for (int i = m - shift; i < m; i++) {
- if (a.isEmpty(i)) continue; // skip empty rows
-
- rollSparseRow(a, c, i, i + shift - m);
- }
- }
+ final int apos = a.pos(inIdx);
+ final int alen = a.size(inIdx) + apos;
+ final int[] aix = a.indexes(inIdx);
+ final double[] avals = a.values(inIdx);
- private static void rollSparseRow(SparseBlock a, SparseBlock c, int
oriIdx, int shiftIdx) {
- final int apos = a.pos(oriIdx);
- final int alen = a.size(oriIdx) + apos;
- final int[] aix = a.indexes(oriIdx);
- final double[] avals = a.values(oriIdx);
+ // copy only non-zero elements
+ for (int k = apos; k < alen; k++) {
+ c.set(outIdx, aix[k], avals[k]);
+ }
- // copy only non-zero elements
- for (int k = apos; k < alen; k++) {
- c.set(shiftIdx, aix[k], avals[k]);
+ inIdx++; outIdx++; copyLen--;
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java
b/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java
new file mode 100644
index 0000000000..121d378a6e
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/reorg/FullRollTest.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.reorg;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+
+public class FullRollTest extends AutomatedTestBase {
+ private final static String TEST_NAME1 = "Roll1";
+ //private final static String TEST_NAME2 = "Roll2";
+
+ private final static String TEST_DIR = "functions/reorg/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
FullRollTest.class.getSimpleName() + "/";
+
+ private final static int rows1 = 2017;
+ private final static int cols1 = 1001;
+ private final static double sparsity1 = 0.7;
+ private final static double sparsity2 = 0.1;
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[]{"B", "C"}));
+ }
+
+ @Test
+ public void testRollVectorDense() {
+ runRollTest(TEST_NAME1, false, false);
+ }
+
+ @Test
+ public void testRollVectorSparse() {
+ runRollTest(TEST_NAME1, false, true);
+ }
+
+ @Test
+ public void testRollMatrixDense() {
+ runRollTest(TEST_NAME1, true, false);
+ }
+
+ @Test
+ public void testRollMatrixSparse() {
+ runRollTest(TEST_NAME1, true, true);
+ }
+
+ private void runRollTest(String testname, boolean matrix, boolean
sparse) {
+ //rtplatform for MR
+ ExecMode platformOld = rtplatform;
+ boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+
+ String TEST_NAME = testname;
+
+ try {
+ int cols = matrix ? cols1 : 1;
+ double sparsity = sparse ? sparsity2 : sparsity1;
+ getAndLoadTestConfiguration(TEST_NAME);
+
+ /* This is for running the junit test the new way,
i.e., construct the arguments directly */
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+
+ //generate actual dataset
+ double[][] A = getRandomMatrix(rows1, cols, -1, 1,
sparsity, 7);
+ writeInputMatrixWithMTD("A", A, true);
+
+ // Run test CP
+ rtplatform = ExecMode.HYBRID;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = false;
+ programArgs = new String[]{"-stats", "-explain",
"-args", input("A"), output("B")};
+ runTest(true, false, null, -1);
+ boolean opcodeCP =
Statistics.getCPHeavyHitterOpCodes().contains("roll");
+
+ // Run test SP
+ rtplatform = ExecMode.SPARK;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+ programArgs = new String[]{"-stats", "-explain",
"-args", input("A"), output("C")};
+ runTest(true, false, null, -1);
+ boolean opcodeSP =
Statistics.getCPHeavyHitterOpCodes().contains(Instruction.SP_INST_PREFIX +
"roll");
+
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfileCP =
readDMLMatrixFromOutputDir("B");
+ HashMap<CellIndex, Double> dmlfileSP =
readDMLMatrixFromOutputDir("C");
+ TestUtils.compareMatrices(dmlfileCP, dmlfileSP, 0,
"Stat-DML-CP", "Stat-DML-SP");
+
+ Assert.assertTrue("Missing opcode: roll", opcodeCP);
+ Assert.assertTrue("Missing opcode: sp_roll", opcodeSP);
+ } finally {
+ //reset flags
+ rtplatform = platformOld;
+ DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/reorg/Roll1.dml
b/src/test/scripts/functions/reorg/Roll1.dml
new file mode 100644
index 0000000000..8928fdf219
--- /dev/null
+++ b/src/test/scripts/functions/reorg/Roll1.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1);
+B = roll(A, 1);
+write(B, $2);
\ No newline at end of file