This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 4b3fa93abf [SYSTEMDS-3695] Fix spark frame cbind for misaligned inputs 4b3fa93abf is described below commit 4b3fa93abfd43b9101c09a5bba5094992112c27a Author: e-strauss <lathan...@gmx.de> AuthorDate: Wed Jun 5 09:30:43 2024 +0200 [SYSTEMDS-3695] Fix spark frame cbind for misaligned inputs Closes #2031. --- .../spark/FrameAppendRSPInstruction.java | 102 +++++++++++++++++++-- .../test/functions/frame/FrameAppendDistTest.java | 74 +++++++++++++-- .../functions/frame/FrameNAryAppendMisalignRSP.R | 35 +++++++ .../functions/frame/FrameNAryAppendMisalignRSP.dml | 31 +++++++ .../functions/frame/FrameNAryAppendMisalignRSP2.R | 31 +++++++ .../frame/FrameNAryAppendMisalignRSP2.dml | 28 ++++++ 6 files changed, 288 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java index af1be2b0c4..8774c63ed7 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/spark/FrameAppendRSPInstruction.java @@ -21,6 +21,7 @@ package org.apache.sysds.runtime.instructions.spark; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.PairFlatMapFunction; import org.apache.spark.api.java.function.PairFunction; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; @@ -31,6 +32,11 @@ import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDAggregateUtils; import org.apache.sysds.runtime.matrix.operators.Operator; import scala.Tuple2; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + public class FrameAppendRSPInstruction extends AppendRSPInstruction { protected FrameAppendRSPInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, boolean cbind, @@ -43,7 +49,7 @@ public class FrameAppendRSPInstruction extends AppendRSPInstruction { SparkExecutionContext sec = (SparkExecutionContext)ec; JavaPairRDD<Long,FrameBlock> in1 = sec.getFrameBinaryBlockRDDHandleForVariable( input1.getName() ); JavaPairRDD<Long,FrameBlock> in2 = sec.getFrameBinaryBlockRDDHandleForVariable( input2.getName() ); - JavaPairRDD<Long,FrameBlock> out = null; + JavaPairRDD<Long,FrameBlock> out; long leftRows = sec.getDataCharacteristics(input1.getName()).getRows(); out = appendFrameRSP(in1, in2, leftRows, _cbind); @@ -65,11 +71,14 @@ public class FrameAppendRSPInstruction extends AppendRSPInstruction { public static JavaPairRDD<Long, FrameBlock> appendFrameRSP(JavaPairRDD<Long, FrameBlock> in1, JavaPairRDD<Long, FrameBlock> in2, long leftRows, boolean cbind) { if(cbind) { - JavaPairRDD<Long,FrameBlock> in1Aligned = in1.mapToPair(new ReduceSideAppendAlignFunction(leftRows)); - in1Aligned = FrameRDDAggregateUtils.mergeByKey(in1Aligned); - JavaPairRDD<Long,FrameBlock> in2Aligned = in2.mapToPair(new ReduceSideAppendAlignFunction(leftRows)); + //TODO preserve info if already aligned, and only align if necessary + //get in1 keys + long[] row_indices = in1.keys().collect().stream().mapToLong(Long::longValue).toArray(); + Arrays.sort(row_indices); + //Align the blocks of in2 on the blocks of in1 + JavaPairRDD<Long,FrameBlock> in2Aligned = in2.flatMapToPair(new ReduceSideAppendAlignToLHSFunction(row_indices, leftRows)); in2Aligned = FrameRDDAggregateUtils.mergeByKey(in2Aligned); - return in1Aligned.join(in2Aligned).mapValues(new ReduceSideColumnsFunction(cbind)); + return in1.join(in2Aligned).mapValues(new ReduceSideColumnsFunction(cbind)); } else { //rbind JavaPairRDD<Long,FrameBlock> right = in2.mapToPair( new ReduceSideAppendRowsFunction(leftRows)); return in1.union(right); @@ -94,6 +103,86 @@ public class FrameAppendRSPInstruction extends AppendRSPInstruction { } } + private static class ReduceSideAppendAlignToLHSFunction implements PairFlatMapFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> + { + private static final long serialVersionUID = 5850400295183766409L; + + private final long[] _indices; + private final long lastIndex; //max_rows + 1 + + public ReduceSideAppendAlignToLHSFunction(long[] indices, long max_rows) { + _indices = indices; + lastIndex = max_rows + 1; + } + + @Override + public Iterator<Tuple2<Long, FrameBlock>> call(Tuple2<Long, FrameBlock> arg0) + { + List<Tuple2<Long, FrameBlock>> aligned_blocks = new ArrayList<>(); + long indexRHS = arg0._1(); + FrameBlock fb = arg0._2(); + + //find the block index ix in the LHS with the smallest index s.t. following LHS indix ix' > indexRHS >= ix + //doing binary search + int L = 0; + int R = _indices.length - 1; + int m; + while(L <= R){ + m = (L + R) / 2; + if(_indices[m] == indexRHS){ + R = m; + break; + } + if(_indices[m] < indexRHS) + L = m + 1; + else + R = m - 1; + } + // search terminates if we have found the exact indexRHS or binary search reached the leaf nodes where + // L == R (bucket size = 1) and m == L which implies that _indices[m+1] > indexRHS (otherwise we would + // have considered this index in the search + // if _indices[m] < indexRHS than m contains the index which fits our definition and R == m + // else (m - 1) fits our definition which is stored and R = m - 1 + // therefore in all cases the correct position of the indexLHS is stored in R + long indexLHS = _indices[R]; + + //assumes total num rows LHS == RHS + long nextIndexLHS = R < _indices.length - 1? _indices[R+1] : this.lastIndex; + int blkSizeLHS = (int) (nextIndexLHS - indexLHS); + int offsetLHS = (int) (indexRHS - indexLHS); + int offsetRHS = 0; + int sizeOfSlice = blkSizeLHS - offsetLHS; + + FrameBlock resultBlock = new FrameBlock(fb.getSchema()); + resultBlock.ensureAllocatedColumns(blkSizeLHS); + + int sizeOfRHS = fb.getNumRows(); + while(sizeOfSlice < sizeOfRHS){ + FrameBlock fb_sliced = fb.slice(offsetRHS, offsetRHS + sizeOfSlice - 1); + resultBlock = resultBlock.leftIndexingOperations(fb_sliced,offsetLHS, offsetLHS + sizeOfSlice - 1, 0, fb.getNumColumns()-1, new FrameBlock()); + aligned_blocks.add(new Tuple2<>(indexLHS, resultBlock)); + resultBlock = new FrameBlock(fb.getSchema()); + if(R >= _indices.length - 1) + throw new RuntimeException("Alignment Error while CBIND: LHS has fewer rows than RHS"); + indexLHS = nextIndexLHS; + offsetRHS += sizeOfSlice; + offsetLHS = 0; + sizeOfRHS -= sizeOfSlice; + R++; + nextIndexLHS = R < _indices.length - 1? _indices[R+1] : this.lastIndex; + sizeOfSlice = (int) (nextIndexLHS - indexLHS); //sizeOfSlice = blkSizeLHS + resultBlock.ensureAllocatedColumns(sizeOfSlice); + } + //RHS fits into aligned LHS block + if(offsetRHS != 0) + fb = fb.slice(offsetRHS, offsetRHS + sizeOfRHS - 1); + resultBlock = resultBlock.leftIndexingOperations(fb, offsetLHS, offsetLHS + fb.getNumRows() - 1, 0, fb.getNumColumns()-1, new FrameBlock()); + aligned_blocks.add(new Tuple2<>(indexLHS, resultBlock)); + + return aligned_blocks.iterator(); + } + } + private static class ReduceSideAppendRowsFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> { private static final long serialVersionUID = 1723795153048336791L; @@ -112,12 +201,13 @@ public class FrameAppendRSPInstruction extends AppendRSPInstruction { } } + @SuppressWarnings("unused") private static class ReduceSideAppendAlignFunction implements PairFunction<Tuple2<Long, FrameBlock>, Long, FrameBlock> { private static final long serialVersionUID = 5850400295183766409L; private long _rows; - + public ReduceSideAppendAlignFunction(long rows) { _rows = rows; } diff --git a/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java b/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java index c6cde96167..ebf6682742 100644 --- a/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java +++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameAppendDistTest.java @@ -43,6 +43,9 @@ public class FrameAppendDistTest extends AutomatedTestBase private final static String TEST_NAME = "FrameAppend"; private final static String TEST_NAME2 = "FrameNAryAppend"; private final static String TEST_NAME3 = "FrameNAryAppendMisalign"; + private final static String TEST_NAME4 = "FrameNAryAppendMisalignRSP"; + private final static String TEST_NAME5 = "FrameNAryAppendMisalignRSP2"; + private final static String TEST_DIR = "functions/frame/"; private final static String TEST_CLASS_DIR = TEST_DIR + FrameAppendDistTest.class.getSimpleName() + "/"; @@ -71,17 +74,19 @@ public class FrameAppendDistTest extends AutomatedTestBase addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"C"})); addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[] {"C"})); addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3,new String[] {"C"})); + addTestConfiguration(TEST_NAME4, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME4,new String[] {"C"})); + addTestConfiguration(TEST_NAME5, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME5,new String[] {"C"})); } @Test public void testAppendInBlock1DenseSP() { commonAppendTest(ExecMode.SPARK, rows1, rows1, cols1a, cols2a, false, AppendMethod.MR_RAPPEND, false, TEST_NAME); - } + } @Test public void testAppendInBlock1SparseSP() { commonAppendTest(ExecMode.SPARK, rows1, rows1, cols1a, cols2a, true, AppendMethod.MR_RAPPEND, false, TEST_NAME); - } + } @Test public void testAppendInBlock1DenseRBindSP() { @@ -116,22 +121,77 @@ public class FrameAppendDistTest extends AutomatedTestBase @Test public void testNAryCAppendMSP(){ - commonAppendTest(ExecMode.SPARK ,100, 100, 5, 10, false, null, false, TEST_NAME2);; + commonAppendTest(ExecMode.SPARK ,100, 100, 5, 10, false, null, false, TEST_NAME2); } @Test public void testNAryCAppendRSP(){ - commonAppendTest(ExecMode.SPARK ,30, 30, 5, 1001, false, null, false, TEST_NAME2);; + commonAppendTest(ExecMode.SPARK ,30, 30, 5, 1001, false, null, false, TEST_NAME2); } @Test public void testNAryRAppendSP(){ - commonAppendTest(ExecMode.SPARK ,100, 100, 5, 5, false, null, true, TEST_NAME2);; + commonAppendTest(ExecMode.SPARK ,100, 100, 5, 5, false, null, true, TEST_NAME2); } @Test public void testNAryAppendWithMisalignmentMSP(){ - commonAppendTest(ExecMode.SPARK ,5, 10, 5, 5, false, null, false, TEST_NAME3);; + commonAppendTest(ExecMode.SPARK ,5, 10, 5, 5, false, null, false, TEST_NAME3); + } + + @Test + public void testNAryAppendWithMisalignmentRSP() { + commonAppendTest(ExecMode.SPARK, 5, 10, 1001, 1001, false, null, false, TEST_NAME3); + } + +// NAryAppendWithMisalignmentRSP2: +// LHS: RHS: +// +---------+ +-----+ +// | | +-----+ +// | | +-----+ +// | | +-----+ +// +---------+ +-----+ + @Test + public void testNAryAppendWithMisalignmentRSP2(){ + commonAppendTest(ExecMode.SPARK ,20, 5, 1001, 1005, false, null, false, TEST_NAME4); + } + +// NAryAppendWithMisalignmentRSP3: +// LHS: RHS: +// +-----+ +---------+ +// +-----+ | | +// +-----+ | | +// +-----+ | | +// +-----+ +---------+ + @Test + public void testNAryAppendWithMisalignmentRSP3(){ + commonAppendTest(ExecMode.SPARK ,5, 20, 1001, 1005, false, null, false, TEST_NAME4); + } +// NAryAppendWithMisalignmentRSP4: +// LHS: RHS: +// +-----+ +---------+ +// | | +---------+ +// +-----+ | | +// | | +---------+ +// +-----+ | | +// | | +---------+ +// +-----+ | | +// +-----+ +---------+ + @Test + public void testNAryAppendWithMisalignmentRSP4(){ + commonAppendTest(ExecMode.SPARK ,20, 5, 1001, 1001, false, null, false, TEST_NAME5); + } +// NAryAppendWithMisalignmentRSP5: +// LHS: RHS: +// +-----+ +---------+ +// +-----+ | | +// +-----+ +---------+ +// +-----+ +---------+ +// | | +---------+ +// +-----+ +---------+ + @Test + public void testNAryAppendWithMisalignmentRSP5(){ + commonAppendTest(ExecMode.SPARK ,8, 20, 1001, 1001, false, null, false, TEST_NAME5); } @@ -142,7 +202,7 @@ public class FrameAppendDistTest extends AutomatedTestBase ExecMode prevPlfm=rtplatform; - double sparsity = (sparse) ? sparsity2 : sparsity1; + double sparsity = (sparse) ? sparsity2 : sparsity1; boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG; //setOutputBuffering(true); try diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R new file mode 100644 index 0000000000..1ec13f690a --- /dev/null +++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.R @@ -0,0 +1,35 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, stringsAsFactors=FALSE) +B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, stringsAsFactors=FALSE) +if(nrow(A) > nrow(B)){ + t=rbind(B, B, B, B) + C=cbind(A, t) +} else { + t= rbind(A, A, A, A) + C=cbind(t, B) +} +write.csv(C, paste(args[2], "C.csv", sep=""), row.names = FALSE, quote = FALSE) diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml new file mode 100644 index 0000000000..3633733da6 --- /dev/null +++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP.dml @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A=read($1, data_type="frame", rows=$2, cols=$3, format="binary") +B=read($4, data_type="frame", rows=$5, cols=$6, format="binary") +if($2 > $5){ + t=rbind(B, B, B, B) + C=cbind(A, t) +} else { + t= rbind(A, A, A, A) + C=cbind(t, B) +} +write(C, $7, format="binary") diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R new file mode 100644 index 0000000000..9265c9b3a4 --- /dev/null +++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.R @@ -0,0 +1,31 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +args <- commandArgs(TRUE) +options(digits=22) +library("Matrix") + +A=read.csv(paste(args[1], "A.csv", sep=""), header = FALSE, stringsAsFactors=FALSE) +B=read.csv(paste(args[1], "B.csv", sep=""), header = FALSE, stringsAsFactors=FALSE) +t=rbind(B, A, A, A) +t2=rbind(A, A, A, B) +C=cbind(t2, t) +write.csv(C, paste(args[2], "C.csv", sep=""), row.names = FALSE, quote = FALSE) diff --git a/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml new file mode 100644 index 0000000000..c761c428fd --- /dev/null +++ b/src/test/scripts/functions/frame/FrameNAryAppendMisalignRSP2.dml @@ -0,0 +1,28 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +A=read($1, data_type="frame", rows=$2, cols=$3, format="binary") +B=read($4, data_type="frame", rows=$5, cols=$6, format="binary") + +t=rbind(B, A, A, A) +t2=rbind(A, A, A, B) +C=cbind(t2, t) +write(C, $7, format="binary")