This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 4b3fa93abf [SYSTEMDS-3695] Fix spark frame cbind for misaligned inputs
4b3fa93abf is described below
commit 4b3fa93abfd43b9101c09a5bba5094992112c27a
Author: e-strauss <[email protected]>
AuthorDate: Wed Jun 5 09:30:43 2024 +0200
[SYSTEMDS-3695] Fix spark frame cbind for misaligned inputs
Closes #2031.
---
.../spark/FrameAppendRSPInstruction.java | 102 +++++++++++++++++++--
.../test/functions/frame/FrameAppendDistTest.java | 74 +++++++++++++--
.../functions/frame/FrameNAryAppendMisalignRSP.R | 35 +++++++
.../functions/frame/FrameNAryAppendMisalignRSP.dml | 31 +++++++
.../functions/frame/FrameNAryAppendMisalignRSP2.R | 31 +++++++
.../frame/FrameNAryAppendMisalignRSP2.dml | 28 ++++++
6 files changed, 288 insertions(+), 13 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
index af1be2b0c4..8774c63ed7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.spark;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -31,6 +32,11 @@ import
org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;
import scala.Tuple2;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Iterator;
+import java.util.List;
+
public class FrameAppendRSPInstruction extends AppendRSPInstruction {
protected FrameAppendRSPInstruction(Operator op, CPOperand in1,
CPOperand in2, CPOperand out, boolean cbind,
@@ -43,7 +49,7 @@ public class FrameAppendRSPInstruction extends
AppendRSPInstruction {
SparkExecutionContext sec = (SparkExecutionContext)ec;
JavaPairRDD<Long,FrameBlock> in1 =
sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() );
JavaPairRDD<Long,FrameBlock> in2 =
sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() );
- JavaPairRDD<Long,FrameBlock> out = null;
+ JavaPairRDD<Long,FrameBlock> out;
long leftRows =
sec.getDataCharacteristics(input1.getName()).getRows();
out = appendFrameRSP(in1, in2, leftRows, _cbind);
@@ -65,11 +71,14 @@ public class FrameAppendRSPInstruction extends
AppendRSPInstruction {
public static JavaPairRDD<Long, FrameBlock>
appendFrameRSP(JavaPairRDD<Long, FrameBlock> in1, JavaPairRDD<Long, FrameBlock>
in2, long leftRows, boolean cbind) {
if(cbind) {
- JavaPairRDD<Long,FrameBlock> in1Aligned =
in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
- in1Aligned =
FrameRDDAggregateUtils.mergeByKey(in1Aligned);
- JavaPairRDD<Long,FrameBlock> in2Aligned =
in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows));
+ //TODO preserve info if already aligned, and only align
if necessary
+ //get in1 keys
+ long[] row_indices =
in1.keys().collect().stream().mapToLong(Long::longValue).toArray();
+ Arrays.sort(row_indices);
+ //Align the blocks of in2 on the blocks of in1
+ JavaPairRDD<Long,FrameBlock> in2Aligned =
in2.flatMapToPair(new ReduceSideAppendAlignToLHSFunction(row_indices,
leftRows));
in2Aligned =
FrameRDDAggregateUtils.mergeByKey(in2Aligned);
- return in1Aligned.join(in2Aligned).mapValues(new
ReduceSideColumnsFunction(cbind));
+ return in1.join(in2Aligned).mapValues(new
ReduceSideColumnsFunction(cbind));
} else { //rbind
JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new
ReduceSideAppendRowsFunction(leftRows));
return in1.union(right);
@@ -94,6 +103,86 @@ public class FrameAppendRSPInstruction extends
AppendRSPInstruction {
}
}
+ private static class ReduceSideAppendAlignToLHSFunction implements
PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
+ {
+ private static final long serialVersionUID =
5850400295183766409L;
+
+ private final long[] _indices;
+ private final long lastIndex; //max_rows + 1
+
+ public ReduceSideAppendAlignToLHSFunction(long[] indices, long
max_rows) {
+ _indices = indices;
+ lastIndex = max_rows + 1;
+ }
+
+ @Override
+ public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long,
FrameBlock> arg0)
+ {
+ List<Tuple2<Long, FrameBlock>> aligned_blocks = new
ArrayList<>();
+ long indexRHS = arg0._1();
+ FrameBlock fb = arg0._2();
+
+ //find the block index ix in the LHS with the smallest
index s.t. following LHS indix ix' > indexRHS >= ix
+ //doing binary search
+ int L = 0;
+ int R = _indices.length - 1;
+ int m;
+ while(L <= R){
+ m = (L + R) / 2;
+ if(_indices[m] == indexRHS){
+ R = m;
+ break;
+ }
+ if(_indices[m] < indexRHS)
+ L = m + 1;
+ else
+ R = m - 1;
+ }
+ // search terminates if we have found the exact
indexRHS or binary search reached the leaf nodes where
+ // L == R (bucket size = 1) and m == L which implies
that _indices[m+1] > indexRHS (otherwise we would
+ // have considered this index in the search
+ // if _indices[m] < indexRHS than m contains the index
which fits our definition and R == m
+ // else (m - 1) fits our definition which is stored and
R = m - 1
+ // therefore in all cases the correct position of the
indexLHS is stored in R
+ long indexLHS = _indices[R];
+
+ //assumes total num rows LHS == RHS
+ long nextIndexLHS = R < _indices.length - 1?
_indices[R+1] : this.lastIndex;
+ int blkSizeLHS = (int) (nextIndexLHS - indexLHS);
+ int offsetLHS = (int) (indexRHS - indexLHS);
+ int offsetRHS = 0;
+ int sizeOfSlice = blkSizeLHS - offsetLHS;
+
+ FrameBlock resultBlock = new FrameBlock(fb.getSchema());
+ resultBlock.ensureAllocatedColumns(blkSizeLHS);
+
+ int sizeOfRHS = fb.getNumRows();
+ while(sizeOfSlice < sizeOfRHS){
+ FrameBlock fb_sliced = fb.slice(offsetRHS,
offsetRHS + sizeOfSlice - 1);
+ resultBlock =
resultBlock.leftIndexingOperations(fb_sliced,offsetLHS, offsetLHS + sizeOfSlice
- 1, 0, fb.getNumColumns()-1, new FrameBlock());
+ aligned_blocks.add(new Tuple2<>(indexLHS,
resultBlock));
+ resultBlock = new FrameBlock(fb.getSchema());
+ if(R >= _indices.length - 1)
+ throw new RuntimeException("Alignment
Error while CBIND: LHS has fewer rows than RHS");
+ indexLHS = nextIndexLHS;
+ offsetRHS += sizeOfSlice;
+ offsetLHS = 0;
+ sizeOfRHS -= sizeOfSlice;
+ R++;
+ nextIndexLHS = R < _indices.length - 1?
_indices[R+1] : this.lastIndex;
+ sizeOfSlice = (int) (nextIndexLHS - indexLHS);
//sizeOfSlice = blkSizeLHS
+ resultBlock.ensureAllocatedColumns(sizeOfSlice);
+ }
+ //RHS fits into aligned LHS block
+ if(offsetRHS != 0)
+ fb = fb.slice(offsetRHS, offsetRHS + sizeOfRHS
- 1);
+ resultBlock = resultBlock.leftIndexingOperations(fb,
offsetLHS, offsetLHS + fb.getNumRows() - 1, 0, fb.getNumColumns()-1, new
FrameBlock());
+ aligned_blocks.add(new Tuple2<>(indexLHS, resultBlock));
+
+ return aligned_blocks.iterator();
+ }
+ }
+
private static class ReduceSideAppendRowsFunction implements
PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
{
private static final long serialVersionUID =
1723795153048336791L;
@@ -112,12 +201,13 @@ public class FrameAppendRSPInstruction extends
AppendRSPInstruction {
}
}
+ @SuppressWarnings("unused")
private static class ReduceSideAppendAlignFunction implements
PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock>
{
private static final long serialVersionUID =
5850400295183766409L;
private long _rows;
-
+
public ReduceSideAppendAlignFunction(long rows) {
_rows = rows;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
b/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
index c6cde96167..ebf6682742 100644
---
a/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java
@@ -43,6 +43,9 @@ public class FrameAppendDistTest extends AutomatedTestBase
private final static String TEST_NAME = "FrameAppend";
private final static String TEST_NAME2 = "FrameNAryAppend";
private final static String TEST_NAME3 = "FrameNAryAppendMisalign";
+ private final static String TEST_NAME4 = "FrameNAryAppendMisalignRSP";
+ private final static String TEST_NAME5 = "FrameNAryAppendMisalignRSP2";
+
private final static String TEST_DIR = "functions/frame/";
private final static String TEST_CLASS_DIR = TEST_DIR +
FrameAppendDistTest.class.getSimpleName() + "/";
@@ -71,17 +74,19 @@ public class FrameAppendDistTest extends AutomatedTestBase
addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"}));
addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[] {"C"}));
addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3,new String[] {"C"}));
+ addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4,new String[] {"C"}));
+ addTestConfiguration(TEST_NAME5, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME5,new String[] {"C"}));
}
@Test
public void testAppendInBlock1DenseSP() {
commonAppendTest(ExecMode.SPARK, rows1, rows1, cols1a, cols2a,
false, AppendMethod.MR_RAPPEND, false, TEST_NAME);
- }
+ }
@Test
public void testAppendInBlock1SparseSP() {
commonAppendTest(ExecMode.SPARK, rows1, rows1, cols1a, cols2a,
true, AppendMethod.MR_RAPPEND, false, TEST_NAME);
- }
+ }
@Test
public void testAppendInBlock1DenseRBindSP() {
@@ -116,22 +121,77 @@ public class FrameAppendDistTest extends AutomatedTestBase
@Test
public void testNAryCAppendMSP(){
- commonAppendTest(ExecMode.SPARK ,100, 100, 5, 10, false, null,
false, TEST_NAME2);;
+ commonAppendTest(ExecMode.SPARK ,100, 100, 5, 10, false, null,
false, TEST_NAME2);
}
@Test
public void testNAryCAppendRSP(){
- commonAppendTest(ExecMode.SPARK ,30, 30, 5, 1001, false, null,
false, TEST_NAME2);;
+ commonAppendTest(ExecMode.SPARK ,30, 30, 5, 1001, false, null,
false, TEST_NAME2);
}
@Test
public void testNAryRAppendSP(){
- commonAppendTest(ExecMode.SPARK ,100, 100, 5, 5, false, null,
true, TEST_NAME2);;
+ commonAppendTest(ExecMode.SPARK ,100, 100, 5, 5, false, null,
true, TEST_NAME2);
}
@Test
public void testNAryAppendWithMisalignmentMSP(){
- commonAppendTest(ExecMode.SPARK ,5, 10, 5, 5, false, null,
false, TEST_NAME3);;
+ commonAppendTest(ExecMode.SPARK ,5, 10, 5, 5, false, null,
false, TEST_NAME3);
+ }
+
+ @Test
+ public void testNAryAppendWithMisalignmentRSP() {
+ commonAppendTest(ExecMode.SPARK, 5, 10, 1001, 1001, false,
null, false, TEST_NAME3);
+ }
+
+// NAryAppendWithMisalignmentRSP2:
+// LHS: RHS:
+// +---------+ +-----+
+// | | +-----+
+// | | +-----+
+// | | +-----+
+// +---------+ +-----+
+ @Test
+ public void testNAryAppendWithMisalignmentRSP2(){
+ commonAppendTest(ExecMode.SPARK ,20, 5, 1001, 1005, false,
null, false, TEST_NAME4);
+ }
+
+// NAryAppendWithMisalignmentRSP3:
+// LHS: RHS:
+// +-----+ +---------+
+// +-----+ | |
+// +-----+ | |
+// +-----+ | |
+// +-----+ +---------+
+ @Test
+ public void testNAryAppendWithMisalignmentRSP3(){
+ commonAppendTest(ExecMode.SPARK ,5, 20, 1001, 1005, false,
null, false, TEST_NAME4);
+ }
+// NAryAppendWithMisalignmentRSP4:
+// LHS: RHS:
+// +-----+ +---------+
+// | | +---------+
+// +-----+ | |
+// | | +---------+
+// +-----+ | |
+// | | +---------+
+// +-----+ | |
+// +-----+ +---------+
+ @Test
+ public void testNAryAppendWithMisalignmentRSP4(){
+ commonAppendTest(ExecMode.SPARK ,20, 5, 1001, 1001, false,
null, false, TEST_NAME5);
+ }
+// NAryAppendWithMisalignmentRSP5:
+// LHS: RHS:
+// +-----+ +---------+
+// +-----+ | |
+// +-----+ +---------+
+// +-----+ +---------+
+// | | +---------+
+// +-----+ +---------+
+ @Test
+ public void testNAryAppendWithMisalignmentRSP5(){
+ commonAppendTest(ExecMode.SPARK ,8, 20, 1001, 1001, false,
null, false, TEST_NAME5);
}
@@ -142,7 +202,7 @@ public class FrameAppendDistTest extends AutomatedTestBase
ExecMode prevPlfm=rtplatform;
- double sparsity = (sparse) ? sparsity2 : sparsity1;
+ double sparsity = (sparse) ? sparsity2 : sparsity1;
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
//setOutputBuffering(true);
try
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R
new file mode 100644
index 0000000000..1ec13f690a
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R
@@ -0,0 +1,35 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE,
stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE,
stringsAsFactors=FALSE)
+if(nrow(A) > nrow(B)){
+ t=rbind(B, B, B, B)
+ C=cbind(A, t)
+} else {
+ t= rbind(A, A, A, A)
+ C=cbind(t, B)
+}
+write.csv(C, paste(args[2], "C.csv", sep=""), row.names = FALSE, quote = FALSE)
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml
new file mode 100644
index 0000000000..3633733da6
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A=read($1, data_type="frame", rows=$2, cols=$3, format="binary")
+B=read($4, data_type="frame", rows=$5, cols=$6, format="binary")
+if($2 > $5){
+ t=rbind(B, B, B, B)
+ C=cbind(A, t)
+} else {
+ t= rbind(A, A, A, A)
+ C=cbind(t, B)
+}
+write(C, $7, format="binary")
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R
new file mode 100644
index 0000000000..9265c9b3a4
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R
@@ -0,0 +1,31 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+args <- commandArgs(TRUE)
+options(digits=22)
+library("Matrix")
+
+A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE,
stringsAsFactors=FALSE)
+B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE,
stringsAsFactors=FALSE)
+t=rbind(B, A, A, A)
+t2=rbind(A, A, A, B)
+C=cbind(t2, t)
+write.csv(C, paste(args[2], "C.csv", sep=""), row.names = FALSE, quote = FALSE)
diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml
b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml
new file mode 100644
index 0000000000..c761c428fd
--- /dev/null
+++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A=read($1, data_type="frame", rows=$2, cols=$3, format="binary")
+B=read($4, data_type="frame", rows=$5, cols=$6, format="binary")
+
+t=rbind(B, A, A, A)
+t2=rbind(A, A, A, B)
+C=cbind(t2, t)
+write(C, $7, format="binary")