This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3741895624 [SYSTEMDS-3909] New einsum expression evaluation framework
3741895624 is described below
commit 3741895624c5650aeb08808fad212ce0e7f9e853
Author: Hubert Krawczyk <[email protected]>
AuthorDate: Thu Aug 21 14:35:15 2025 +0200
[SYSTEMDS-3909] New einsum expression evaluation framework
Closes #2312.
Closes #2265.
---
.../java/org/apache/sysds/common/Builtins.java | 1 +
.../org/apache/sysds/common/InstructionType.java | 1 +
src/main/java/org/apache/sysds/common/Opcodes.java | 1 +
src/main/java/org/apache/sysds/common/Types.java | 4 +-
src/main/java/org/apache/sysds/hops/NaryOp.java | 9 +
.../apache/sysds/hops/codegen/cplan/CNodeCell.java | 2 +-
.../apache/sysds/hops/codegen/cplan/CNodeData.java | 8 +
src/main/java/org/apache/sysds/lops/Nary.java | 1 +
.../sysds/parser/BuiltinFunctionExpression.java | 53 +-
.../org/apache/sysds/parser/DMLTranslator.java | 5 +-
.../apache/sysds/runtime/einsum/EinsumContext.java | 177 +++++
.../runtime/einsum/EinsumEquationValidator.java | 144 ++++
.../runtime/instructions/CPInstructionParser.java | 4 +
.../instructions/cp/BuiltinNaryCPInstruction.java | 5 +-
.../runtime/instructions/cp/CPInstruction.java | 2 +-
.../instructions/cp/EinsumCPInstruction.java | 837 +++++++++++++++++++++
.../sysds/test/functions/einsum/EinsumTest.java | 362 +++++++++
.../functions/einsum/SystemDS-config-codegen.xml | 31 +
src/test/scripts/installDependencies.R | 1 +
19 files changed, 1640 insertions(+), 8 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index fe75aec6a0..5fe1721cc2 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -403,6 +403,7 @@ public enum Builtins {
UNDER_SAMPLING("underSampling", true),
UNIQUE("unique", false, true),
UPPER_TRI("upper.tri", false, true),
+ EINSUM("einsum", false, false),
XDUMMY1("xdummy1", true), //error handling test
XDUMMY2("xdummy2", true); //error handling test
diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java
b/src/main/java/org/apache/sysds/common/InstructionType.java
index 1980dd7984..29148f03e9 100644
--- a/src/main/java/org/apache/sysds/common/InstructionType.java
+++ b/src/main/java/org/apache/sysds/common/InstructionType.java
@@ -62,6 +62,7 @@ public enum InstructionType {
PMMJ,
MMChain,
Union,
+ EINSUM,
//SP Types
MAPMM,
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java
b/src/main/java/org/apache/sysds/common/Opcodes.java
index 64a6c7dd27..6cd1561128 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -174,6 +174,7 @@ public enum Opcodes {
RBIND("rbind", InstructionType.BuiltinNary),
EVAL("eval", InstructionType.BuiltinNary),
LIST("list", InstructionType.BuiltinNary),
+ EINSUM("einsum", InstructionType.BuiltinNary),
//Parametrized builtin functions
AUTODIFF("autoDiff", InstructionType.ParameterizedBuiltin),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index cc7f6eb377..09a0f8effd 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -767,8 +767,8 @@ public interface Types {
/** Operations that require a variable number of operands*/
public enum OpOpN {
- PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
-
+ PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST, EINSUM;
+
public boolean isCellOp() {
return this == MIN || this == MAX || this == PLUS ||
this == MULT;
}
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java
b/src/main/java/org/apache/sysds/hops/NaryOp.java
index 1659b0dbc5..6962beadcb 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -26,6 +26,7 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.Nary;
+import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -235,6 +236,14 @@ public class NaryOp extends Hop {
setDim1(getInput().size());
setDim2(1);
break;
+ case EINSUM:
+ String equationString = ((LiteralOp)
_input.get(0)).getStringValue();
+ var dims =
EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString,
this.getInput().subList(1, this.getInput().size()));
+
+ setDim1(dims.getLeft());
+ setDim2(dims.getMiddle());
+ setDataType(dims.getRight());
+ break;
case PRINTF:
case EVAL:
//do nothing:
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index 3d2c19ef4c..2482ac77e2 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -31,7 +31,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
public class CNodeCell extends CNodeTpl
{
- protected static final String JAVA_TEMPLATE =
+ public static final String JAVA_TEMPLATE =
"package codegen;\n"
+ "import
org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
+ "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n"
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
index 9292972874..b90789df5e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
@@ -56,6 +56,14 @@ public class CNodeData extends CNode
_cols = node.getNumCols();
_dataType = node.getDataType();
}
+
+ public CNodeData(String name, long hopID, long rows, long cols,
DataType dataType) {
+ _name = name;
+ _hopID = hopID;
+ _rows = rows;
+ _cols = cols;
+ _dataType = dataType;
+ }
@Override
public String getVarname() {
diff --git a/src/main/java/org/apache/sysds/lops/Nary.java
b/src/main/java/org/apache/sysds/lops/Nary.java
index e073bc6881..e5382ba033 100644
--- a/src/main/java/org/apache/sysds/lops/Nary.java
+++ b/src/main/java/org/apache/sysds/lops/Nary.java
@@ -111,6 +111,7 @@ public class Nary extends Lop {
case RBIND:
case EVAL:
case LIST:
+ case EINSUM:
return operationType.name().toLowerCase();
case MIN:
case MAX:
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 540b522a8b..28f6949f72 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedList;
import org.antlr.v4.runtime.ParserRuleContext;
import org.apache.commons.lang3.ArrayUtils;
@@ -35,6 +36,7 @@ import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
+import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.DnnUtils;
import org.apache.sysds.runtime.util.UtilFunctions;
@@ -751,7 +753,9 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
else
raiseValidateError("Compress/DeCompress
instruction not allowed in dml script");
break;
-
+ case EINSUM:
+ validateEinsum((DataIdentifier) getOutputs()[0]);
+ break;
default: //always unconditional
raiseValidateError("Unknown Builtin Function opcode: "
+ _opcode, false);
}
@@ -2063,7 +2067,9 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
output.setValueType(ValueType.INT64);
output.setNnz(id.getDim2());
break;
-
+ case EINSUM:
+ validateEinsum(output);
+ break;
default:
if( isMathFunction() ) {
checkMathFunctionParam();
@@ -2096,6 +2102,49 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
}
}
+ private void validateEinsum(DataIdentifier output){
+ if(getSecondExpr() == null)
+ raiseValidateError("Einsum: at least one input matrix
required", false,
+ LanguageErrorCodes.INVALID_PARAMETERS);
+
+ if(!(getFirstExpr() instanceof StringIdentifier))
+ raiseValidateError("Einsum: first argument has to be
equation str", false,
+ LanguageErrorCodes.INVALID_PARAMETERS);
+
+ String equationString =
((StringIdentifier)getFirstExpr()).getValue();
+
+ if (equationString.length() == 0) raiseValidateError("Einsum:
equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS);
+ if (equationString.charAt(0) == '-' || equationString.charAt(0)
== ',') raiseValidateError("Einsum: equation str invalid", false,
LanguageErrorCodes.INVALID_PARAMETERS);
+
+ Expression[] expressions = getAllExpr();
+ boolean allDimsKnown = true;
+
+ LinkedList<Identifier> matrixBlocks = new LinkedList<>();
+ for (int i=1;i<expressions.length; i++){
+ checkMatrixParam(expressions[i]);
+ if(!(expressions[i]).getOutput().dimsKnown()){
+ allDimsKnown = false;
+ break;
+ }
+ matrixBlocks.add((expressions[i].getOutput()));
+ }
+
+ if(allDimsKnown){
+ var dims =
EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString,
matrixBlocks);
+
+ output.setDataType(dims.getRight());
+ output.setDimensions(dims.getLeft(), dims.getMiddle());
+ }else{
+ DataType dataType =
EinsumEquationValidator.validateEinsumEquationNoDimensions(equationString,
_args.length - 1);
+
+ output.setDataType(dataType);
+ output.setDimensions(-1l, -1l);
+ }
+
+ output.setValueType(ValueType.FP64);
+ output.setBlocksize(getSecondExpr().getOutput().getBlocksize());
+ }
+
private void setBinaryOutputProperties(DataIdentifier output) {
DataType dt1 = getFirstExpr().getOutput().getDataType();
DataType dt2 = getSecondExpr().getOutput().getDataType();
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 6827bcc4bf..092fbffe36 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2447,7 +2447,10 @@ public class DMLTranslator
new NaryOp(target.getName(),
target.getDataType(), target.getValueType(),
OpOpN.valueOf(source.getOpCode().name()),
processAllExpressions(source.getAllExpr(), hops));
break;
-
+ case EINSUM:
+ currBuiltinOp = new NaryOp(target.getName(),
target.getDataType(), target.getValueType(),
+
OpOpN.valueOf(source.getOpCode().name()),
processAllExpressions(source.getAllExpr(), hops));
+ break;
case PPRED:
String sop =
((StringIdentifier)source.getThirdExpr()).getValue();
sop = sop.replace("\"", "");
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
new file mode 100644
index 0000000000..6da39e5987
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+
+
+public class EinsumContext {
+ public enum ContractDimensions {
+ CONTRACT_LEFT,
+ CONTRACT_RIGHT,
+ CONTRACT_BOTH,
+ }
+ public Integer outRows;
+ public Integer outCols;
+ public Character outChar1;
+ public Character outChar2;
+ public HashMap<Character, Integer> charToDimensionSize;
+ public String equationString;
+ public boolean[] diagonalInputs;
+ public HashSet<Character> summingChars;
+ public HashSet<Character> contractDimsSet;
+ public ContractDimensions[] contractDims;
+ public ArrayList<String> newEquationStringInputsSplit;
+ public HashMap<Character, ArrayList<Integer>> characterAppearanceIndexes;
// for each character, this tells in which inputs it appears
+
+ private EinsumContext(){};
+ public static EinsumContext getEinsumContext(String eqStr,
ArrayList<MatrixBlock> inputs){
+ EinsumContext res = new EinsumContext();
+
+ res.equationString = eqStr;
+ res.charToDimensionSize = new HashMap<Character, Integer>();
+ HashSet<Character> summingChars = new HashSet<>();
+ ContractDimensions[] contractDims = new
ContractDimensions[inputs.size()];
+ boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by
default
+ HashSet<Character> contractDimsSet = new HashSet<>();
+ HashMap<Character, ArrayList<Integer>> partsCharactersToIndices = new
HashMap<>();
+ ArrayList<String> newEquationStringSplit = new ArrayList<>();
+
+ Iterator<MatrixBlock> it = inputs.iterator();
+ MatrixBlock curArr = it.next();
+ int arrSizeIterator = 0;
+ int arrayIterator = 0;
+ int i;
+ // first iteration through string: collect information on
character-size and what characters are summing characters
+ for (i = 0; true; i++) {
+ char c = eqStr.charAt(i);
+ if(c == '-'){
+ i+=2;
+ break;
+ }
+ if(c == ','){
+ arrayIterator++;
+ curArr = it.next();
+ arrSizeIterator = 0;
+ }
+ else{
+ if (res.charToDimensionSize.containsKey(c)) { // sanity check
if dims match, this is already checked at validation
+ if(arrSizeIterator == 0 && res.charToDimensionSize.get(c)
!= curArr.getNumRows())
+ throw new RuntimeException("Einsum: character "+c+"
has multiple conflicting sizes");
+ else if(arrSizeIterator == 1 &&
res.charToDimensionSize.get(c) != curArr.getNumColumns())
+ throw new RuntimeException("Einsum: character "+c+"
has multiple conflicting sizes");
+ summingChars.add(c);
+ } else {
+ if(arrSizeIterator == 0)
+ res.charToDimensionSize.put(c, curArr.getNumRows());
+ else if(arrSizeIterator == 1)
+ res.charToDimensionSize.put(c, curArr.getNumColumns());
+ }
+
+ arrSizeIterator++;
+ }
+ }
+
+ int numOfRemainingChars = eqStr.length() - i;
+
+ if (numOfRemainingChars > 2)
+ throw new RuntimeException("Einsum: dim > 2 not supported");
+
+ arrSizeIterator = 0;
+
+ Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null;
+ Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) :
null;
+ res.outRows=(numOfRemainingChars > 0 ?
res.charToDimensionSize.get(outChar1) : 1);
+ res.outCols=(numOfRemainingChars > 1 ?
res.charToDimensionSize.get(outChar2) : 1);
+
+ arrayIterator=0;
+ // second iteration through string: collect remaining information
+ for (i = 0; true; i++) {
+ char c = eqStr.charAt(i);
+ if (c == '-') {
+ break;
+ }
+ if (c == ',') {
+ arrayIterator++;
+ arrSizeIterator = 0;
+ continue;
+ }
+ String s = "";
+
+ if(summingChars.contains(c)) {
+ s+=c;
+ if(!partsCharactersToIndices.containsKey(c))
+ partsCharactersToIndices.put(c, new ArrayList<>());
+ partsCharactersToIndices.get(c).add(arrayIterator);
+ }
+ else if((outChar1 != null && c == outChar1) || (outChar2 != null
&& c == outChar2)) {
+ s+=c;
+ }
+ else {
+ contractDimsSet.add(c);
+ contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT;
+ }
+
+ if(i + 1 < eqStr.length()) { // process next character together
+ char c2 = eqStr.charAt(i + 1);
+ i++;
+ if (c2 == '-') { newEquationStringSplit.add(s); break;}
+ if (c2 == ',') { arrayIterator++;
newEquationStringSplit.add(s); continue; }
+
+ if (c2 == c){
+ diagonalInputs[arrayIterator] = true;
+ if (contractDims[arrayIterator] ==
ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] =
ContractDimensions.CONTRACT_BOTH;
+ }
+ else{
+ if(summingChars.contains(c2)) {
+ s+=c2;
+ if(!partsCharactersToIndices.containsKey(c2))
+ partsCharactersToIndices.put(c2, new
ArrayList<>());
+ partsCharactersToIndices.get(c2).add(arrayIterator);
+ }
+ else if((outChar1 != null && c2 == outChar1) || (outChar2
!= null && c2 == outChar2)) {
+ s+=c2;
+ }
+ else {
+ contractDimsSet.add(c2);
+ contractDims[arrayIterator] =
contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ?
ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT;
+ }
+ }
+ }
+ newEquationStringSplit.add(s);
+ arrSizeIterator++;
+ }
+
+ res.contractDims = contractDims;
+ res.contractDimsSet = contractDimsSet;
+ res.diagonalInputs = diagonalInputs;
+ res.summingChars = summingChars;
+ res.outChar1 = outChar1;
+ res.outChar2 = outChar2;
+ res.newEquationStringInputsSplit = newEquationStringSplit;
+ res.characterAppearanceIndexes = partsCharactersToIndices;
+ return res;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
new file mode 100644
index 0000000000..5643159ef9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.lang3.tuple.Triple;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.Identifier;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.ParseInfo;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+
+public class EinsumEquationValidator {
+
+ public static <HopOrIdentifier extends ParseInfo> Triple<Long, Long,
Types.DataType> validateEinsumEquationAndReturnDimensions(String
equationString, List<HopOrIdentifier> expressionsOrIdentifiers) throws
LanguageException {
+ String[] eqStringParts = equationString.split("->"); // length 2 if
"...->..." , length 1 if "...->"
+ boolean isResultScalar = eqStringParts.length == 1;
+
+ if(expressionsOrIdentifiers == null)
+ throw new RuntimeException("Einsum: called
validateEinsumAndReturnDimensions with null list");
+
+ HashMap<Character, Long> charToDimensionSize = new HashMap<>();
+ Iterator<HopOrIdentifier> it = expressionsOrIdentifiers.iterator();
+ HopOrIdentifier currArr = it.next();
+ int arrSizeIterator = 0;
+ int numberOfMatrices = 1;
+ for (int i = 0; i < eqStringParts[0].length(); i++) {
+ char c = equationString.charAt(i);
+ if(c==' ') continue;
+ if(c==','){
+ if(!it.hasNext())
+ throw new LanguageException("Einsum: Provided less
operands than specified in equation str");
+ currArr = it.next();
+ arrSizeIterator = 0;
+ numberOfMatrices++;
+ } else{
+ long thisCharDimension = getThisCharDimension(currArr,
arrSizeIterator);
+ if (charToDimensionSize.containsKey(c)){
+ if (charToDimensionSize.get(c) != thisCharDimension)
+ throw new LanguageException("Einsum: Character '" + c
+ "' expected to be dim " + charToDimensionSize.get(c) + ", but found " +
thisCharDimension);
+ }else{
+ charToDimensionSize.put(c, thisCharDimension);
+ }
+ arrSizeIterator++;
+ }
+ }
+ if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
+ throw new LanguageException("Einsum: Provided more operands than
specified in equation str");
+
+ if (isResultScalar)
+ return Triple.of(-1l,-1l, Types.DataType.SCALAR);
+
+ int numberOfOutDimensions = 0;
+ Character dim1Char = null;
+ long dim1 = 1;
+ long dim2 = 1;
+ for (int i = 0; i < eqStringParts[1].length(); i++) {
+ char c = eqStringParts[1].charAt(i);
+ if (c == ' ') continue;
+ if (numberOfOutDimensions == 0) {
+ dim1Char = c;
+ dim1 = charToDimensionSize.get(c);
+ } else {
+ if(c==dim1Char) throw new LanguageException("Einsum: output
character "+c+" provided multiple times");
+ dim2 = charToDimensionSize.get(c);
+ }
+ numberOfOutDimensions++;
+ }
+ if (numberOfOutDimensions > 2) {
+ throw new LanguageException("Einsum: output matrices with with no.
dims > 2 not supported");
+ } else {
+ return Triple.of(dim1, dim2, Types.DataType.MATRIX);
+ }
+ }
+
+ public static Types.DataType validateEinsumEquationNoDimensions(String
equationString, int numberOfMatrixInputs) throws LanguageException {
+ String[] eqStringParts = equationString.split("->"); // length 2 if
"...->..." , length 1 if "...->"
+ boolean isResultScalar = eqStringParts.length == 1;
+
+ int numberOfMatrices = 1;
+ for (int i = 0; i < eqStringParts[0].length(); i++) {
+ char c = eqStringParts[0].charAt(i);
+ if(c == ' ') continue;
+ if(c == ',')
+ numberOfMatrices++;
+ }
+ if(numberOfMatrixInputs != numberOfMatrices){
+ throw new LanguageException("Einsum: Invalid number of
parameters, given: " + numberOfMatrixInputs + ", expected: " +
numberOfMatrices);
+ }
+
+ if(isResultScalar){
+ return Types.DataType.SCALAR;
+ }else {
+ int numberOfDimensions = 0;
+ Character dim1Char = null;
+ for (int i = 0; i < eqStringParts[1].length(); i++) {
+ char c = eqStringParts[i].charAt(i);
+ if(c == ' ') continue;
+ numberOfDimensions++;
+ if (numberOfDimensions == 1 && c == dim1Char)
+ throw new LanguageException("Einsum: output character
"+c+" provided multiple times");
+ dim1Char = c;
+ }
+
+ if (numberOfDimensions > 2) {
+ throw new LanguageException("Einsum: output matrices with with
no. dims > 2 not supported");
+ } else {
+ return Types.DataType.MATRIX;
+ }
+ }
+ }
+
+ private static <HopOrIdentifier extends ParseInfo> long
getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
+ long thisCharDimension;
+ if(currArr instanceof Hop){
+ thisCharDimension = arrSizeIterator == 0 ? ((Hop)
currArr).getDim1() : ((Hop) currArr).getDim2();
+ } else if(currArr instanceof Identifier){
+ thisCharDimension = arrSizeIterator == 0 ? ((Identifier)
currArr).getDim1() : ((Identifier) currArr).getDim2();
+ } else {
+ throw new RuntimeException("validateEinsumAndReturnDimensions
called with expressions that are not Hop or Identifier: "+ currArr == null ?
"null" : currArr.getClass().toString());
+ }
+ return thisCharDimension;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index fa443378e6..92e11b425d 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -65,6 +65,7 @@ import
org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnionCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
import
org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
public class CPInstructionParser extends InstructionParser {
@@ -223,6 +224,9 @@ public class CPInstructionParser extends InstructionParser {
case Union:
return UnionCPInstruction.parseInstruction(str);
+ case EINSUM:
+ return
EinsumCPInstruction.parseInstruction(str);
+
default:
throw new DMLRuntimeException("Invalid CP
Instruction Type: " + cptype );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
index 6d230a30f0..e7aa1b5fd7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
@@ -91,10 +91,13 @@ public abstract class BuiltinNaryCPInstruction extends
CPInstruction
return new MatrixBuiltinNaryCPInstruction(
new
SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand,
inputOperands);
}
+ else if( opcode.equals(Opcodes.EINSUM.toString()) ) {
+ return new EinsumCPInstruction(null, opcode, str,
outputOperand, inputOperands);
+ }
else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) {
return new EvalNaryCPInstruction(null, opcode, str,
outputOperand, inputOperands);
}
-
+
throw new DMLRuntimeException("Opcode (" + opcode + ") not
recognized in BuiltinMultipleCPInstruction");
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index c99039bb7f..b35ca55dab 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -44,7 +44,7 @@ public abstract class CPInstruction extends Instruction {
Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick,
Local,
MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition,
Compression, DeCompression, SpoofFused,
StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn,
Sql, Prefetch, Broadcast, TrigRemote,
- EvictLineageCache,
+ EvictLineageCache, EINSUM,
NoOp,
Union,
QuantizeCompression
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
new file mode 100644
index 0000000000..c67dd29079
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
@@ -0,0 +1,837 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.Triple;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
+import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
+import org.apache.sysds.hops.codegen.cplan.CNodeCell;
+import org.apache.sysds.hops.codegen.cplan.CNodeData;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
+import org.apache.sysds.runtime.codegen.*;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.einsum.EinsumContext;
+import org.apache.sysds.runtime.functionobjects.*;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+import java.util.*;
+import java.util.function.Predicate;
+
+public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
+ public static boolean FORCE_CELL_TPL = false;
+ protected static final Log LOG =
LogFactory.getLog(EinsumCPInstruction.class.getName());
+ public String eqStr;
+ private final int _numThreads;
+ private final CPOperand[] _in;
+
+ public EinsumCPInstruction(Operator op, String opcode, String istr,
CPOperand out, CPOperand... inputs)
+ {
+ super(op, opcode, istr, out, inputs);
+ _numThreads = OptimizerUtils.getConstrainedNumThreads(-1);
+ _in = inputs;
+ this.eqStr = inputs[0].getName();
+
Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE);
+ }
+
+ @SuppressWarnings("unused")
+ private EinsumContext einc = null;
+
+ @Override
+ public void processInstruction(ExecutionContext ec) {
+ //get input matrices and scalars, incl pinning of matrices
+ ArrayList<MatrixBlock> inputs = new ArrayList<>();
+ for (CPOperand input : _in) {
+ if(input.getDataType()==DataType.MATRIX){
+ MatrixBlock mb =
ec.getMatrixInput(input.getName());
+ if(mb instanceof CompressedMatrixBlock){
+ mb = ((CompressedMatrixBlock)
mb).getUncompressed("Spoof instruction");
+ }
+ inputs.add(mb);
+ }
+ }
+
+ EinsumContext einc = EinsumContext.getEinsumContext(eqStr,
inputs);
+
+ this.einc = einc;
+ String resultString = einc.outChar2 != null ?
String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ?
String.valueOf(einc.outChar1) : "";
+
+ if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+",
outcols:"+einc.outCols);
+
+ ArrayList<String> inputsChars =
einc.newEquationStringInputsSplit;
+
+ if(LOG.isTraceEnabled())
LOG.trace(String.join(",",einc.newEquationStringInputsSplit));
+
+ contractDimensionsAndComputeDiagonals(einc, inputs);
+
+ //make all vetors col vectors
+ for(int i = 0; i < inputs.size(); i++){
+ if(inputs.get(i) != null && inputsChars.get(i).length()
== 1) EnsureMatrixBlockColumnVector(inputs.get(i));
+ }
+
+ if(LOG.isTraceEnabled()) for(Character c :
einc.characterAppearanceIndexes.keySet()){
+ ArrayList<Integer> a =
einc.characterAppearanceIndexes.get(c);
+ LOG.trace(c+" count= "+a.size());
+ }
+
+ // compute scalar by suming-all matrices:
+ Double scalar = null;
+ for(int i=0;i< inputs.size(); i++){
+ String s = inputsChars.get(i);
+ if(s.equals("")){
+ MatrixBlock mb = inputs.get(i);
+ if (scalar == null) scalar = mb.get(0,0);
+ else scalar*= mb.get(0,0);
+ inputs.set(i,null);
+ inputsChars.set(i,null);
+ }
+ }
+
+ if (scalar != null) {
+ inputsChars.add("");
+ inputs.add(new MatrixBlock(scalar));
+ }
+
+ HashMap<Character, Integer> characterToOccurences = new
HashMap<>();
+ for (Character key :einc.characterAppearanceIndexes.keySet()) {
+ characterToOccurences.put(key,
einc.characterAppearanceIndexes.get(key).size());
+ }
+ for (Character key :einc.charToDimensionSize.keySet()) {
+ if(!characterToOccurences.containsKey(key))
+ characterToOccurences.put(key, 1);
+ }
+
+ ArrayList<EOpNode> eOpNodes = new
ArrayList<>(inputsChars.size());
+ for (int i = 0; i < inputsChars.size(); i++) {
+ if (inputsChars.get(i) == null) continue;
+ EOpNodeData n = new
EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) :
null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i);
+ eOpNodes.add(n);
+ }
+ Pair<Integer, List<EOpNode> > plan = FORCE_CELL_TPL ? null :
generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences,
einc.outChar1, einc.outChar2);
+
+
+ ArrayList<MatrixBlock> resMatrices = FORCE_CELL_TPL ? null :
executePlan(plan.getRight(), inputs);
+// ArrayList<MatrixBlock> resMatrices =
executePlan(plan.getRight(), inputs, true);
+
+ if(!FORCE_CELL_TPL && resMatrices.size() == 1){
+ EOpNode resNode = plan.getRight().get(0);
+ if (einc.outChar1 != null && einc.outChar2 != null){
+ if(resNode.c1 == einc.outChar1 && resNode.c2 ==
einc.outChar2){
+ ec.setMatrixOutput(output.getName(),
resMatrices.get(0));
+ }
+ else if(resNode.c1 == einc.outChar2 &&
resNode.c2 == einc.outChar1){
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+ MatrixBlock resM =
resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0);
+ ec.setMatrixOutput(output.getName(),
resM);
+ }else{
+ if(LOG.isTraceEnabled())
LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
+ throw new RuntimeException("Einsum plan
produced different result");
+ }
+ }else if (einc.outChar1 != null){
+ if(resNode.c1 == einc.outChar1 && resNode.c2
== null){
+ ec.setMatrixOutput(output.getName(),
resMatrices.get(0));
+ }else{
+ if(LOG.isTraceEnabled())
LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
+ throw new RuntimeException("Einsum plan
produced different result");
+ }
+ }else{
+ if(resNode.c1 == null && resNode.c2 == null){
+ ec.setScalarOutput(output.getName(),
new DoubleObject(resMatrices.get(0).get(0, 0)));;
+ }
+ }
+ }else{
+ // use cell template with loops for remaining
+ ArrayList<MatrixBlock> mbs = resMatrices;
+ ArrayList<String> chars = new ArrayList<>();
+
+ for (int i = 0; i < plan.getRight().size(); i++) {
+ String s;
+ if(plan.getRight().get(i).c1 == null) s = "";
+ else if(plan.getRight().get(i).c2 == null) s =
plan.getRight().get(i).c1.toString();
+ else s = plan.getRight().get(i).c1.toString() +
plan.getRight().get(i).c2;
+ chars.add(s);
+ }
+
+ ArrayList<Character> summingChars = new ArrayList<>();
+ for (Character c :
einc.characterAppearanceIndexes.keySet()) {
+ if (c != einc.outChar1 && c != einc.outChar2)
summingChars.add(c);
+ }
+ if(LOG.isTraceEnabled()) LOG.trace("finishing with cell
tpl: "+String.join(",", chars));
+
+ MatrixBlock res = computeCellSummation(mbs, chars,
resultString, einc.charToDimensionSize, summingChars, einc.outRows,
einc.outCols);
+
+ if (einc.outRows == 1 && einc.outCols == 1)
+ ec.setScalarOutput(output.getName(), new
DoubleObject(res.get(0, 0)));
+ else ec.setMatrixOutput(output.getName(), res);
+ }
+ if(LOG.isTraceEnabled()) LOG.trace("EinsumCPInstruction
Finished");
+
+ releaseMatrixInputs(ec);
+
+ }
+
+ private void contractDimensionsAndComputeDiagonals(EinsumContext einc,
ArrayList<MatrixBlock> inputs) {
+ for(int i = 0; i< einc.contractDims.length; i++){
+ //AggregateOperator agg = new AggregateOperator(0,
KahanPlus.getKahanPlusFnObject(),Types.CorrectionLocationType.LASTCOLUMN);
+ AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+
+ if(einc.diagonalInputs[i]){
+ ReorgOperator op = new
ReorgOperator(DiagIndex.getDiagIndexFnObject());
+ inputs.set(i, inputs.get(i).reorgOperations(op,
new MatrixBlock(),0,0,0));
+ }
+ if (einc.contractDims[i] == null) continue;
+ switch (einc.contractDims[i]){
+ case CONTRACT_BOTH: {
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+ MatrixBlock res = new MatrixBlock(1, 1,
false);
+
inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null);
+ inputs.set(i, res);
+ break;
+ }
+ case CONTRACT_RIGHT: {
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+ MatrixBlock res = new
MatrixBlock(inputs.get(i).getNumRows(), 1, false);
+
inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null);
+ inputs.set(i, res);
+ break;
+ }
+ case CONTRACT_LEFT: {
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
+ MatrixBlock res = new
MatrixBlock(inputs.get(i).getNumColumns(), 1, false);
+
inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null);
+ inputs.set(i, res);
+ break;
+ }
+ default:
+ break;
+ }
+ }
+ }
+
+ private enum EBinaryOperand { // upper case: char has to remain, lower
case: to be summed
+ ////// summations: //////
+ aB_a,// -> B
+ Ba_a, // -> B
+ Ba_aC, // mmult -> BC
+ aB_Ca,
+ Ba_Ca, // -> BC
+ aB_aC, // outer mult, possibly with transposing first -> BC
+ a_a,// dot ->
+
+ ////// elementwisemult and sums, something like ij,ij->i
//////
+ aB_aB,// elemwise and colsum -> B
+ Ba_Ba, // elemwise and rowsum ->B
+ Ba_aB, // elemwise, either colsum or rowsum -> B
+// aB_Ba,
+
+ ////// elementwise, no summations: //////
+ A_A,// v-elemwise -> A
+ AB_AB,// M-M elemwise -> AB
+ AB_BA, // M-M.T elemwise -> AB
+ AB_A, // M-v colwise -> BA!?
+ BA_A, // M-v rowwise -> BA
+ ab_ab,//M-M sum all
+ ab_ba, //M-M.T sum all
+ ////// other //////
+ A_B, // outer mult -> AB
+ A_scalar, // v-scalar
+ AB_scalar, // m-scalar
+ scalar_scalar
+ }
+ private abstract class EOpNode {
+ public Character c1;
+ public Character c2; // nullable
+ public EOpNode(Character c1, Character c2){
+ this.c1 = c1;
+ this.c2 = c2;
+ }
+ }
+ private class EOpNodeBinary extends EOpNode {
+ public EOpNode left;
+ public EOpNode right;
+ public EBinaryOperand operand;
+ public EOpNodeBinary(Character c1, Character c2, EOpNode left,
EOpNode right, EBinaryOperand operand){
+ super(c1,c2);
+ this.left = left;
+ this.right = right;
+ this.operand = operand;
+ }
+ }
+ private class EOpNodeData extends EOpNode {
+ public int matrixIdx;
+ public EOpNodeData(Character c1, Character c2, int matrixIdx){
+ super(c1,c2);
+ this.matrixIdx = matrixIdx;
+ }
+ }
+
+ private Pair<Integer, List<EOpNode> /* ideally with one element */>
generatePlan(int cost, ArrayList<EOpNode> operands, HashMap<Character, Integer>
charToSizeMap, HashMap<Character, Integer> charToOccurences, Character
outChar1, Character outChar2) {
+ Integer minCost = cost;
+ List<EOpNode> minNodes = operands;
+
+ if (operands.size() == 2){
+ boolean swap = (operands.get(0).c2 == null &&
operands.get(1).c2 != null) || operands.get(0).c1 == null;
+ EOpNode n1 = operands.get(!swap ? 0 : 1);
+ EOpNode n2 = operands.get(!swap ? 1 : 0);
+ Triple<Integer, EBinaryOperand, Pair<Character,
Character>> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences,
outChar1, outChar2);
+ if (t != null) {
+ EOpNodeBinary newNode = new
EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2,
t.getMiddle());
+ int thisCost = cost + t.getLeft();
+ return Pair.of(thisCost,
Arrays.asList(newNode));
+ }
+ return Pair.of(cost, operands);
+ }
+ else if (operands.size() == 1){
+ // check for transpose
+ return Pair.of(cost, operands);
+ }
+
+ for(int i = 0; i < operands.size()-1; i++){
+ for (int j = i+1; j < operands.size(); j++){
+ boolean swap = (operands.get(i).c2 == null &&
operands.get(j).c2 != null) || operands.get(i).c1 == null;
+ EOpNode n1 = operands.get(!swap ? i : j);
+ EOpNode n2 = operands.get(!swap ? j : i);
+
+
+ Triple<Integer, EBinaryOperand, Pair<Character,
Character>> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences,
outChar1, outChar2);
+ if (t != null){
+ EOpNodeBinary newNode = new
EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2,
t.getMiddle());
+ int thisCost = cost + t.getLeft();
+
+ if(n1.c1 != null)
charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)-1);
+ if(n1.c2 != null)
charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)-1);
+ if(n2.c1 != null)
charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)-1);
+ if(n2.c2 != null)
charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)-1);
+
+ if(newNode.c1 != null)
charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)+1);
+ if(newNode.c2 != null)
charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)+1);
+
+ ArrayList<EOpNode> newOperands = new
ArrayList<>(operands.size()-1);
+ for(int z = 0; z < operands.size();
z++){
+ if(z != i && z != j)
newOperands.add(operands.get(z));
+ }
+ newOperands.add(newNode);
+
+ Pair<Integer, List<EOpNode>>
furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap,
charToOccurences, outChar1, outChar2);
+ if(furtherPlan.getRight().size() <
(minNodes.size()) || furtherPlan.getLeft() < minCost){
+ minCost = furtherPlan.getLeft();
+ minNodes =
furtherPlan.getRight();
+ }
+
+ if(n1.c1 != null)
charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)+1);
+ if(n1.c2 != null)
charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)+1);
+ if(n2.c1 != null)
charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)+1);
+ if(n2.c2 != null)
charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)+1);
+ if(newNode.c1 != null)
charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)-1);
+ if(newNode.c2 != null)
charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)-1);
+ }
+ }
+ }
+
+ return Pair.of(minCost, minNodes);
+ }
+
+ private static Triple<Integer, EBinaryOperand, Pair<Character,
Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character,
Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character
outChar1, Character outChar2){
+ Predicate<Character> cannotBeSummed = (c) ->
+ c == outChar1 || c == outChar2 ||
charToOccurences.get(c) > 2;
+
+ if(n1.c1 == null) {
+ // n2.c1 also has to be null
+ return Triple.of(1, EBinaryOperand.scalar_scalar,
Pair.of(null, null));
+ }
+
+ if(n2.c1 == null) {
+ if(n1.c2 == null)
+ return Triple.of(charToSizeMap.get(n1.c1),
EBinaryOperand.A_scalar, Pair.of(n1.c1, null));
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2));
+ }
+
+ if(n1.c1 == n2.c1){
+ if(n1.c2 != null){
+ if ( n1.c2 == n2.c2){
+ if( cannotBeSummed.test(n1.c1)){
+ if(cannotBeSummed.test(n1.c2)){
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2));
+ }
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null));
+ }
+
+ if(cannotBeSummed.test(n1.c2)){
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.aB_aB, Pair.of(n1.c2, null));
+ }
+
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.ab_ab, Pair.of(null, null));
+
+ }
+
+ else if(n2.c2 == null){
+ if(cannotBeSummed.test(n1.c1)){
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2,
EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2));
+ }
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2,
EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2)
+ }
+ else if(n1.c1 ==outChar1 || n1.c1==outChar2||
charToOccurences.get(n1.c1) > 2){
+ return null;// AB,AC
+ }
+ else {
+ return
Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)),
EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2
+ }
+ }else{ // n1.c2 = null -> c2.c2 = null
+ if(n1.c1 ==outChar1 || n1.c1==outChar2 ||
charToOccurences.get(n1.c1) > 2){
+ return
Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null));
+ }
+ return Triple.of(charToSizeMap.get(n1.c1),
EBinaryOperand.a_a, Pair.of(null, null));
+ }
+
+
+ }else{ // n1.c1 != n2.c1
+ if(n1.c2 == null) {
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1),
EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1));
+ }
+ else if(n2.c2 == null) { // ab,c
+ if (n1.c2 == n2.c1) {
+ if(cannotBeSummed.test(n1.c2)){
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1),
EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2));
+ }
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1),
EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
+ }
+ return null; // AB,C
+ }
+ else if (n1.c2 == n2.c1) {
+ if(n1.c1 == n2.c2){ // ab,ba
+ if(cannotBeSummed.test(n1.c1)){
+ if(cannotBeSummed.test(n1.c2)){
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2));
+ }
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.Ba_aB, Pair.of(n1.c1, null));
+ }
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.ab_ba, Pair.of(null, null));
+ }
+ if(cannotBeSummed.test(n1.c2)){
+ return null; // AB_B
+ }else{
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2),
EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2));
+// if(n1.c1 ==outChar1 ||
n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
+// return null; // AB_B
+// }
+// return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2),
EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
+ }
+ }
+ if(n1.c1 == n2.c2) {
+ if(cannotBeSummed.test(n1.c1)){
+ return null; // AB_B
+ }
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1),
EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult
+ }
+ else if (n1.c2 == n2.c2) {
+ if(n1.c2 ==outChar1 || n1.c2==outChar2||
charToOccurences.get(n1.c2) > 2){
+ return null; // BA_CA
+ }else{
+ return
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)
+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)),
EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1
+ }
+ }
+ else { // we have something like ab,cd
+ return null;
+ }
+ }
+ }
+
+ private ArrayList<MatrixBlock /* #els = #els of plan */>
executePlan(List<EOpNode> plan, ArrayList<MatrixBlock> inputs){
+ return executePlan(plan, inputs, false);
+ }
+ private ArrayList<MatrixBlock /* #els = #els of plan */>
executePlan(List<EOpNode> plan, ArrayList<MatrixBlock> inputs, boolean codegen)
{
+ ArrayList<MatrixBlock> res = new ArrayList<>(plan.size());
+ for(EOpNode p : plan){
+ if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs));
+ else res.add(ComputeEOpNode(p, inputs));
+ }
+ return res;
+ }
+
+ private MatrixBlock ComputeEOpNode(EOpNode eOpNode,
ArrayList<MatrixBlock> inputs){
+ if(eOpNode instanceof EOpNodeData eOpNodeData){
+ return inputs.get(eOpNodeData.matrixIdx);
+ }
+ EOpNodeBinary bin = (EOpNodeBinary) eOpNode;
+ MatrixBlock left = ComputeEOpNode(bin.left, inputs);
+ MatrixBlock right = ComputeEOpNode(bin.right, inputs);
+
+ AggregateOperator agg = new AggregateOperator(0,
Plus.getPlusFnObject());
+
+ MatrixBlock res;
+ switch (bin.operand){
+ case AB_AB -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case A_A -> {
+ EnsureMatrixBlockColumnVector(left);
+ EnsureMatrixBlockColumnVector(right);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case a_a -> {
+ EnsureMatrixBlockColumnVector(left);
+ EnsureMatrixBlockColumnVector(right);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ }
+ ////////////
+ case Ba_Ba -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ }
+ case aB_aB -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ EnsureMatrixBlockColumnVector(res);
+ }
+ case ab_ab -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ }
+ case ab_ba -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+ right = right.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ }
+ case Ba_aB -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+ right = right.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ }
+
+ /////////
+ case AB_BA -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+ right = right.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left,
right},new ScalarObject[]{}, new MatrixBlock());
+ }
+ case Ba_aC -> {
+ res = LibMatrixMult.matrixMult(left,right, new
MatrixBlock(), _numThreads);
+ }
+ case aB_Ca -> {
+ res = LibMatrixMult.matrixMult(right,left, new
MatrixBlock(), _numThreads);
+ }
+ case Ba_Ca -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+ right = right.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = LibMatrixMult.matrixMult(left,right, new
MatrixBlock(), _numThreads);
+ }
+ case aB_aC -> {
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+ left = left.reorgOperations(transpose, new
MatrixBlock(), 0, 0, 0);
+ res = LibMatrixMult.matrixMult(left,right, new
MatrixBlock(), _numThreads);
+ }
+ case A_scalar, AB_scalar -> {
+ res = MatrixBlock.naryOperations(new
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new
ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
+ }
+ case BA_A -> {
+ EnsureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case Ba_a -> {
+ EnsureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ }
+
+ case AB_A -> {
+ EnsureMatrixBlockColumnVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case aB_a -> {
+ EnsureMatrixBlockColumnVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ AggregateUnaryOperator aggun = new
AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
+ res = (MatrixBlock)
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+ EnsureMatrixBlockColumnVector(res);
+ }
+
+ case A_B -> {
+ EnsureMatrixBlockColumnVector(left);
+ EnsureMatrixBlockRowVector(right);
+ res = left.binaryOperations(new
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+ }
+ case scalar_scalar -> {
+ return new
MatrixBlock(left.get(0,0)*right.get(0,0));
+ }
+ default -> {
+ throw new IllegalArgumentException("Unexpected
value: " + bin.operand.toString());
+ }
+
+ }
+ return res;
+ }
+
+ private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode,
ArrayList<MatrixBlock> inputs){
+ return rComputeEOpNodeCodegen(eOpNode, inputs);
+// throw new NotImplementedException();
+ }
+ private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){
+ return new CNodeData("ce"+id, id, mb.getNumRows(),
mb.getNumColumns(), DataType.MATRIX);
+ }
+ private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode,
ArrayList<MatrixBlock> inputs) {
+ if (eOpNode instanceof EOpNodeData eOpNodeData){
+ return inputs.get(eOpNodeData.matrixIdx);
+// return new CNodeData("ce"+eOpNodeData.matrixIdx,
eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(),
inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX);
+ }
+
+ EOpNodeBinary bin = (EOpNodeBinary) eOpNode;
+// CNodeData dataLeft = null;
+// if (bin.left instanceof EOpNodeData eOpNodeData) dataLeft = new
CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx,
inputs.get(eOpNodeData.matrixIdx).getNumRows(),
inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX);
+// CNodeData dataRight = null;
+// if (bin.right instanceof EOpNodeData eOpNodeData) dataRight =
new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx,
inputs.get(eOpNodeData.matrixIdx).getNumRows(),
inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX);
+
+ if(bin.operand == EBinaryOperand.AB_AB){
+ if (bin.right instanceof EOpNodeBinary rBinary &&
rBinary.operand == EBinaryOperand.AB_AB){
+ MatrixBlock left =
rComputeEOpNodeCodegen(bin.left, inputs);
+
+ MatrixBlock right1 =
rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).left, inputs);
+ MatrixBlock right2 =
rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).right, inputs);
+
+ CNodeData d0 = MatrixBlockToCNodeData(left, 0);
+ CNodeData d1 = MatrixBlockToCNodeData(right1,
1);
+ CNodeData d2 = MatrixBlockToCNodeData(right2,
2);
+// CNodeNary nary = new CNodeNary(cnodeIn,
CNodeNary.NaryType.)
+ CNodeBinary rightBinary = new CNodeBinary(d1,
d2, CNodeBinary.BinType.VECT_MULT);
+ CNodeBinary cNodeBinary = new CNodeBinary(d0,
rightBinary, CNodeBinary.BinType.VECT_MULT);
+ ArrayList<CNode> cnodeIn = new ArrayList<>();
+ cnodeIn.add(d0);
+ cnodeIn.add(d1);
+ cnodeIn.add(d2);
+
+ CNodeRow cnode = new CNodeRow(cnodeIn,
cNodeBinary);
+
+ cnode.setRowType(SpoofRowwise.RowType.NO_AGG);
+ cnode.renameInputs();
+
+
+ String src = cnode.codegen(false,
SpoofCompiler.GeneratorAPI.JAVA);
+ if( LOG.isTraceEnabled())
LOG.trace(CodegenUtils.printWithLineNumber(src));
+ Class<?> cla =
CodegenUtils.compileClass("codegen." + cnode.getClassname(), src);
+
+ SpoofOperator op =
CodegenUtils.createInstance(cla);
+ MatrixBlock mb = new MatrixBlock();
+
+ ArrayList<ScalarObject> scalars = new
ArrayList<>();
+ ArrayList<MatrixBlock> mbs = new ArrayList<>(3);
+ mbs.add(left);
+ mbs.add(right1);
+ mbs.add(right2);
+ MatrixBlock out = op.execute(mbs, scalars, mb,
6);
+
+ return out;
+ }
+ }
+
+ throw new NotImplementedException();
+ }
+
+
+ private void releaseMatrixInputs(ExecutionContext ec){
+ for (CPOperand input : _in)
+ if(input.getDataType()==DataType.MATRIX)
+ ec.releaseMatrixInput(input.getName()); //todo
release other
+ }
+
+ private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){
+ if(mb.getNumColumns() > 1){
+ mb.setNumRows(mb.getNumColumns());
+ mb.setNumColumns(1);
+ mb.getDenseBlock().resetNoFill(mb.getNumRows(),1);
+ }
+ }
+ private static void EnsureMatrixBlockRowVector(MatrixBlock mb){
+ if(mb.getNumRows() > 1){
+ mb.setNumColumns(mb.getNumRows());
+ mb.setNumRows(1);
+ mb.getDenseBlock().resetNoFill(1,mb.getNumColumns());
+ }
+ }
+
+ private static void indent(StringBuilder sb, int level) {
+ for (int i = 0; i < level; i++) {
+ sb.append(" ");
+ }
+ }
+
+ private MatrixBlock computeCellSummation(ArrayList<MatrixBlock> inputs,
List<String> inputsChars, String resultString,
+
HashMap<Character, Integer>
charToDimensionSizeInt, List<Character> summingChars, int outRows, int outCols){
+ ArrayList<CNode> cnodeIn = new ArrayList<>();
+ cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0,
DataType.SCALAR));
+ CNodeCell cnode = new CNodeCell(cnodeIn, null);
+ StringBuilder sb = new StringBuilder();
+
+ int indent = 2;
+ indent(sb, indent);
+
+ boolean needsSumming = summingChars.stream().anyMatch(x -> x !=
null);
+
+ String itVar0 = cnode.createVarname();
+ String outVar = itVar0;
+ if (needsSumming) {
+ sb.append("double ");
+ sb.append(outVar);
+ sb.append("=0;\n");
+ }
+
+ Iterator<Character> hsIt = summingChars.iterator();
+ while (hsIt.hasNext()) {
+ indent(sb, indent);
+ indent++;
+ Character c = hsIt.next();
+ String itVar = itVar0 + c;
+ sb.append("for(int ");
+ sb.append(itVar);
+ sb.append("=0;");
+ sb.append(itVar);
+ sb.append("<");
+ sb.append(charToDimensionSizeInt.get(c));
+ sb.append(";");
+ sb.append(itVar);
+ sb.append("++){\n");
+ }
+ indent(sb, indent);
+ if (needsSumming) {
+ sb.append(outVar);
+ sb.append("+=");
+ }
+
+ for (int i = 0; i < inputsChars.size(); i++) {
+ if (inputsChars.get(i).length() == 0){
+ sb.append("getValue(b[");
+ sb.append(i);
+ sb.append("],b[");
+ sb.append(i);
+ sb.append("].clen, 0,");
+ }
+
+ else if
(summingChars.contains(inputsChars.get(i).charAt(0))) {
+ sb.append("getValue(b[");
+ sb.append(i);
+ sb.append("],b[");
+ sb.append(i);
+ sb.append("].clen,");
+ sb.append(itVar0);
+ sb.append(inputsChars.get(i).charAt(0));
+ sb.append(",");
+ } else if (resultString.length() >= 1 &&
inputsChars.get(i).charAt(0) == resultString.charAt(0)) {
+ sb.append("getValue(b[");
+ sb.append(i);
+ sb.append("],b[");
+ sb.append(i);
+ sb.append("].clen, rix,");
+ } else if (resultString.length() == 2 &&
inputsChars.get(i).charAt(0) == resultString.charAt(1)) {
+ sb.append("getValue(b[");
+ sb.append(i);
+ sb.append("],b[");
+ sb.append(i);
+ sb.append("].clen, cix,");
+ } else {
+ sb.append("getValue(b[");
+ sb.append(i);
+ sb.append("],b[");
+ sb.append(i);
+ sb.append("].clen, 0,");
+ }
+
+ if (inputsChars.get(i).length() != 2){
+ sb.append("0)");
+ }
+ else if
(summingChars.contains(inputsChars.get(i).charAt(1))) {
+ sb.append(itVar0);
+ sb.append(inputsChars.get(i).charAt(1));
+ sb.append(")");
+ } else if (resultString.length() >= 1
&&inputsChars.get(i).charAt(1) == resultString.charAt(0)) {
+ sb.append("rix)");
+ } else if (resultString.length() == 2 &&
inputsChars.get(i).charAt(1) == resultString.charAt(1)) {
+ sb.append("cix)");
+ } else {
+ sb.append("0)");
+ }
+
+ if (i < inputsChars.size() - 1) {
+ sb.append(" * ");
+ }
+
+ }
+ if (needsSumming) {
+ sb.append(";\n");
+ }
+ indent--;
+ for (int si = 0; si < summingChars.size(); si++) {
+ indent(sb, indent);
+ indent--;
+ sb.append("}\n");
+ }
+ String src = CNodeCell.JAVA_TEMPLATE;//
+ src = src.replace("%TMP%", cnode.createVarname());
+ src = src.replace("%TYPE%", "NO_AGG");
+ src = src.replace("%SPARSE_SAFE%", "false");
+ src = src.replace("%SEQ%", "true");
+ src = src.replace("%AGG_OP_NAME%", "null");
+ if (needsSumming) {
+ src = src.replace("%BODY_dense%", sb.toString());
+ src = src.replace("%OUT%", outVar);
+ } else {
+ src = src.replace("%BODY_dense%", "");
+ src = src.replace("%OUT%", sb.toString());
+ }
+
+ if( LOG.isTraceEnabled()) LOG.trace(src);
+ Class<?> cla = CodegenUtils.compileClass("codegen." +
cnode.getClassname(), src);
+ SpoofOperator op = CodegenUtils.createInstance(cla);
+ MatrixBlock resBlock = new MatrixBlock();
+ resBlock.reset(outRows, outCols);
+ inputs.add(0, resBlock);
+ MatrixBlock out = op.execute(inputs, new ArrayList<>(), new
MatrixBlock(), _numThreads);
+
+ return out;
+ }
+
+ public CPOperand[] getInputs() {
+ return _in;
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
new file mode 100644
index 0000000000..dbf9047968
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
@@ -0,0 +1,362 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.einsum;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+
+@RunWith(Parameterized.class)
+public class EinsumTest extends AutomatedTestBase
+{
+ final private static List<Config> TEST_CONFIGS = List.of(
+ new Config("ij,jk->ik", List.of(shape(50, 600),
shape(600, 10))), // mm
+ new Config("ji,jk->ik", List.of(shape(600, 5),
shape(600, 10))),
+ new Config("ji,kj->ik", List.of(shape(600, 5),
shape(10, 600))),
+ new Config("ij,kj->ik", List.of(shape(5, 600),
shape(10, 600))),
+
+ new Config("ji,jk->i", List.of(shape(600, 5),
shape(600, 10))),
+ new Config("ij,jk->i", List.of(shape(5, 600),
shape(600, 10))),
+
+ new Config("ji,jk->k", List.of(shape(600, 5),
shape(600, 10))),
+ new Config("ij,jk->k", List.of(shape(5, 600),
shape(600, 10))),
+
+ new Config("ji,jk->j", List.of(shape(600, 5),
shape(600, 10))),
+
+ new Config("ji,ji->ji", List.of(shape(600, 5),
shape(600, 5))), // elemwise mult
+ new Config("ji,ji,ji->ji", List.of(shape(600,
5),shape(600, 5), shape(600, 5)),
+ List.of(0.0001, 0.0005, 0.001)),
+ new Config("ji,ij->ji", List.of(shape(600, 5), shape(5,
600))), // elemwise mult
+
+
+ new Config("ij,i->ij", List.of(shape(100, 50),
shape(100))), // col mult
+ new Config("ji,i->ij", List.of(shape(50, 100),
shape(100))), // row mult
+ new Config("ij,i->i", List.of(shape(100, 50),
shape(100))),
+ new Config("ij,i->j", List.of(shape(100, 50),
shape(100))),
+
+ new Config("i,i->", List.of(shape(50), shape(50))),
+ new Config("i,j->", List.of(shape(50), shape(80))),
+ new Config("i,j->ij", List.of(shape(50),
shape(80))), // outer vect mult
+ new Config("i,j->ji", List.of(shape(50),
shape(80))), // outer vect mult
+
+ new Config("ij->", List.of(shape(100, 50))), // sum
+ new Config("ij->i", List.of(shape(100, 50))), //
sum(1)
+ new Config("ij->j", List.of(shape(100, 50))), //
sum(0)
+ new Config("ij->ji", List.of(shape(100, 50))), // T
+
+ new Config("ab,cd->ba", List.of(shape( 600, 10),
shape(6, 5))),
+ new Config("ab,cd,g->ba", List.of(shape( 600, 10),
shape(6, 5), shape(3))),
+
+ new Config("ab,bc,cd,de->ae", List.of(shape(5, 600),
shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm
+
+ new Config("ji,jz,zx->ix", List.of(shape(600, 5),
shape( 600, 10), shape(10, 2))),
+ new Config("fx,fg,fz,xg->z", List.of(shape(600, 5),
shape( 600, 10), shape(600, 6), shape(5, 10))),
+ new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times
(cell tpl)
+ List.of(shape(5, 60), shape(5, 30),
shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))),
+
+ new Config("i->", List.of(shape(100))),
+ new Config("i->i", List.of(shape(100)))
+ );
+
+ private final int id;
+ private final String einsumStr;
+ //private final List<int[]> shapes;
+ private final File dmlFile;
+ private final File rFile;
+ private final boolean outputScalar;
+
+ public EinsumTest(String einsumStr, List<int[]> shapes, File dmlFile,
File rFile, boolean outputScalar, int id){
+ this.id = id;
+ this.einsumStr = einsumStr;
+ //this.shapes = shapes;
+ this.dmlFile = dmlFile;
+ this.rFile = rFile;
+ this.outputScalar = outputScalar;
+ }
+
+ @Parameterized.Parameters(name = "{index}: einsum={0}")
+ public static Collection<Object[]> data() throws IOException {
+ List<Object[]> parameters = new ArrayList<>();
+
+ int counter = 1;
+
+ for (Config config : TEST_CONFIGS) {
+ //List<File> files = new ArrayList<>();
+ String fullDMLScriptName = "SystemDS_einsum_test" +
counter;
+
+ File dmlFile = File.createTempFile(fullDMLScriptName,
".dml");
+ dmlFile.deleteOnExit();
+
+ boolean outputScalar =
config.einsumStr.trim().endsWith("->");
+
+ StringBuilder sb = createDmlFile(config, outputScalar);
+
+ Files.writeString(dmlFile.toPath(), sb.toString());
+
+ File rFile = File.createTempFile(fullDMLScriptName,
".R");
+ rFile.deleteOnExit();
+
+ sb = createRFile(config, outputScalar);
+
+ Files.writeString(rFile.toPath(), sb.toString());
+
+ parameters.add(new Object[]{config.einsumStr,
config.shapes, dmlFile, rFile, outputScalar, counter});
+
+ counter++;
+ }
+
+ return parameters;
+ }
+
+ private static StringBuilder createDmlFile(Config config, boolean
outputScalar) {
+ StringBuilder sb = new StringBuilder();
+
+ for (int i = 0; i < config.shapes.size(); i++) {
+ int[] dims = config.shapes.get(i);
+
+ double factor = config.factors != null ?
config.factors.get(i) : 0.0001;
+ sb.append("A");
+ sb.append(i);
+
+ if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001
+ sb.append(" = seq(1,");
+ sb.append(dims[0]);
+ sb.append(") * ");
+ sb.append(factor);
+ } else { // A0 = matrix(seq(1,50000), 1000, 50) * 0.0001
+ sb.append(" = matrix(seq(1, ");
+ sb.append(dims[0]*dims[1]);
+ sb.append("), ");
+ sb.append(dims[0]);
+ sb.append(", ");
+ sb.append(dims[1]);
+
+ sb.append(") * ");
+ sb.append(factor);
+ }
+ sb.append("\n");
+ }
+ sb.append("\n");
+
+ sb.append("R = einsum(\"");
+ sb.append(config.einsumStr);
+ sb.append("\", ");
+
+ for (int i = 0; i < config.shapes.size() - 1; i++) {
+ sb.append("A");
+ sb.append(i);
+ sb.append(", ");
+ }
+ sb.append("A");
+ sb.append(config.shapes.size() - 1);
+ sb.append(")");
+
+ sb.append("\n\n");
+ sb.append("write(R, $1)\n");
+ return sb;
+ }
+
+ private static StringBuilder createRFile(Config config, boolean
outputScalar) {
+ StringBuilder sb = new StringBuilder();
+ sb.append("args<-commandArgs(TRUE)\n");
+ sb.append("options(digits=22)\n");
+ sb.append("library(\"Matrix\")\n");
+ sb.append("library(\"matrixStats\")\n");
+ sb.append("library(\"einsum\")\n\n");
+
+
+ for (int i = 0; i < config.shapes.size(); i++) {
+ int[] dims = config.shapes.get(i);
+
+ double factor = config.factors != null ?
config.factors.get(i) : 0.0001;
+ sb.append("A");
+ sb.append(i);
+
+ if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001
+ sb.append(" = seq(1,");
+ sb.append(dims[0]);
+ sb.append(") * ");
+ sb.append(factor);
+ } else { // A0 = matrix(seq(1,50000), 1000, 50,
byrow=TRUE) * 0.0001
+ sb.append(" = matrix(seq(1, ");
+ sb.append(dims[0]*dims[1]);
+ sb.append("), ");
+ sb.append(dims[0]);
+ sb.append(", ");
+ sb.append(dims[1]);
+
+ sb.append(", byrow=TRUE) * ");
+ sb.append(factor);
+ }
+ sb.append("\n");
+ }
+ sb.append("\n");
+
+ sb.append("R = einsum(\"");
+ sb.append(config.einsumStr);
+ sb.append("\", ");
+
+ for (int i = 0; i < config.shapes.size()-1; i++) {
+ sb.append("A");
+ sb.append(i);
+ sb.append(", ");
+ }
+ sb.append("A");
+ sb.append(config.shapes.size()-1);
+ sb.append(")");
+
+ sb.append("\n\n");
+ if(outputScalar){
+ sb.append("write(R, paste(args[2], \"S\",
sep=\"\"))\n");
+ }else{
+ sb.append("writeMM(as(R, \"CsparseMatrix\"),
paste(args[2], \"S\", sep=\"\"))\n");
+ }
+ return sb;
+ }
+
+ @Test
+ public void testEinsumWithFiles() {
+ System.out.println("Testing einsum: " + this.einsumStr);
+ testCodegenIntegration(TEST_NAME_EINSUM+this.id);
+ }
+ @After
+ public void cleanUp() {
+ if (this.dmlFile.exists()) {
+ boolean deleted = this.dmlFile.delete();
+ if (!deleted) {
+ System.err.println("Failed to delete temp file:
" + this.dmlFile.getAbsolutePath());
+ }
+ }
+ if (this.rFile.exists()) {
+ boolean deleted = this.rFile.delete();
+ if (!deleted) {
+ System.err.println("Failed to delete temp file:
" + this.rFile.getAbsolutePath());
+ }
+ }
+ }
+
+ private static class Config {
+ public List<Double> factors;
+ String einsumStr;
+ List<int[]> shapes;
+
+ Config(String einsum, List<int[]> shapes) {
+ this.einsumStr = einsum;
+ this.shapes = shapes;
+ this.factors = null;
+ }
+ Config(String einsum, List<int[]> shapes, List<Double> factors)
{
+ this.einsumStr = einsum;
+ this.shapes = shapes;
+ this.factors = factors;
+ }
+ }
+
+ private static int[] shape(int... dims) {
+ return dims;
+ }
+ private static final Log LOG =
LogFactory.getLog(EinsumTest.class.getName());
+
+ private static final String TEST_NAME_EINSUM = "einsum";
+ private static final String TEST_DIR = "functions/einsum/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
EinsumTest.class.getSimpleName() + "/";
+ private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+ private final static File TEST_CONF_FILE = new File(SCRIPT_DIR +
TEST_DIR, TEST_CONF);
+
+ private static double eps = Math.pow(10, -10);
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ for(int i = 1; i<= TEST_CONFIGS.size(); i++)
+ addTestConfiguration( TEST_NAME_EINSUM+i, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] {
String.valueOf(i) }) );
+ }
+
+ private void testCodegenIntegration( String testname)
+ {
+ boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+ ExecMode platformOld = setExecMode(ExecType.CP);
+
+ String testnameDml = this.dmlFile.getAbsolutePath();
+ String testnameR = this.rFile.getAbsolutePath();
+ try
+ {
+ TestConfiguration config =
getTestConfiguration(testname);
+ loadTestConfiguration(config);
+
+ fullDMLScriptName = testnameDml;
+ programArgs = new String[]{"-stats", "-explain",
"-args", output("S") };
+ fullRScriptName = testnameR;
+ rCmd = getRCmd(inputDir(), expectedDir());
+
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
+
+ runTest(true, false, null, -1);
+ runRScript(true);
+
+ if(outputScalar){
+ HashMap<CellIndex, Double> dmlfile =
readDMLScalarFromOutputDir("S");
+ HashMap<CellIndex, Double> rfile =
readRScalarFromExpectedDir("S");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+ }else {
+ //compare matrices
+ HashMap<CellIndex, Double> dmlfile =
readDMLMatrixFromOutputDir("S");
+ HashMap<CellIndex, Double> rfile =
readRMatrixFromExpectedDir("S");
+ TestUtils.compareMatrices(dmlfile, rfile, eps,
"Stat-DML", "Stat-R");
+ }
+ }
+ finally {
+ resetExecMode(platformOld);
+ OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+ OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+ OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+ }
+ }
+
+
+ /**
+ * Override default configuration with custom test configuration to
ensure
+ * scratch space and local temporary directory locations are also
updated.
+ */
+ @Override
+ protected File getConfigTemplateFile() {
+ // Instrumentation in this test's output log to show custom
configuration file used for template.
+ LOG.debug("This test case overrides default configuration with
" + TEST_CONF_FILE.getPath());
+ return TEST_CONF_FILE;
+ }
+}
diff --git a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml
b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml
new file mode 100644
index 0000000000..626b31ebd7
--- /dev/null
+++ b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml
@@ -0,0 +1,31 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+ <sysds.localtmpdir>/tmp/systemds</sysds.localtmpdir>
+ <sysds.scratch>scratch_space</sysds.scratch>
+ <sysds.optlevel>2</sysds.optlevel>
+ <sysds.codegen.plancache>true</sysds.codegen.plancache>
+ <sysds.codegen.literals>1</sysds.codegen.literals>
+
+ <!-- The number of threads for the spark instance artificially selected-->
+ <sysds.local.spark.number.threads>16</sysds.local.spark.number.threads>
+
+ <sysds.codegen.api>auto</sysds.codegen.api>
+</root>
\ No newline at end of file
diff --git a/src/test/scripts/installDependencies.R
b/src/test/scripts/installDependencies.R
index af89f2b936..60642fa8ed 100644
--- a/src/test/scripts/installDependencies.R
+++ b/src/test/scripts/installDependencies.R
@@ -64,6 +64,7 @@ custom_install("unbalanced");
custom_install("naivebayes");
custom_install("BiocManager");
custom_install("mltools");
+custom_install("einsum");
BiocManager::install("rhdf5");
print("Installation Done")