This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 5aa0307f18 [SYSTEMDS-3253] New native operator for union distinct
5aa0307f18 is described below
commit 5aa0307f180776cef4113e00afe6a5208f7548f2
Author: Chi-Hsin Huang <[email protected]>
AuthorDate: Sat Jul 12 17:41:29 2025 +0200
[SYSTEMDS-3253] New native operator for union distinct
This patch refines the current union operation to an internal LOP
operation. Currently, two subsequent operations -- rbind() and unique()
are used to perform the union operation. We rewrite the operation with
an internal LOP that uses a HashSet to compute the unique entries and
returns them in a matrix. This improves the efficiency of the
operation, as it avoids unique(). The order of the input entries is
preserved in the output.
Closes #2286.
---
.../org/apache/sysds/common/InstructionType.java | 1 +
src/main/java/org/apache/sysds/common/Opcodes.java | 2 +
src/main/java/org/apache/sysds/common/Types.java | 3 +-
.../RewriteAlgebraicSimplificationStatic.java | 20 +++++
.../runtime/instructions/CPInstructionParser.java | 4 +
.../runtime/instructions/cp/CPInstruction.java | 1 +
.../instructions/cp/UnionCPInstruction.java | 60 ++++++++++++++
.../sysds/runtime/matrix/data/MatrixBlock.java | 89 ++++++++++++++++++++-
.../rewrite/RewriteSimplifyUnionDistinctTest.java | 92 ++++++++++++++++++++++
.../functions/rewrite/RewriteSimplifyUnion.R | 50 ++++++++++++
.../functions/rewrite/RewriteSimplifyUnion.dml | 41 ++++++++++
11 files changed, 360 insertions(+), 3 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java
b/src/main/java/org/apache/sysds/common/InstructionType.java
index 4dba1c5be0..1980dd7984 100644
--- a/src/main/java/org/apache/sysds/common/InstructionType.java
+++ b/src/main/java/org/apache/sysds/common/InstructionType.java
@@ -61,6 +61,7 @@ public enum InstructionType {
MMTSJ,
PMMJ,
MMChain,
+ Union,
//SP Types
MAPMM,
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java
b/src/main/java/org/apache/sysds/common/Opcodes.java
index 71f3e3f752..a4081f9292 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -90,6 +90,8 @@ public enum Opcodes {
MULT2("*2", InstructionType.Binary), //special * case
MINUS_NZ("-nz", InstructionType.Binary), //special - case
+ UNION_DISTINCT("union_distinct", InstructionType.Union),
+
// Boolean Instruction Opcodes
AND("&&", InstructionType.Binary),
OR("||", InstructionType.Binary),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index c5ad9ded2b..fc6e1610ca 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -637,7 +637,8 @@ public interface Types {
MINUS_NZ(false), //sparse-safe minus: X-(mean*ppred(X,0,!=))
LOG_NZ(false), //sparse-safe log; ppred(X,0,"!=")*log(X,0.5)
MINUS1_MULT(false), //1-X*Y
- QUANTIZE_COMPRESS(false); //quantization-fused compression
+ QUANTIZE_COMPRESS(false), //quantization-fused compression
+ UNION_DISTINCT(false);
private final boolean _validOuter;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
index ef5670dda8..fd4445cf44 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationStatic.java
@@ -182,6 +182,7 @@ public class RewriteAlgebraicSimplificationStatic extends
HopRewriteRule
hi = simplifyListIndexing(hi);
//e.g., L[i:i, 1:ncol(L)] -> L[i:i, 1:1]
hi = simplifyScalarIndexing(hop, hi, i);
//e.g., as.scalar(X[i,1])->X[i,1] w/ scalar output
hi = simplifyConstantSort(hop, hi, i);
//e.g., order(matrix())->matrix/seq;
+ hi = simplifyUnionDistinct(hop, hi, i);
//e.g., unique(rbind(A, B)) -> union_distinct(A, B);
hi = simplifyOrderedSort(hop, hi, i);
//e.g., order(matrix())->seq;
hi = fuseOrderOperationChain(hi);
//e.g., order(order(X,2),1) -> order(X,(12))
hi = removeUnnecessaryReorgOperation(hop, hi, i);
//e.g., t(t(X))->X; rev(rev(X))->X potentially introduced by other rewrites
@@ -1837,7 +1838,26 @@ public class RewriteAlgebraicSimplificationStatic
extends HopRewriteRule
}
}
}
+
+ return hi;
+ }
+
+ private static Hop simplifyUnionDistinct(Hop parent, Hop hi, int pos) {
+ // pattern: unique(rbind(A, B)) -> union_distinct(A, B)
+ if(HopRewriteUtils.isAggUnaryOp(hi, AggOp.UNIQUE)
+ && HopRewriteUtils.isBinary(hi.getInput(0),
OpOp2.RBIND)) {
+ Hop rbindAB = hi.getInput(0);
+ if(rbindAB.getParent().size() == 1) {
+ // make sure that rbind is only used here
+ Hop A = rbindAB.getInput(0);
+ Hop B = rbindAB.getInput(1);
+ Hop unionDistinct =
HopRewriteUtils.createBinary(A, B, OpOp2.UNION_DISTINCT);
+ HopRewriteUtils.replaceChildReference(parent,
hi, unionDistinct, pos);
+ HopRewriteUtils.cleanupUnreferenced(hi,
rbindAB);
+ LOG.debug("Applied simplifyUnionDistinct");
+ }
+ }
return hi;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index a76dd6aaca..fa443378e6 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -64,6 +64,7 @@ import
org.apache.sysds.runtime.instructions.cp.TernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.UnionCPInstruction;
import
org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
public class CPInstructionParser extends InstructionParser {
@@ -218,6 +219,9 @@ public class CPInstructionParser extends InstructionParser {
case EvictLineageCache:
return EvictCPInstruction.parseInstruction(str);
+
+ case Union:
+ return UnionCPInstruction.parseInstruction(str);
default:
throw new DMLRuntimeException("Invalid CP
Instruction Type: " + cptype );
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index b0b502f8a0..c99039bb7f 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -46,6 +46,7 @@ public abstract class CPInstruction extends Instruction {
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn,
Sql, Prefetch, Broadcast, TrigRemote,
EvictLineageCache,
NoOp,
+ Union,
QuantizeCompression
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnionCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnionCPInstruction.java
new file mode 100644
index 0000000000..194a2c6929
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnionCPInstruction.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.sysds.common.Opcodes;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.MultiThreadedOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+
+public class UnionCPInstruction extends BinaryCPInstruction {
+
+ private UnionCPInstruction(Operator op, CPOperand in1, CPOperand in2,
CPOperand out, String opcode, String istr) {
+ super(CPType.Union, op, in1, in2, out, opcode, istr);
+ }
+
+ public static UnionCPInstruction parseInstruction(String str) {
+ String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
+ String opcode = parts[0];
+
+ if(!opcode.equalsIgnoreCase(Opcodes.UNION_DISTINCT.toString()))
+ throw new DMLRuntimeException("Invalid opcode for
UNION_DISTINCT: " + opcode);
+
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand out = new CPOperand(parts[parts.length - 2]);
+ MultiThreadedOperator operator = new MultiThreadedOperator();
+ return new UnionCPInstruction(operator, in1, in2, out, opcode,
str);
+ }
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ MatrixBlock matBlock1 = ec.getMatrixInput(input1.getName());
+ MatrixBlock matBlock2 = ec.getMatrixInput(input2.getName());
+ MatrixBlock out = matBlock1.unionOperations(matBlock1,
matBlock2);
+ ec.releaseMatrixInput(input1.getName());
+ ec.releaseMatrixInput(input2.getName());
+ ec.setMatrixOutput(output.getName(), out);
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 7bc516588a..d15e9e5c2b 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -33,6 +33,9 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
@@ -4925,11 +4928,93 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
LibMatrixOuterAgg.aggregateMatrix(mbLeft, mbOut, bv,
bvi, bOp, uaggOp);
} else
throw new DMLRuntimeException("Unsupported operator for
unary aggregate operations.");
-
+
return mbOut;
}
+
+ public MatrixBlock unionOperations(MatrixBlock m1, MatrixBlock m2) {
+ if(m1.getNumColumns() == 1) {
+ HashSet<Double> set = new HashSet<>();
+ boolean[] toAddArr = new boolean[m1.getNumRows() +
m2.getNumRows()];
+ int id = 0;
+ for(MatrixBlock m : new MatrixBlock[] {m1,m2}) {
+ for(int i = 0; i < m.getNumRows(); i++) {
+ Double val = m.get(i, 0);
+ if(!set.contains(val)) {
+ set.add(val);
+ toAddArr[id] = true;
+ }
+ id++;
+ }
+ }
+
+ MatrixBlock mbOut = new MatrixBlock(set.size(),
m1.getNumColumns(), false);
+ int rowOut = 0;
+ int rowId = 0;
+ for(boolean toAdd : toAddArr) {
+ if(toAdd) {
+ if(rowId < m1.getNumRows()) { // is
first matrix
+ mbOut.set(rowOut, 0,
m1.get(rowId, 0));
+ }
+ else { // is second matrix
+ int tempRowId = rowId -
m1.getNumRows();
+ mbOut.set(rowOut, 0,
m2.get(tempRowId, 0));
+ }
+ rowOut++;
+ }
+ rowId++;
+ }
+
+ return mbOut;
+ }
+ else {
+ Set<double[]> set = new TreeSet<>((o1, o2) -> {
+ return Arrays.compare(o1, o2);
+ });
+ boolean[] toAddArr = new boolean[m1.getNumRows() +
m2.getNumRows()];
+ int id = 0;
+
+ //TODO perf dense zero-copy and sparse
+ for(MatrixBlock m : new MatrixBlock[] {m1,m2}) {
+ for(int i = 0; i < m.getNumRows(); i++) {
+ double[] row = new
double[m.getNumColumns()];
+ for(int j = 0; j < m.getNumColumns();
j++)
+ row[j] = m.get(i, j);
+ if(!set.contains(row)) {
+ set.add(row);
+ toAddArr[id] = true;
+ }
+ id++;
+ }
+ }
+
+ MatrixBlock mbOut = new MatrixBlock(set.size(),
m1.getNumColumns(), false);
+ int rowOut = 0;
+ int rowId = 0;
+ for(boolean toAdd : toAddArr) {
+ if(toAdd) {
+ if(rowId < m1.getNumRows()) {
+ // is first matrix
+ for(int i = 0; i <
m1.getNumColumns(); i++) {
+ mbOut.set(rowOut, i,
m1.get(rowId, i));
+ }
+ }
+ else {
+ // is second matrix
+ int tempRowId = rowId -
m1.getNumRows();
+ for(int i = 0; i <
m2.getNumColumns(); i++) {
+ mbOut.set(rowOut, i,
m2.get(tempRowId, i));
+ }
+ }
+ rowOut++;
+ }
+ rowId++;
+ }
+
+ return mbOut;
+ }
+ }
-
/**
* Invocation from CP instructions. The aggregate is computed on the
groups object
* against target and weights.
diff --git
a/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnionDistinctTest.java
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnionDistinctTest.java
new file mode 100644
index 0000000000..af9b009b8c
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/rewrite/RewriteSimplifyUnionDistinctTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.rewrite;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class RewriteSimplifyUnionDistinctTest extends AutomatedTestBase {
+ private static final String TEST_NAME = "RewriteSimplifyUnion";
+ private static final String TEST_DIR = "functions/rewrite/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
RewriteSimplifyUnionDistinctTest.class.getSimpleName()
+ + "/";
+ private static final double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[] {"R"}));
+ }
+
+ @Test
+ public void testUnionDistinctRewriteOne() {
+ testRewriteSimplifyUnionDistinct(1, true);
+ }
+
+ @Test
+ public void testUnionDistinctRewriteFifty() {
+ testRewriteSimplifyUnionDistinct(50, true);
+ }
+
+ @Test
+ public void testUnionDistinctRewriteOneThousand() {
+ testRewriteSimplifyUnionDistinct(1000, true);
+ }
+
+ @Test
+ public void testUnionDistinctRewrite() {
+ testRewriteSimplifyUnionDistinct(2, true);
+ }
+
+ private void testRewriteSimplifyUnionDistinct(int ID, boolean rewrites)
{
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ try {
+ TestConfiguration config =
getTestConfiguration(TEST_NAME);
+ loadTestConfiguration(config);
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ int rowNum = (int) (Math.random() * 1000);
+ programArgs = new String[] {"-explain", "-stats",
"-args", String.valueOf(ID), String.valueOf(rowNum),
+ output("R")};
+ fullRScriptName = HOME + TEST_NAME + ".R";
+ rCmd = getRCmd(String.valueOf(ID),
String.valueOf(rowNum), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION =
rewrites;
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ // compare matrices
+ HashMap<MatrixValue.CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("R");
+ HashMap<MatrixValue.CellIndex, Double> rfile =
readRMatrixFromExpectedDir("R");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+ }
+ finally {
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ }
+ }
+}
diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnion.R
b/src/test/scripts/functions/rewrite/RewriteSimplifyUnion.R
new file mode 100644
index 0000000000..0ca3e0e9aa
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnion.R
@@ -0,0 +1,50 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+args <- commandArgs(TRUE)
+
+# Set options for numeric precision
+options(digits=22)
+
+# Load required libraries
+library("Matrix")
+library("matrixStats")
+
+# Read matrices
+colNum = as.integer(args[1])
+rowNum = as.integer(args[2])
+X = matrix(rep(1, colNum), nrow=1, ncol=colNum)
+Y = matrix(rep(1 + floor(rowNum / 2), colNum), nrow=1, ncol=colNum)
+
+if(rowNum != 1) {
+ for(i in 2 : rowNum - 1) {
+ X = rbind(X, rep(i + 1, colNum))
+ Y = rbind(Y, rep(i + 1 + floor(rowNum / 2), colNum))
+ }
+}
+
+# Perform operations
+combined = rbind(X,Y);
+R = unique(combined);
+
+#Write result matrix R
+writeMM(as(R, "CsparseMatrix"), paste(args[3], "R", sep=""))
diff --git a/src/test/scripts/functions/rewrite/RewriteSimplifyUnion.dml
b/src/test/scripts/functions/rewrite/RewriteSimplifyUnion.dml
new file mode 100644
index 0000000000..bd3edcd7cc
--- /dev/null
+++ b/src/test/scripts/functions/rewrite/RewriteSimplifyUnion.dml
@@ -0,0 +1,41 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+
+colNum = $1
+rowNum = $2
+
+X = matrix(1, rows=rowNum, cols=colNum)
+Y = matrix(1, rows=rowNum, cols=colNum)
+for (i in 1 : rowNum) {
+ for (j in 1 : colNum) {
+ X[i, j] = i
+ Y[i, j] = i + floor(rowNum/2)
+ }
+}
+
+C = rbind(X,Y)
+R = unique(C)
+R = order(target=R)
+
+# Write the result matrix R
+write(R, $3)