This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 3741895624 [SYSTEMDS-3909] New einsum expression evaluation framework
3741895624 is described below

commit 3741895624c5650aeb08808fad212ce0e7f9e853
Author: Hubert Krawczyk <[email protected]>
AuthorDate: Thu Aug 21 14:35:15 2025 +0200

    [SYSTEMDS-3909] New einsum expression evaluation framework
    
    Closes #2312.
    Closes #2265.
---
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 .../org/apache/sysds/common/InstructionType.java   |   1 +
 src/main/java/org/apache/sysds/common/Opcodes.java |   1 +
 src/main/java/org/apache/sysds/common/Types.java   |   4 +-
 src/main/java/org/apache/sysds/hops/NaryOp.java    |   9 +
 .../apache/sysds/hops/codegen/cplan/CNodeCell.java |   2 +-
 .../apache/sysds/hops/codegen/cplan/CNodeData.java |   8 +
 src/main/java/org/apache/sysds/lops/Nary.java      |   1 +
 .../sysds/parser/BuiltinFunctionExpression.java    |  53 +-
 .../org/apache/sysds/parser/DMLTranslator.java     |   5 +-
 .../apache/sysds/runtime/einsum/EinsumContext.java | 177 +++++
 .../runtime/einsum/EinsumEquationValidator.java    | 144 ++++
 .../runtime/instructions/CPInstructionParser.java  |   4 +
 .../instructions/cp/BuiltinNaryCPInstruction.java  |   5 +-
 .../runtime/instructions/cp/CPInstruction.java     |   2 +-
 .../instructions/cp/EinsumCPInstruction.java       | 837 +++++++++++++++++++++
 .../sysds/test/functions/einsum/EinsumTest.java    | 362 +++++++++
 .../functions/einsum/SystemDS-config-codegen.xml   |  31 +
 src/test/scripts/installDependencies.R             |   1 +
 19 files changed, 1640 insertions(+), 8 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index fe75aec6a0..5fe1721cc2 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -403,6 +403,7 @@ public enum Builtins {
        UNDER_SAMPLING("underSampling", true),
        UNIQUE("unique", false, true),
        UPPER_TRI("upper.tri", false, true),
+       EINSUM("einsum", false, false),
        XDUMMY1("xdummy1", true), //error handling test
        XDUMMY2("xdummy2", true); //error handling test
 
diff --git a/src/main/java/org/apache/sysds/common/InstructionType.java 
b/src/main/java/org/apache/sysds/common/InstructionType.java
index 1980dd7984..29148f03e9 100644
--- a/src/main/java/org/apache/sysds/common/InstructionType.java
+++ b/src/main/java/org/apache/sysds/common/InstructionType.java
@@ -62,6 +62,7 @@ public enum InstructionType {
        PMMJ,
        MMChain,
        Union,
+       EINSUM,
 
        //SP Types
        MAPMM,
diff --git a/src/main/java/org/apache/sysds/common/Opcodes.java 
b/src/main/java/org/apache/sysds/common/Opcodes.java
index 64a6c7dd27..6cd1561128 100644
--- a/src/main/java/org/apache/sysds/common/Opcodes.java
+++ b/src/main/java/org/apache/sysds/common/Opcodes.java
@@ -174,6 +174,7 @@ public enum Opcodes {
        RBIND("rbind", InstructionType.BuiltinNary),
        EVAL("eval", InstructionType.BuiltinNary),
        LIST("list", InstructionType.BuiltinNary),
+       EINSUM("einsum", InstructionType.BuiltinNary),
 
        //Parametrized builtin functions
        AUTODIFF("autoDiff", InstructionType.ParameterizedBuiltin),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index cc7f6eb377..09a0f8effd 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -767,8 +767,8 @@ public interface Types {
        
        /** Operations that require a variable number of operands*/
        public enum OpOpN {
-               PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST;
-               
+               PRINTF, CBIND, RBIND, MIN, MAX, PLUS, MULT, EVAL, LIST, EINSUM;
+
                public boolean isCellOp() {
                        return this == MIN || this == MAX || this == PLUS || 
this == MULT;
                }
diff --git a/src/main/java/org/apache/sysds/hops/NaryOp.java 
b/src/main/java/org/apache/sysds/hops/NaryOp.java
index 1659b0dbc5..6962beadcb 100644
--- a/src/main/java/org/apache/sysds/hops/NaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/NaryOp.java
@@ -26,6 +26,7 @@ import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.lops.Nary;
+import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
@@ -235,6 +236,14 @@ public class NaryOp extends Hop {
                                setDim1(getInput().size());
                                setDim2(1);
                                break;
+                       case EINSUM:
+                               String equationString = ((LiteralOp) 
_input.get(0)).getStringValue();
+                               var dims = 
EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString,
 this.getInput().subList(1, this.getInput().size()));
+
+                               setDim1(dims.getLeft());
+                               setDim2(dims.getMiddle());
+                               setDataType(dims.getRight());
+                               break;
                        case PRINTF:
                        case EVAL:
                                //do nothing:
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
index 3d2c19ef4c..2482ac77e2 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeCell.java
@@ -31,7 +31,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
 
 public class CNodeCell extends CNodeTpl 
 {
-       protected static final String JAVA_TEMPLATE = 
+       public static final String JAVA_TEMPLATE =
                  "package codegen;\n"
                + "import 
org.apache.sysds.runtime.codegen.LibSpoofPrimitives;\n"
                + "import org.apache.sysds.runtime.codegen.SpoofCellwise;\n"
diff --git a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java 
b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
index 9292972874..b90789df5e 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeData.java
@@ -56,6 +56,14 @@ public class CNodeData extends CNode
                _cols = node.getNumCols();
                _dataType = node.getDataType();
        }
+
+       public CNodeData(String name, long hopID, long rows, long cols, 
DataType dataType) {
+               _name = name;
+               _hopID = hopID;
+               _rows = rows;
+               _cols = cols;
+               _dataType = dataType;
+       }
        
        @Override
        public String getVarname() {
diff --git a/src/main/java/org/apache/sysds/lops/Nary.java 
b/src/main/java/org/apache/sysds/lops/Nary.java
index e073bc6881..e5382ba033 100644
--- a/src/main/java/org/apache/sysds/lops/Nary.java
+++ b/src/main/java/org/apache/sysds/lops/Nary.java
@@ -111,6 +111,7 @@ public class Nary extends Lop {
                        case RBIND:
                        case EVAL:
                        case LIST:
+                       case EINSUM:
                                return operationType.name().toLowerCase();
                        case MIN:
                        case MAX:
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index 540b522a8b..28f6949f72 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedList;
 
 import org.antlr.v4.runtime.ParserRuleContext;
 import org.apache.commons.lang3.ArrayUtils;
@@ -35,6 +36,7 @@ import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.parser.LanguageException.LanguageErrorCodes;
+import org.apache.sysds.runtime.einsum.EinsumEquationValidator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.DnnUtils;
 import org.apache.sysds.runtime.util.UtilFunctions;
@@ -751,7 +753,9 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        else
                                raiseValidateError("Compress/DeCompress 
instruction not allowed in dml script");
                        break;
-                                                       
+               case EINSUM:
+                       validateEinsum((DataIdentifier) getOutputs()[0]);
+                       break;
                default: //always unconditional
                        raiseValidateError("Unknown Builtin Function opcode: " 
+ _opcode, false);
                }
@@ -2063,7 +2067,9 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                        output.setValueType(ValueType.INT64);
                        output.setNnz(id.getDim2());
                        break;
-
+               case EINSUM:
+                       validateEinsum(output);
+                       break;
                default:
                        if( isMathFunction() ) {
                                checkMathFunctionParam();
@@ -2096,6 +2102,49 @@ public class BuiltinFunctionExpression extends 
DataIdentifier {
                }
        }
 
+       private void validateEinsum(DataIdentifier output){
+               if(getSecondExpr() == null)
+                       raiseValidateError("Einsum: at least one input matrix 
required", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+
+               if(!(getFirstExpr() instanceof StringIdentifier))
+                       raiseValidateError("Einsum: first argument has to be 
equation str", false,
+                                       LanguageErrorCodes.INVALID_PARAMETERS);
+
+               String equationString = 
((StringIdentifier)getFirstExpr()).getValue();
+
+               if (equationString.length() == 0) raiseValidateError("Einsum: 
equation str too short", false, LanguageErrorCodes.INVALID_PARAMETERS);
+               if (equationString.charAt(0) == '-' || equationString.charAt(0) 
== ',') raiseValidateError("Einsum: equation str invalid", false, 
LanguageErrorCodes.INVALID_PARAMETERS);
+
+               Expression[] expressions = getAllExpr();
+               boolean allDimsKnown = true;
+
+               LinkedList<Identifier> matrixBlocks = new LinkedList<>();
+               for (int i=1;i<expressions.length; i++){
+                       checkMatrixParam(expressions[i]);
+                       if(!(expressions[i]).getOutput().dimsKnown()){
+                               allDimsKnown = false;
+                               break;
+                       }
+                       matrixBlocks.add((expressions[i].getOutput()));
+               }
+
+               if(allDimsKnown){
+                       var dims = 
EinsumEquationValidator.validateEinsumEquationAndReturnDimensions(equationString,
 matrixBlocks);
+
+                       output.setDataType(dims.getRight());
+                       output.setDimensions(dims.getLeft(), dims.getMiddle());
+               }else{
+                       DataType dataType = 
EinsumEquationValidator.validateEinsumEquationNoDimensions(equationString, 
_args.length - 1);
+
+                       output.setDataType(dataType);
+                       output.setDimensions(-1l, -1l);
+               }
+
+               output.setValueType(ValueType.FP64);
+               output.setBlocksize(getSecondExpr().getOutput().getBlocksize());
+       }
+
        private void setBinaryOutputProperties(DataIdentifier output) {
                DataType dt1 = getFirstExpr().getOutput().getDataType();
                DataType dt2 = getSecondExpr().getOutput().getDataType();
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index 6827bcc4bf..092fbffe36 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2447,7 +2447,10 @@ public class DMLTranslator
                                new NaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
                                        
OpOpN.valueOf(source.getOpCode().name()), 
processAllExpressions(source.getAllExpr(), hops));
                        break;
-
+               case EINSUM:
+                       currBuiltinOp = new NaryOp(target.getName(), 
target.getDataType(), target.getValueType(),
+                                       
OpOpN.valueOf(source.getOpCode().name()), 
processAllExpressions(source.getAllExpr(), hops));
+                       break;
                case PPRED:
                        String sop = 
((StringIdentifier)source.getThirdExpr()).getValue();
                        sop = sop.replace("\"", "");
diff --git a/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
new file mode 100644
index 0000000000..6da39e5987
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumContext.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+
+
+public class EinsumContext {
+    public enum ContractDimensions {
+        CONTRACT_LEFT,
+        CONTRACT_RIGHT,
+        CONTRACT_BOTH,
+    }
+    public Integer outRows;
+    public Integer outCols;
+    public Character outChar1;
+    public Character outChar2;
+    public HashMap<Character, Integer> charToDimensionSize;
+    public String equationString;
+    public boolean[] diagonalInputs;
+    public HashSet<Character> summingChars;
+    public HashSet<Character> contractDimsSet;
+    public ContractDimensions[] contractDims;
+    public ArrayList<String> newEquationStringInputsSplit;
+    public HashMap<Character, ArrayList<Integer>> characterAppearanceIndexes; 
// for each character, this tells in which inputs it appears
+
+    private EinsumContext(){};
+    public static EinsumContext getEinsumContext(String eqStr, 
ArrayList<MatrixBlock> inputs){
+        EinsumContext res = new EinsumContext();
+
+        res.equationString = eqStr;
+        res.charToDimensionSize = new HashMap<Character, Integer>();
+        HashSet<Character> summingChars = new HashSet<>();
+        ContractDimensions[] contractDims = new 
ContractDimensions[inputs.size()];
+        boolean[] diagonalInputs = new boolean[inputs.size()]; // all false by 
default
+        HashSet<Character> contractDimsSet = new HashSet<>();
+        HashMap<Character, ArrayList<Integer>> partsCharactersToIndices = new 
HashMap<>();
+        ArrayList<String> newEquationStringSplit = new ArrayList<>();
+
+        Iterator<MatrixBlock> it = inputs.iterator();
+        MatrixBlock curArr = it.next();
+        int arrSizeIterator = 0;
+        int arrayIterator = 0;
+        int i;
+        // first iteration through string: collect information on 
character-size and what characters are summing characters
+        for (i = 0; true; i++) {
+            char c = eqStr.charAt(i);
+            if(c == '-'){
+                i+=2;
+                break;
+            }
+            if(c == ','){
+                arrayIterator++;
+                curArr = it.next();
+                arrSizeIterator = 0;
+            }
+            else{
+                if (res.charToDimensionSize.containsKey(c)) { // sanity check 
if dims match, this is already checked at validation
+                    if(arrSizeIterator == 0 && res.charToDimensionSize.get(c) 
!= curArr.getNumRows())
+                        throw new RuntimeException("Einsum: character "+c+" 
has multiple conflicting sizes");
+                    else if(arrSizeIterator == 1 && 
res.charToDimensionSize.get(c) != curArr.getNumColumns())
+                        throw new RuntimeException("Einsum: character "+c+" 
has multiple conflicting sizes");
+                    summingChars.add(c);
+                } else {
+                    if(arrSizeIterator == 0)
+                        res.charToDimensionSize.put(c, curArr.getNumRows());
+                    else if(arrSizeIterator == 1)
+                        res.charToDimensionSize.put(c, curArr.getNumColumns());
+                }
+
+                arrSizeIterator++;
+            }
+        }
+
+        int numOfRemainingChars = eqStr.length() - i;
+
+        if (numOfRemainingChars > 2)
+            throw new RuntimeException("Einsum: dim > 2 not supported");
+
+        arrSizeIterator = 0;
+
+        Character outChar1 = numOfRemainingChars > 0 ? eqStr.charAt(i) : null;
+        Character outChar2 = numOfRemainingChars > 1 ? eqStr.charAt(i+1) : 
null;
+        res.outRows=(numOfRemainingChars > 0 ? 
res.charToDimensionSize.get(outChar1) : 1);
+        res.outCols=(numOfRemainingChars > 1 ? 
res.charToDimensionSize.get(outChar2) : 1);
+
+        arrayIterator=0;
+        // second iteration through string: collect remaining information
+        for (i = 0; true; i++) {
+            char c = eqStr.charAt(i);
+            if (c == '-') {
+                break;
+            }
+            if (c == ',') {
+                arrayIterator++;
+                arrSizeIterator = 0;
+                continue;
+            }
+            String s = "";
+
+            if(summingChars.contains(c)) {
+                s+=c;
+                if(!partsCharactersToIndices.containsKey(c))
+                    partsCharactersToIndices.put(c, new ArrayList<>());
+                partsCharactersToIndices.get(c).add(arrayIterator);
+            }
+            else if((outChar1 != null && c == outChar1) || (outChar2 != null 
&& c == outChar2)) {
+                s+=c;
+            }
+            else {
+                contractDimsSet.add(c);
+                contractDims[arrayIterator] = ContractDimensions.CONTRACT_LEFT;
+            }
+
+            if(i + 1 < eqStr.length()) { // process next character together
+                char c2 = eqStr.charAt(i + 1);
+                i++;
+                if (c2 == '-') { newEquationStringSplit.add(s); break;}
+                if (c2 == ',') { arrayIterator++; 
newEquationStringSplit.add(s); continue; }
+
+                if (c2 == c){
+                    diagonalInputs[arrayIterator] = true;
+                    if (contractDims[arrayIterator] == 
ContractDimensions.CONTRACT_LEFT) contractDims[arrayIterator] = 
ContractDimensions.CONTRACT_BOTH;
+                }
+                else{
+                    if(summingChars.contains(c2)) {
+                        s+=c2;
+                        if(!partsCharactersToIndices.containsKey(c2))
+                            partsCharactersToIndices.put(c2, new 
ArrayList<>());
+                        partsCharactersToIndices.get(c2).add(arrayIterator);
+                    }
+                    else if((outChar1 != null && c2 == outChar1) || (outChar2 
!= null && c2 == outChar2)) {
+                        s+=c2;
+                    }
+                    else {
+                        contractDimsSet.add(c2);
+                        contractDims[arrayIterator] = 
contractDims[arrayIterator] == ContractDimensions.CONTRACT_LEFT ? 
ContractDimensions.CONTRACT_BOTH : ContractDimensions.CONTRACT_RIGHT;
+                    }
+                }
+            }
+            newEquationStringSplit.add(s);
+            arrSizeIterator++;
+        }
+
+        res.contractDims = contractDims;
+        res.contractDimsSet = contractDimsSet;
+        res.diagonalInputs = diagonalInputs;
+        res.summingChars = summingChars;
+        res.outChar1 = outChar1;
+        res.outChar2 = outChar2;
+        res.newEquationStringInputsSplit = newEquationStringSplit;
+        res.characterAppearanceIndexes = partsCharactersToIndices;
+        return res;
+    }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java 
b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
new file mode 100644
index 0000000000..5643159ef9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/einsum/EinsumEquationValidator.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.einsum;
+
+import org.apache.commons.lang3.tuple.Triple;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.parser.Identifier;
+import org.apache.sysds.parser.LanguageException;
+import org.apache.sysds.parser.ParseInfo;
+
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+
+public class EinsumEquationValidator {
+
+    public static <HopOrIdentifier extends ParseInfo> Triple<Long, Long, 
Types.DataType> validateEinsumEquationAndReturnDimensions(String 
equationString, List<HopOrIdentifier> expressionsOrIdentifiers) throws 
LanguageException {
+        String[] eqStringParts = equationString.split("->"); // length 2 if 
"...->..." , length 1 if "...->"
+        boolean isResultScalar = eqStringParts.length == 1;
+
+        if(expressionsOrIdentifiers == null)
+            throw new RuntimeException("Einsum: called 
validateEinsumAndReturnDimensions with null list");
+
+        HashMap<Character, Long> charToDimensionSize = new HashMap<>();
+        Iterator<HopOrIdentifier> it = expressionsOrIdentifiers.iterator();
+        HopOrIdentifier currArr = it.next();
+        int arrSizeIterator = 0;
+        int numberOfMatrices = 1;
+        for (int i = 0; i < eqStringParts[0].length(); i++) {
+            char c = equationString.charAt(i);
+            if(c==' ') continue;
+            if(c==','){
+                if(!it.hasNext())
+                    throw new LanguageException("Einsum: Provided less 
operands than specified in equation str");
+                currArr = it.next();
+                arrSizeIterator = 0;
+                numberOfMatrices++;
+            } else{
+                long thisCharDimension = getThisCharDimension(currArr, 
arrSizeIterator);
+                if (charToDimensionSize.containsKey(c)){
+                    if (charToDimensionSize.get(c) != thisCharDimension)
+                        throw new LanguageException("Einsum: Character '" + c 
+ "' expected to be dim " + charToDimensionSize.get(c) + ", but found " + 
thisCharDimension);
+                }else{
+                    charToDimensionSize.put(c, thisCharDimension);
+                }
+                arrSizeIterator++;
+            }
+        }
+        if (expressionsOrIdentifiers.size() - 1 > numberOfMatrices)
+            throw new LanguageException("Einsum: Provided more operands than 
specified in equation str");
+
+        if (isResultScalar)
+            return Triple.of(-1l,-1l, Types.DataType.SCALAR);
+
+        int numberOfOutDimensions = 0;
+        Character dim1Char = null;
+        long dim1 = 1;
+        long dim2 = 1;
+        for (int i = 0; i < eqStringParts[1].length(); i++) {
+            char c = eqStringParts[1].charAt(i);
+            if (c == ' ') continue;
+            if (numberOfOutDimensions == 0) {
+                dim1Char = c;
+                dim1 = charToDimensionSize.get(c);
+            } else {
+                if(c==dim1Char) throw new LanguageException("Einsum: output 
character "+c+" provided multiple times");
+                dim2 = charToDimensionSize.get(c);
+            }
+            numberOfOutDimensions++;
+        }
+        if (numberOfOutDimensions > 2) {
+            throw new LanguageException("Einsum: output matrices with with no. 
dims > 2 not supported");
+        } else {
+            return Triple.of(dim1, dim2, Types.DataType.MATRIX);
+        }
+    }
+
+    public static Types.DataType validateEinsumEquationNoDimensions(String 
equationString, int numberOfMatrixInputs) throws LanguageException {
+        String[] eqStringParts = equationString.split("->"); // length 2 if 
"...->..." , length 1 if "...->"
+        boolean isResultScalar = eqStringParts.length == 1;
+
+        int numberOfMatrices = 1;
+        for (int i = 0; i < eqStringParts[0].length(); i++) {
+            char c = eqStringParts[0].charAt(i);
+            if(c == ' ') continue;
+            if(c == ',')
+                numberOfMatrices++;
+        }
+        if(numberOfMatrixInputs != numberOfMatrices){
+            throw  new LanguageException("Einsum: Invalid number of 
parameters, given: " + numberOfMatrixInputs + ", expected: " + 
numberOfMatrices);
+        }
+
+        if(isResultScalar){
+            return Types.DataType.SCALAR;
+        }else {
+            int numberOfDimensions = 0;
+            Character dim1Char = null;
+            for (int i = 0; i < eqStringParts[1].length(); i++) {
+                char c = eqStringParts[i].charAt(i);
+                if(c == ' ') continue;
+                numberOfDimensions++;
+                if (numberOfDimensions == 1 && c == dim1Char)
+                    throw new LanguageException("Einsum: output character 
"+c+" provided multiple times");
+                dim1Char = c;
+            }
+
+            if (numberOfDimensions > 2) {
+                throw new LanguageException("Einsum: output matrices with with 
no. dims > 2 not supported");
+            } else {
+                return Types.DataType.MATRIX;
+            }
+        }
+    }
+
+    private static <HopOrIdentifier extends ParseInfo> long 
getThisCharDimension(HopOrIdentifier currArr, int arrSizeIterator) {
+        long thisCharDimension;
+        if(currArr instanceof Hop){
+            thisCharDimension = arrSizeIterator == 0 ? ((Hop) 
currArr).getDim1()  : ((Hop) currArr).getDim2();
+        } else if(currArr instanceof Identifier){
+            thisCharDimension = arrSizeIterator == 0 ? ((Identifier) 
currArr).getDim1()  : ((Identifier) currArr).getDim2();
+        } else {
+            throw new RuntimeException("validateEinsumAndReturnDimensions 
called with expressions that are not Hop or Identifier: "+ currArr == null ? 
"null" : currArr.getClass().toString());
+        }
+        return thisCharDimension;
+    }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index fa443378e6..92e11b425d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -65,6 +65,7 @@ import 
org.apache.sysds.runtime.instructions.cp.UaggOuterChainCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.UnionCPInstruction;
+import org.apache.sysds.runtime.instructions.cp.EinsumCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cpfile.MatrixIndexingCPFileInstruction;
 
 public class CPInstructionParser extends InstructionParser {
@@ -223,6 +224,9 @@ public class CPInstructionParser extends InstructionParser {
                        case Union:
                                return UnionCPInstruction.parseInstruction(str);
                        
+                       case EINSUM:
+                               return 
EinsumCPInstruction.parseInstruction(str);
+                               
                        default:
                                throw new DMLRuntimeException("Invalid CP 
Instruction Type: " + cptype );
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
index 6d230a30f0..e7aa1b5fd7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/BuiltinNaryCPInstruction.java
@@ -91,10 +91,13 @@ public abstract class BuiltinNaryCPInstruction extends 
CPInstruction
                        return new MatrixBuiltinNaryCPInstruction(
                                        new 
SimpleOperator(Multiply.getMultiplyFnObject()), opcode, str, outputOperand, 
inputOperands);
                }
+               else if( opcode.equals(Opcodes.EINSUM.toString()) ) {
+                       return new EinsumCPInstruction(null, opcode, str, 
outputOperand, inputOperands);
+               }
                else if (OpOpN.EVAL.name().equalsIgnoreCase(opcode)) {
                        return new EvalNaryCPInstruction(null, opcode, str, 
outputOperand, inputOperands);
                }
-               
+
                throw new DMLRuntimeException("Opcode (" + opcode + ") not 
recognized in BuiltinMultipleCPInstruction");
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
index c99039bb7f..b35ca55dab 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/CPInstruction.java
@@ -44,7 +44,7 @@ public abstract class CPInstruction extends Instruction {
                Builtin, Reorg, Variable, FCall, Append, Rand, QSort, QPick, 
Local,
                MatrixIndexing, MMTSJ, PMMJ, MMChain, Reshape, Partition, 
Compression, DeCompression, SpoofFused,
                StringInit, CentralMoment, Covariance, UaggOuterChain, Dnn, 
Sql, Prefetch, Broadcast, TrigRemote,
-               EvictLineageCache,
+               EvictLineageCache, EINSUM,
                NoOp,
                Union,
                QuantizeCompression
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
new file mode 100644
index 0000000000..c67dd29079
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/EinsumCPInstruction.java
@@ -0,0 +1,837 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.instructions.cp;
+
+import org.apache.commons.lang3.NotImplementedException;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.commons.lang3.tuple.Triple;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.hops.LiteralOp;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.hops.codegen.SpoofCompiler;
+import org.apache.sysds.hops.codegen.cplan.CNode;
+import org.apache.sysds.hops.codegen.cplan.CNodeBinary;
+import org.apache.sysds.hops.codegen.cplan.CNodeCell;
+import org.apache.sysds.hops.codegen.cplan.CNodeData;
+import org.apache.sysds.hops.codegen.cplan.CNodeRow;
+import org.apache.sysds.runtime.codegen.*;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.einsum.EinsumContext;
+import org.apache.sysds.runtime.functionobjects.*;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.matrix.operators.SimpleOperator;
+
+import java.util.*;
+import java.util.function.Predicate;
+
+public class EinsumCPInstruction extends BuiltinNaryCPInstruction {
+       public static boolean FORCE_CELL_TPL = false;
+       protected static final Log LOG = 
LogFactory.getLog(EinsumCPInstruction.class.getName());
+       public String eqStr;
+       private final int _numThreads;
+       private final CPOperand[] _in;
+
+       public EinsumCPInstruction(Operator op, String opcode, String istr, 
CPOperand out, CPOperand... inputs)
+       {
+               super(op, opcode, istr, out, inputs);
+               _numThreads = OptimizerUtils.getConstrainedNumThreads(-1);
+               _in = inputs;
+               this.eqStr = inputs[0].getName();
+               
Logger.getLogger(EinsumCPInstruction.class).setLevel(Level.TRACE);
+       }
+
+       @SuppressWarnings("unused")
+       private EinsumContext einc = null;
+
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               //get input matrices and scalars, incl pinning of matrices
+               ArrayList<MatrixBlock> inputs = new ArrayList<>();
+               for (CPOperand input : _in) {
+                       if(input.getDataType()==DataType.MATRIX){
+                               MatrixBlock mb = 
ec.getMatrixInput(input.getName());
+                               if(mb instanceof CompressedMatrixBlock){
+                                       mb = ((CompressedMatrixBlock) 
mb).getUncompressed("Spoof instruction");
+                               }
+                               inputs.add(mb);
+                       }
+               }
+
+               EinsumContext einc = EinsumContext.getEinsumContext(eqStr, 
inputs);
+
+               this.einc = einc;
+               String resultString = einc.outChar2 != null ? 
String.valueOf(einc.outChar1) + einc.outChar2 : einc.outChar1 != null ? 
String.valueOf(einc.outChar1) : "";
+
+               if( LOG.isDebugEnabled() ) LOG.trace("outrows:"+einc.outRows+", 
outcols:"+einc.outCols);
+
+               ArrayList<String> inputsChars = 
einc.newEquationStringInputsSplit;
+
+               if(LOG.isTraceEnabled()) 
LOG.trace(String.join(",",einc.newEquationStringInputsSplit));
+
+               contractDimensionsAndComputeDiagonals(einc, inputs);
+
+               //make all vetors col vectors
+               for(int i = 0; i < inputs.size(); i++){
+                       if(inputs.get(i) != null && inputsChars.get(i).length() 
== 1) EnsureMatrixBlockColumnVector(inputs.get(i));
+               }
+
+               if(LOG.isTraceEnabled()) for(Character c : 
einc.characterAppearanceIndexes.keySet()){
+                       ArrayList<Integer> a = 
einc.characterAppearanceIndexes.get(c);
+                       LOG.trace(c+" count= "+a.size());
+               }
+
+               // compute scalar by suming-all matrices:
+               Double scalar = null;
+               for(int i=0;i< inputs.size(); i++){
+                       String s = inputsChars.get(i);
+                       if(s.equals("")){
+                               MatrixBlock mb = inputs.get(i);
+                               if (scalar == null) scalar = mb.get(0,0);
+                               else scalar*= mb.get(0,0);
+                               inputs.set(i,null);
+                               inputsChars.set(i,null);
+                       }
+               }
+
+               if (scalar != null) {
+                       inputsChars.add("");
+                       inputs.add(new MatrixBlock(scalar));
+               }
+
+               HashMap<Character, Integer> characterToOccurences = new 
HashMap<>();
+               for (Character key :einc.characterAppearanceIndexes.keySet()) {
+                       characterToOccurences.put(key, 
einc.characterAppearanceIndexes.get(key).size());
+               }
+               for (Character key :einc.charToDimensionSize.keySet()) {
+                       if(!characterToOccurences.containsKey(key))
+                               characterToOccurences.put(key, 1);
+               }
+
+               ArrayList<EOpNode> eOpNodes = new 
ArrayList<>(inputsChars.size());
+               for (int i = 0; i < inputsChars.size(); i++) {
+                       if (inputsChars.get(i) == null) continue;
+                       EOpNodeData n = new 
EOpNodeData(inputsChars.get(i).length() > 0 ? inputsChars.get(i).charAt(0) : 
null, inputsChars.get(i).length() > 1 ? inputsChars.get(i).charAt(1) : null, i);
+                       eOpNodes.add(n);
+               }
+               Pair<Integer, List<EOpNode> > plan = FORCE_CELL_TPL ? null : 
generatePlan(0, eOpNodes, einc.charToDimensionSize, characterToOccurences, 
einc.outChar1, einc.outChar2);
+
+
+               ArrayList<MatrixBlock> resMatrices = FORCE_CELL_TPL ? null : 
executePlan(plan.getRight(), inputs);
+//             ArrayList<MatrixBlock> resMatrices = 
executePlan(plan.getRight(), inputs, true);
+
+               if(!FORCE_CELL_TPL && resMatrices.size() == 1){
+                       EOpNode resNode = plan.getRight().get(0);
+                       if (einc.outChar1 != null && einc.outChar2 != null){
+                               if(resNode.c1 == einc.outChar1 && resNode.c2 == 
einc.outChar2){
+                                       ec.setMatrixOutput(output.getName(), 
resMatrices.get(0));
+                               }
+                               else if(resNode.c1 == einc.outChar2 && 
resNode.c2 == einc.outChar1){
+                                       ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+                                       MatrixBlock resM = 
resMatrices.get(0).reorgOperations(transpose, new MatrixBlock(),0,0,0);
+                                       ec.setMatrixOutput(output.getName(), 
resM);
+                               }else{
+                                       if(LOG.isTraceEnabled()) 
LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
+                                       throw new RuntimeException("Einsum plan 
produced different result");
+                               }
+                       }else if (einc.outChar1 != null){
+                               if(resNode.c1 == einc.outChar1  && resNode.c2 
== null){
+                                       ec.setMatrixOutput(output.getName(), 
resMatrices.get(0));
+                               }else{
+                                       if(LOG.isTraceEnabled()) 
LOG.trace("Einsum expected: "+resultString + ", got: "+resNode.c1+resNode.c2);
+                                       throw new RuntimeException("Einsum plan 
produced different result");
+                               }
+                       }else{
+                               if(resNode.c1 == null && resNode.c2 == null){
+                                       ec.setScalarOutput(output.getName(), 
new DoubleObject(resMatrices.get(0).get(0, 0)));;
+                               }
+                       }
+               }else{
+                       // use cell template with loops for remaining
+                       ArrayList<MatrixBlock> mbs = resMatrices;
+                       ArrayList<String> chars = new ArrayList<>();
+
+                       for (int i = 0; i < plan.getRight().size(); i++) {
+                               String s;
+                               if(plan.getRight().get(i).c1 == null) s = "";
+                               else if(plan.getRight().get(i).c2 == null) s = 
plan.getRight().get(i).c1.toString();
+                               else s = plan.getRight().get(i).c1.toString() + 
plan.getRight().get(i).c2;
+                               chars.add(s);
+                       }
+
+                       ArrayList<Character> summingChars = new ArrayList<>();
+                       for (Character c : 
einc.characterAppearanceIndexes.keySet()) {
+                               if (c != einc.outChar1 && c != einc.outChar2) 
summingChars.add(c);
+                       }
+                       if(LOG.isTraceEnabled()) LOG.trace("finishing with cell 
tpl: "+String.join(",", chars));
+
+                       MatrixBlock res = computeCellSummation(mbs, chars, 
resultString, einc.charToDimensionSize, summingChars, einc.outRows, 
einc.outCols);
+
+                       if (einc.outRows == 1 && einc.outCols == 1)
+                               ec.setScalarOutput(output.getName(), new 
DoubleObject(res.get(0, 0)));
+                       else ec.setMatrixOutput(output.getName(), res);
+               }
+               if(LOG.isTraceEnabled()) LOG.trace("EinsumCPInstruction 
Finished");
+
+               releaseMatrixInputs(ec);
+
+       }
+
+       private void contractDimensionsAndComputeDiagonals(EinsumContext einc, 
ArrayList<MatrixBlock> inputs) {
+               for(int i = 0; i< einc.contractDims.length; i++){
+                       //AggregateOperator agg = new AggregateOperator(0, 
KahanPlus.getKahanPlusFnObject(),Types.CorrectionLocationType.LASTCOLUMN);
+                       AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+
+                       if(einc.diagonalInputs[i]){
+                               ReorgOperator op = new 
ReorgOperator(DiagIndex.getDiagIndexFnObject());
+                               inputs.set(i, inputs.get(i).reorgOperations(op, 
new MatrixBlock(),0,0,0));
+                       }
+                       if (einc.contractDims[i] == null) continue;
+                       switch (einc.contractDims[i]){
+                               case CONTRACT_BOTH: {
+                                       AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+                                       MatrixBlock res = new MatrixBlock(1, 1, 
false);
+                                       
inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null);
+                                       inputs.set(i, res);
+                                       break;
+                               }
+                               case CONTRACT_RIGHT: {
+                                       AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+                                       MatrixBlock res = new 
MatrixBlock(inputs.get(i).getNumRows(), 1, false);
+                                       
inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null);
+                                       inputs.set(i, res);
+                                       break;
+                               }
+                               case CONTRACT_LEFT: {
+                                       AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
+                                       MatrixBlock res = new 
MatrixBlock(inputs.get(i).getNumColumns(), 1, false);
+                                       
inputs.get(i).aggregateUnaryOperations(aggun, res, 0, null);
+                                       inputs.set(i, res);
+                                       break;
+                               }
+                               default:
+                                       break;
+                       }
+               }
+       }
+
+       private enum EBinaryOperand { // upper case: char has to remain, lower 
case: to be summed
+               ////// summations:   //////
+               aB_a,// -> B
+               Ba_a, // -> B
+               Ba_aC, // mmult -> BC
+               aB_Ca,
+               Ba_Ca, // -> BC
+               aB_aC, // outer mult, possibly with transposing first -> BC
+               a_a,// dot ->
+
+               ////// elementwisemult and sums, something like ij,ij->i   
//////
+               aB_aB,// elemwise and colsum -> B
+               Ba_Ba, // elemwise and rowsum ->B
+               Ba_aB, // elemwise, either colsum or rowsum -> B
+//             aB_Ba,
+
+               ////// elementwise, no summations:   //////
+               A_A,// v-elemwise -> A
+               AB_AB,// M-M elemwise -> AB
+               AB_BA, // M-M.T elemwise -> AB
+               AB_A, // M-v colwise -> BA!?
+               BA_A, // M-v rowwise -> BA
+               ab_ab,//M-M sum all
+               ab_ba, //M-M.T sum all
+               ////// other   //////
+               A_B, // outer mult -> AB
+               A_scalar, // v-scalar
+               AB_scalar, // m-scalar
+               scalar_scalar
+       }
+       private abstract class EOpNode {
+               public Character c1;
+               public Character c2; // nullable
+               public EOpNode(Character c1, Character c2){
+                       this.c1 = c1;
+                       this.c2 = c2;
+               }
+       }
+       private class EOpNodeBinary extends EOpNode {
+               public EOpNode left;
+               public EOpNode right;
+               public EBinaryOperand operand;
+               public EOpNodeBinary(Character c1, Character c2, EOpNode left, 
EOpNode right, EBinaryOperand operand){
+                       super(c1,c2);
+                       this.left = left;
+                       this.right = right;
+                       this.operand = operand;
+               }
+       }
+       private class EOpNodeData extends EOpNode {
+               public int matrixIdx;
+               public EOpNodeData(Character c1, Character c2, int matrixIdx){
+                       super(c1,c2);
+                       this.matrixIdx = matrixIdx;
+               }
+       }
+
+       private Pair<Integer, List<EOpNode> /* ideally with one element */> 
generatePlan(int cost, ArrayList<EOpNode> operands, HashMap<Character, Integer> 
charToSizeMap, HashMap<Character, Integer> charToOccurences, Character 
outChar1, Character outChar2) {
+               Integer minCost = cost;
+               List<EOpNode> minNodes = operands;
+
+               if (operands.size() == 2){
+                       boolean swap = (operands.get(0).c2 == null && 
operands.get(1).c2 != null) || operands.get(0).c1 == null;
+                       EOpNode n1 = operands.get(!swap ? 0 : 1);
+                       EOpNode n2 = operands.get(!swap ? 1 : 0);
+                       Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, 
outChar1, outChar2);
+                       if (t != null) {
+                               EOpNodeBinary newNode = new 
EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, 
t.getMiddle());
+                               int thisCost = cost + t.getLeft();
+                               return Pair.of(thisCost, 
Arrays.asList(newNode));
+                       }
+                       return Pair.of(cost, operands);
+               }
+               else if (operands.size() == 1){
+                       // check for transpose
+                       return Pair.of(cost, operands);
+               }
+
+               for(int i = 0; i < operands.size()-1; i++){
+                       for (int j = i+1; j < operands.size(); j++){
+                               boolean swap = (operands.get(i).c2 == null && 
operands.get(j).c2 != null) || operands.get(i).c1 == null;
+                               EOpNode n1 = operands.get(!swap ? i : j);
+                               EOpNode n2 = operands.get(!swap ? j : i);
+
+
+                               Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> t = TryCombineAndCost(n1, n2, charToSizeMap, charToOccurences, 
outChar1, outChar2);
+                               if (t != null){
+                                       EOpNodeBinary newNode = new 
EOpNodeBinary(t.getRight().getLeft(), t.getRight().getRight(), n1, n2, 
t.getMiddle());
+                                       int thisCost = cost + t.getLeft();
+
+                                       if(n1.c1 != null) 
charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)-1);
+                                       if(n1.c2 != null) 
charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)-1);
+                                       if(n2.c1 != null) 
charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)-1);
+                                       if(n2.c2 != null) 
charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)-1);
+
+                                       if(newNode.c1 != null) 
charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)+1);
+                                       if(newNode.c2 != null) 
charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)+1);
+
+                                       ArrayList<EOpNode> newOperands = new 
ArrayList<>(operands.size()-1);
+                                       for(int z = 0; z < operands.size(); 
z++){
+                                               if(z != i && z != j) 
newOperands.add(operands.get(z));
+                                       }
+                                       newOperands.add(newNode);
+
+                                       Pair<Integer, List<EOpNode>> 
furtherPlan = generatePlan(thisCost, newOperands,charToSizeMap, 
charToOccurences, outChar1, outChar2);
+                                       if(furtherPlan.getRight().size() < 
(minNodes.size()) || furtherPlan.getLeft() < minCost){
+                                               minCost = furtherPlan.getLeft();
+                                               minNodes = 
furtherPlan.getRight();
+                                       }
+
+                                       if(n1.c1 != null) 
charToOccurences.put(n1.c1, charToOccurences.get(n1.c1)+1);
+                                       if(n1.c2 != null) 
charToOccurences.put(n1.c2, charToOccurences.get(n1.c2)+1);
+                                       if(n2.c1 != null) 
charToOccurences.put(n2.c1, charToOccurences.get(n2.c1)+1);
+                                       if(n2.c2 != null) 
charToOccurences.put(n2.c2, charToOccurences.get(n2.c2)+1);
+                                       if(newNode.c1 != null) 
charToOccurences.put(newNode.c1, charToOccurences.get(newNode.c1)-1);
+                                       if(newNode.c2 != null) 
charToOccurences.put(newNode.c2, charToOccurences.get(newNode.c2)-1);
+                               }
+                       }
+               }
+
+               return Pair.of(minCost, minNodes);
+       }
+
+       private static Triple<Integer, EBinaryOperand, Pair<Character, 
Character>> TryCombineAndCost(EOpNode n1 , EOpNode n2, HashMap<Character, 
Integer> charToSizeMap, HashMap<Character, Integer> charToOccurences, Character 
outChar1, Character outChar2){
+               Predicate<Character> cannotBeSummed = (c) ->
+                               c == outChar1 || c == outChar2 || 
charToOccurences.get(c) > 2;
+
+               if(n1.c1 == null) {
+                       // n2.c1 also has to be null
+                       return Triple.of(1, EBinaryOperand.scalar_scalar, 
Pair.of(null, null));
+               }
+
+               if(n2.c1 == null) {
+                       if(n1.c2 == null)
+                               return Triple.of(charToSizeMap.get(n1.c1), 
EBinaryOperand.A_scalar, Pair.of(n1.c1, null));
+                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.AB_scalar, Pair.of(n1.c1, n1.c2));
+               }
+
+               if(n1.c1 == n2.c1){
+                       if(n1.c2 != null){
+                               if ( n1.c2 == n2.c2){
+                                       if( cannotBeSummed.test(n1.c1)){
+                                               if(cannotBeSummed.test(n1.c2)){
+                                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.AB_AB, Pair.of(n1.c1, n1.c2));
+                                               }
+                                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.Ba_Ba, Pair.of(n1.c1, null));
+                                       }
+
+                                       if(cannotBeSummed.test(n1.c2)){
+                                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.aB_aB, Pair.of(n1.c2, null));
+                                       }
+
+                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.ab_ab, Pair.of(null, null));
+
+                               }
+
+                               else if(n2.c2 == null){
+                                       if(cannotBeSummed.test(n1.c1)){
+                                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, 
EBinaryOperand.AB_A, Pair.of(n1.c1, n1.c2));
+                                       }
+                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*2, 
EBinaryOperand.aB_a, Pair.of(n1.c2, null)); // in theory (null, n1.c2)
+                               }
+                               else if(n1.c1 ==outChar1 || n1.c1==outChar2|| 
charToOccurences.get(n1.c1) > 2){
+                                       return null;// AB,AC
+                               }
+                               else {
+                                       return 
Triple.of((charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2))+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2)),
 EBinaryOperand.aB_aC, Pair.of(n1.c2, n2.c2)); // or n2.c2, n1.c2
+                               }
+                       }else{ // n1.c2 = null -> c2.c2 = null
+                               if(n1.c1 ==outChar1 || n1.c1==outChar2 || 
charToOccurences.get(n1.c1) > 2){
+                                       return 
Triple.of(charToSizeMap.get(n1.c1), EBinaryOperand.A_A, Pair.of(n1.c1, null));
+                               }
+                               return Triple.of(charToSizeMap.get(n1.c1), 
EBinaryOperand.a_a, Pair.of(null, null));
+                       }
+
+
+               }else{ // n1.c1 != n2.c1
+                       if(n1.c2 == null) {
+                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), 
EBinaryOperand.A_B, Pair.of(n1.c1, n2.c1));
+                       }
+                       else if(n2.c2 == null) { // ab,c
+                               if (n1.c2 == n2.c1) {
+                                       if(cannotBeSummed.test(n1.c2)){
+                                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), 
EBinaryOperand.BA_A, Pair.of(n1.c1, n1.c2));
+                                       }
+                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n2.c1), 
EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
+                               }
+                               return null; // AB,C
+                       }
+                       else if (n1.c2 == n2.c1) {
+                               if(n1.c1 == n2.c2){ // ab,ba
+                                       if(cannotBeSummed.test(n1.c1)){
+                                               if(cannotBeSummed.test(n1.c2)){
+                                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.AB_BA, Pair.of(n1.c1, n1.c2));
+                                               }
+                                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.Ba_aB, Pair.of(n1.c1, null));
+                                       }
+                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.ab_ba, Pair.of(null, null));
+                               }
+                               if(cannotBeSummed.test(n1.c2)){
+                                       return null; // AB_B
+                               }else{
+                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c2),
 EBinaryOperand.Ba_aC, Pair.of(n1.c1, n2.c2));
+//                                     if(n1.c1 ==outChar1 || 
n1.c1==outChar2|| charToOccurences.get(n1.c1) > 2){
+//                                             return null; // AB_B
+//                                     }
+//                                     return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2), 
EBinaryOperand.Ba_a, Pair.of(n1.c1, null));
+                               }
+                       }
+                       if(n1.c1 == n2.c2) {
+                               if(cannotBeSummed.test(n1.c1)){
+                                       return null; // AB_B
+                               }
+                               return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1),
 EBinaryOperand.aB_Ca, Pair.of(n2.c1, n1.c2)); // * its just reorder of mmult
+                       }
+                       else if (n1.c2 == n2.c2) {
+                               if(n1.c2 ==outChar1 || n1.c2==outChar2|| 
charToOccurences.get(n1.c2) > 2){
+                                       return null; // BA_CA
+                               }else{
+                                       return 
Triple.of(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2) 
+(charToSizeMap.get(n1.c1)*charToSizeMap.get(n1.c2)*charToSizeMap.get(n2.c1)), 
EBinaryOperand.Ba_Ca, Pair.of(n1.c1, n2.c1)); // or n2.c1, n1.c1
+                               }
+                       }
+                       else { // we have something like ab,cd
+                               return null;
+                       }
+               }
+       }
+
+       private ArrayList<MatrixBlock /* #els = #els of plan */> 
executePlan(List<EOpNode> plan, ArrayList<MatrixBlock> inputs){
+               return executePlan(plan, inputs, false);
+       }
+       private ArrayList<MatrixBlock /* #els = #els of plan */> 
executePlan(List<EOpNode> plan, ArrayList<MatrixBlock> inputs, boolean codegen) 
{
+               ArrayList<MatrixBlock> res = new ArrayList<>(plan.size());
+               for(EOpNode p : plan){
+                       if(codegen) res.add(ComputeEOpNodeCodegen(p, inputs));
+                       else res.add(ComputeEOpNode(p, inputs));
+               }
+               return res;
+       }
+
+       private MatrixBlock ComputeEOpNode(EOpNode eOpNode, 
ArrayList<MatrixBlock> inputs){
+               if(eOpNode instanceof EOpNodeData eOpNodeData){
+                       return inputs.get(eOpNodeData.matrixIdx);
+               }
+               EOpNodeBinary bin = (EOpNodeBinary) eOpNode;
+               MatrixBlock left = ComputeEOpNode(bin.left, inputs);
+               MatrixBlock right = ComputeEOpNode(bin.right, inputs);
+
+               AggregateOperator agg = new AggregateOperator(0, 
Plus.getPlusFnObject());
+
+               MatrixBlock res;
+               switch (bin.operand){
+                       case AB_AB -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                       }
+                       case A_A -> {
+                               EnsureMatrixBlockColumnVector(left);
+                               EnsureMatrixBlockColumnVector(right);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                       }
+                       case a_a -> {
+                               EnsureMatrixBlockColumnVector(left);
+                               EnsureMatrixBlockColumnVector(right);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                       }
+                       ////////////
+                       case Ba_Ba -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                       }
+                       case aB_aB -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                               EnsureMatrixBlockColumnVector(res);
+                       }
+                       case ab_ab -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                       }
+                       case ab_ba -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+                               right = right.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceAll.getReduceAllFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                       }
+                       case Ba_aB -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+                               right = right.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                       }
+
+                       /////////
+                       case AB_BA -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+                               right = right.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left, 
right},new ScalarObject[]{}, new MatrixBlock());
+                       }
+                       case Ba_aC -> {
+                               res = LibMatrixMult.matrixMult(left,right, new 
MatrixBlock(), _numThreads);
+                       }
+                       case aB_Ca -> {
+                               res = LibMatrixMult.matrixMult(right,left, new 
MatrixBlock(), _numThreads);
+                       }
+                       case Ba_Ca -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+                               right = right.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = LibMatrixMult.matrixMult(left,right, new 
MatrixBlock(), _numThreads);
+                       }
+                       case aB_aC -> {
+                               ReorgOperator transpose = new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), _numThreads);
+                               left = left.reorgOperations(transpose, new 
MatrixBlock(), 0, 0, 0);
+                               res = LibMatrixMult.matrixMult(left,right, new 
MatrixBlock(), _numThreads);
+                       }
+                       case A_scalar, AB_scalar -> {
+                               res = MatrixBlock.naryOperations(new 
SimpleOperator(Multiply.getMultiplyFnObject()), new MatrixBlock[]{left},new 
ScalarObject[]{new DoubleObject(right.get(0,0))}, new MatrixBlock());
+                       }
+                       case BA_A -> {
+                               EnsureMatrixBlockRowVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                       }
+                       case Ba_a -> {
+                               EnsureMatrixBlockRowVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceCol.getReduceColFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                       }
+
+                       case AB_A -> {
+                               EnsureMatrixBlockColumnVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                       }
+                       case aB_a -> {
+                               EnsureMatrixBlockColumnVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                               AggregateUnaryOperator aggun = new 
AggregateUnaryOperator(agg, ReduceRow.getReduceRowFnObject(), _numThreads);
+                               res = (MatrixBlock) 
res.aggregateUnaryOperations(aggun, new MatrixBlock(), 0, null);
+                               EnsureMatrixBlockColumnVector(res);
+                       }
+
+                       case A_B -> {
+                               EnsureMatrixBlockColumnVector(left);
+                               EnsureMatrixBlockRowVector(right);
+                               res = left.binaryOperations(new 
BinaryOperator(Multiply.getMultiplyFnObject()), right);
+                       }
+                       case scalar_scalar -> {
+                               return new 
MatrixBlock(left.get(0,0)*right.get(0,0));
+                       }
+                       default -> {
+                               throw new IllegalArgumentException("Unexpected 
value: " + bin.operand.toString());
+                       }
+
+               }
+               return res;
+       }
+
+       private static MatrixBlock ComputeEOpNodeCodegen(EOpNode eOpNode, 
ArrayList<MatrixBlock> inputs){
+               return rComputeEOpNodeCodegen(eOpNode, inputs);
+//             throw new NotImplementedException();
+       }
+       private static CNodeData MatrixBlockToCNodeData(MatrixBlock mb, int id){
+               return new CNodeData("ce"+id, id, mb.getNumRows(), 
mb.getNumColumns(), DataType.MATRIX);
+       }
+       private static MatrixBlock rComputeEOpNodeCodegen(EOpNode eOpNode, 
ArrayList<MatrixBlock> inputs) {
+               if (eOpNode instanceof EOpNodeData eOpNodeData){
+                       return inputs.get(eOpNodeData.matrixIdx);
+//                     return new CNodeData("ce"+eOpNodeData.matrixIdx, 
eOpNodeData.matrixIdx, inputs.get(eOpNodeData.matrixIdx).getNumRows(), 
inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX);
+               }
+
+               EOpNodeBinary bin = (EOpNodeBinary) eOpNode;
+//             CNodeData dataLeft = null;
+//             if (bin.left instanceof EOpNodeData eOpNodeData) dataLeft = new 
CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, 
inputs.get(eOpNodeData.matrixIdx).getNumRows(), 
inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX);
+//             CNodeData dataRight = null;
+//             if (bin.right instanceof EOpNodeData eOpNodeData) dataRight = 
new CNodeData("ce"+eOpNodeData.matrixIdx, eOpNodeData.matrixIdx, 
inputs.get(eOpNodeData.matrixIdx).getNumRows(), 
inputs.get(eOpNodeData.matrixIdx).getNumColumns(), DataType.MATRIX);
+
+               if(bin.operand == EBinaryOperand.AB_AB){
+                       if (bin.right instanceof EOpNodeBinary rBinary && 
rBinary.operand  == EBinaryOperand.AB_AB){
+                               MatrixBlock left = 
rComputeEOpNodeCodegen(bin.left, inputs);
+
+                               MatrixBlock right1 = 
rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).left, inputs);
+                               MatrixBlock right2 = 
rComputeEOpNodeCodegen(((EOpNodeBinary) bin.right).right, inputs);
+
+                               CNodeData d0 = MatrixBlockToCNodeData(left, 0);
+                               CNodeData d1 = MatrixBlockToCNodeData(right1, 
1);
+                               CNodeData d2 = MatrixBlockToCNodeData(right2, 
2);
+//                             CNodeNary nary = new CNodeNary(cnodeIn, 
CNodeNary.NaryType.)
+                               CNodeBinary rightBinary = new CNodeBinary(d1, 
d2, CNodeBinary.BinType.VECT_MULT);
+                               CNodeBinary cNodeBinary = new CNodeBinary(d0, 
rightBinary, CNodeBinary.BinType.VECT_MULT);
+                               ArrayList<CNode> cnodeIn = new ArrayList<>();
+                               cnodeIn.add(d0);
+                               cnodeIn.add(d1);
+                               cnodeIn.add(d2);
+
+                               CNodeRow cnode = new CNodeRow(cnodeIn, 
cNodeBinary);
+
+                               cnode.setRowType(SpoofRowwise.RowType.NO_AGG);
+                               cnode.renameInputs();
+
+
+                               String src = cnode.codegen(false, 
SpoofCompiler.GeneratorAPI.JAVA);
+                               if( LOG.isTraceEnabled()) 
LOG.trace(CodegenUtils.printWithLineNumber(src));
+                               Class<?> cla = 
CodegenUtils.compileClass("codegen." + cnode.getClassname(), src);
+
+                               SpoofOperator op = 
CodegenUtils.createInstance(cla);
+                               MatrixBlock mb = new MatrixBlock();
+
+                               ArrayList<ScalarObject> scalars = new 
ArrayList<>();
+                               ArrayList<MatrixBlock> mbs = new ArrayList<>(3);
+                               mbs.add(left);
+                               mbs.add(right1);
+                               mbs.add(right2);
+                               MatrixBlock out = op.execute(mbs, scalars, mb, 
6);
+
+                               return out;
+                       }
+               }
+
+               throw new NotImplementedException();
+       }
+
+
+       private void releaseMatrixInputs(ExecutionContext ec){
+               for (CPOperand input : _in)
+                       if(input.getDataType()==DataType.MATRIX)
+                               ec.releaseMatrixInput(input.getName()); //todo 
release other
+       }
+
+       private static void EnsureMatrixBlockColumnVector(MatrixBlock mb){
+               if(mb.getNumColumns() > 1){
+                       mb.setNumRows(mb.getNumColumns());
+                       mb.setNumColumns(1);
+                       mb.getDenseBlock().resetNoFill(mb.getNumRows(),1);
+               }
+       }
+       private static void EnsureMatrixBlockRowVector(MatrixBlock mb){
+               if(mb.getNumRows() > 1){
+                       mb.setNumColumns(mb.getNumRows());
+                       mb.setNumRows(1);
+                       mb.getDenseBlock().resetNoFill(1,mb.getNumColumns());
+               }
+       }
+
+       private static void indent(StringBuilder sb, int level) {
+               for (int i = 0; i < level; i++) {
+                       sb.append("  ");
+               }
+       }
+
+       private MatrixBlock computeCellSummation(ArrayList<MatrixBlock> inputs, 
List<String> inputsChars, String resultString,
+                                                                               
                                   HashMap<Character, Integer> 
charToDimensionSizeInt, List<Character> summingChars, int outRows, int outCols){
+               ArrayList<CNode> cnodeIn = new ArrayList<>();
+               cnodeIn.add(new CNodeData(new LiteralOp(3), 0, 0, 
DataType.SCALAR));
+               CNodeCell cnode = new CNodeCell(cnodeIn, null);
+               StringBuilder sb = new StringBuilder();
+
+               int indent = 2;
+               indent(sb, indent);
+
+               boolean needsSumming = summingChars.stream().anyMatch(x -> x != 
null);
+
+               String itVar0 = cnode.createVarname();
+               String outVar = itVar0;
+               if (needsSumming) {
+                       sb.append("double ");
+                       sb.append(outVar);
+                       sb.append("=0;\n");
+               }
+
+               Iterator<Character> hsIt = summingChars.iterator();
+               while (hsIt.hasNext()) {
+                       indent(sb, indent);
+                       indent++;
+                       Character c = hsIt.next();
+                       String itVar = itVar0 + c;
+                       sb.append("for(int ");
+                       sb.append(itVar);
+                       sb.append("=0;");
+                       sb.append(itVar);
+                       sb.append("<");
+                       sb.append(charToDimensionSizeInt.get(c));
+                       sb.append(";");
+                       sb.append(itVar);
+                       sb.append("++){\n");
+               }
+               indent(sb, indent);
+               if (needsSumming) {
+                       sb.append(outVar);
+                       sb.append("+=");
+               }
+
+               for (int i = 0; i < inputsChars.size(); i++) {
+                       if (inputsChars.get(i).length() == 0){
+                               sb.append("getValue(b[");
+                               sb.append(i);
+                               sb.append("],b[");
+                               sb.append(i);
+                               sb.append("].clen, 0,");
+                       }
+
+                       else if 
(summingChars.contains(inputsChars.get(i).charAt(0))) {
+                               sb.append("getValue(b[");
+                               sb.append(i);
+                               sb.append("],b[");
+                               sb.append(i);
+                               sb.append("].clen,");
+                               sb.append(itVar0);
+                               sb.append(inputsChars.get(i).charAt(0));
+                               sb.append(",");
+                       } else if (resultString.length() >= 1  && 
inputsChars.get(i).charAt(0) == resultString.charAt(0)) {
+                               sb.append("getValue(b[");
+                               sb.append(i);
+                               sb.append("],b[");
+                               sb.append(i);
+                               sb.append("].clen, rix,");
+                       } else if (resultString.length() == 2 && 
inputsChars.get(i).charAt(0) == resultString.charAt(1)) {
+                               sb.append("getValue(b[");
+                               sb.append(i);
+                               sb.append("],b[");
+                               sb.append(i);
+                               sb.append("].clen, cix,");
+                       } else {
+                               sb.append("getValue(b[");
+                               sb.append(i);
+                               sb.append("],b[");
+                               sb.append(i);
+                               sb.append("].clen, 0,");
+                       }
+
+                       if (inputsChars.get(i).length() != 2){
+                               sb.append("0)");
+                       }
+                       else if 
(summingChars.contains(inputsChars.get(i).charAt(1))) {
+                               sb.append(itVar0);
+                               sb.append(inputsChars.get(i).charAt(1));
+                               sb.append(")");
+                       } else if (resultString.length() >= 1 
&&inputsChars.get(i).charAt(1) == resultString.charAt(0)) {
+                               sb.append("rix)");
+                       } else if (resultString.length() == 2  && 
inputsChars.get(i).charAt(1) == resultString.charAt(1)) {
+                               sb.append("cix)");
+                       } else {
+                               sb.append("0)");
+                       }
+
+                       if (i < inputsChars.size() - 1) {
+                               sb.append(" * ");
+                       }
+
+               }
+               if (needsSumming) {
+                       sb.append(";\n");
+               }
+               indent--;
+               for (int si = 0; si < summingChars.size(); si++) {
+                       indent(sb, indent);
+                       indent--;
+                       sb.append("}\n");
+               }
+               String src = CNodeCell.JAVA_TEMPLATE;//
+               src = src.replace("%TMP%", cnode.createVarname());
+               src = src.replace("%TYPE%", "NO_AGG");
+               src = src.replace("%SPARSE_SAFE%", "false");
+               src = src.replace("%SEQ%", "true");
+               src = src.replace("%AGG_OP_NAME%", "null");
+               if (needsSumming) {
+                       src = src.replace("%BODY_dense%", sb.toString());
+                       src = src.replace("%OUT%", outVar);
+               } else {
+                       src = src.replace("%BODY_dense%", "");
+                       src = src.replace("%OUT%", sb.toString());
+               }
+
+               if( LOG.isTraceEnabled()) LOG.trace(src);
+               Class<?> cla = CodegenUtils.compileClass("codegen." + 
cnode.getClassname(), src);
+               SpoofOperator op = CodegenUtils.createInstance(cla);
+               MatrixBlock resBlock = new MatrixBlock();
+               resBlock.reset(outRows, outCols);
+               inputs.add(0, resBlock);
+               MatrixBlock out = op.execute(inputs, new ArrayList<>(), new 
MatrixBlock(), _numThreads);
+
+               return out;
+       }
+
+       public CPOperand[] getInputs() {
+               return _in;
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java 
b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
new file mode 100644
index 0000000000..dbf9047968
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/einsum/EinsumTest.java
@@ -0,0 +1,362 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.einsum;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.After;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+
+@RunWith(Parameterized.class)
+public class EinsumTest extends AutomatedTestBase
+{
+       final private static List<Config> TEST_CONFIGS = List.of(
+                       new Config("ij,jk->ik", List.of(shape(50, 600), 
shape(600, 10))), // mm
+                       new Config("ji,jk->ik", List.of(shape(600, 5), 
shape(600, 10))),
+                       new Config("ji,kj->ik", List.of(shape(600, 5), 
shape(10, 600))),
+                       new Config("ij,kj->ik", List.of(shape(5, 600), 
shape(10, 600))),
+
+                       new Config("ji,jk->i", List.of(shape(600, 5), 
shape(600, 10))),
+                       new Config("ij,jk->i", List.of(shape(5, 600), 
shape(600, 10))),
+
+                       new Config("ji,jk->k", List.of(shape(600, 5), 
shape(600, 10))),
+                       new Config("ij,jk->k", List.of(shape(5, 600), 
shape(600, 10))),
+
+                       new Config("ji,jk->j", List.of(shape(600, 5), 
shape(600, 10))),
+
+                       new Config("ji,ji->ji", List.of(shape(600, 5), 
shape(600, 5))), // elemwise mult
+                       new Config("ji,ji,ji->ji", List.of(shape(600, 
5),shape(600, 5), shape(600, 5)),
+                                       List.of(0.0001, 0.0005, 0.001)),
+                       new Config("ji,ij->ji", List.of(shape(600, 5), shape(5, 
600))), // elemwise mult
+
+
+                       new Config("ij,i->ij",   List.of(shape(100, 50), 
shape(100))), // col mult
+                       new Config("ji,i->ij",   List.of(shape(50, 100), 
shape(100))), // row mult
+                       new Config("ij,i->i",   List.of(shape(100, 50), 
shape(100))),
+                       new Config("ij,i->j",   List.of(shape(100, 50), 
shape(100))),
+
+                       new Config("i,i->",     List.of(shape(50), shape(50))),
+                       new Config("i,j->",     List.of(shape(50), shape(80))),
+                       new Config("i,j->ij",     List.of(shape(50), 
shape(80))), // outer vect mult
+                       new Config("i,j->ji",     List.of(shape(50), 
shape(80))), // outer vect mult
+
+                       new Config("ij->",     List.of(shape(100, 50))), // sum
+                       new Config("ij->i",     List.of(shape(100, 50))), // 
sum(1)
+                       new Config("ij->j",     List.of(shape(100, 50))), // 
sum(0)
+                       new Config("ij->ji",     List.of(shape(100, 50))), // T
+
+                       new Config("ab,cd->ba",     List.of(shape( 600, 10), 
shape(6, 5))),
+                       new Config("ab,cd,g->ba",     List.of(shape( 600, 10), 
shape(6, 5), shape(3))),
+
+                       new Config("ab,bc,cd,de->ae",   List.of(shape(5, 600), 
shape(600, 10), shape(10, 5), shape(5, 4))), // chain of mm
+
+                       new Config("ji,jz,zx->ix",   List.of(shape(600, 5), 
shape( 600, 10), shape(10, 2))),
+                       new Config("fx,fg,fz,xg->z",   List.of(shape(600, 5), 
shape( 600, 10), shape(600, 6), shape(5, 10))),
+                       new Config("fx,fg,fz,xg,zx,zg->g", // each idx 3 times 
(cell tpl)
+                                       List.of(shape(5, 60), shape(5, 30), 
shape(5, 10), shape(60, 30), shape(10, 60), shape(10, 30))),
+
+                       new Config("i->",     List.of(shape(100))),
+                       new Config("i->i",     List.of(shape(100)))
+       );
+
+       private final int id;
+       private final String einsumStr;
+       //private final List<int[]> shapes;
+       private final File dmlFile;
+       private final File rFile;
+       private final boolean outputScalar;
+
+       public EinsumTest(String einsumStr, List<int[]> shapes, File dmlFile, 
File rFile, boolean outputScalar, int id){
+               this.id = id;
+               this.einsumStr = einsumStr;
+               //this.shapes = shapes;
+               this.dmlFile = dmlFile;
+               this.rFile = rFile;
+               this.outputScalar = outputScalar;
+       }
+
+       @Parameterized.Parameters(name = "{index}: einsum={0}")
+       public static Collection<Object[]> data() throws IOException {
+               List<Object[]> parameters = new ArrayList<>();
+
+               int counter = 1;
+
+               for (Config config : TEST_CONFIGS) {
+                       //List<File> files = new ArrayList<>();
+                       String fullDMLScriptName = "SystemDS_einsum_test" + 
counter;
+
+                       File dmlFile = File.createTempFile(fullDMLScriptName, 
".dml");
+                       dmlFile.deleteOnExit();
+
+                       boolean outputScalar = 
config.einsumStr.trim().endsWith("->");
+
+                       StringBuilder sb = createDmlFile(config, outputScalar);
+
+                       Files.writeString(dmlFile.toPath(), sb.toString());
+
+                       File rFile = File.createTempFile(fullDMLScriptName, 
".R");
+                       rFile.deleteOnExit();
+
+                       sb = createRFile(config, outputScalar);
+
+                       Files.writeString(rFile.toPath(), sb.toString());
+
+                       parameters.add(new Object[]{config.einsumStr, 
config.shapes, dmlFile, rFile, outputScalar, counter});
+
+                       counter++;
+               }
+
+               return parameters;
+       }
+
+       private static StringBuilder createDmlFile(Config config, boolean 
outputScalar) {
+               StringBuilder sb = new StringBuilder();
+
+               for (int i = 0; i < config.shapes.size(); i++) {
+                       int[] dims = config.shapes.get(i);
+
+                       double factor = config.factors != null ? 
config.factors.get(i) : 0.0001;
+                       sb.append("A");
+                       sb.append(i);
+
+                       if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001
+                               sb.append(" = seq(1,");
+                               sb.append(dims[0]);
+                               sb.append(") * ");
+                               sb.append(factor);
+                       } else { // A0 = matrix(seq(1,50000), 1000, 50) * 0.0001
+                               sb.append(" = matrix(seq(1, ");
+                               sb.append(dims[0]*dims[1]);
+                               sb.append("), ");
+                               sb.append(dims[0]);
+                               sb.append(", ");
+                               sb.append(dims[1]);
+
+                               sb.append(") * ");
+                               sb.append(factor);
+                       }
+                       sb.append("\n");
+               }
+               sb.append("\n");
+
+               sb.append("R = einsum(\"");
+               sb.append(config.einsumStr);
+               sb.append("\", ");
+
+               for (int i = 0; i < config.shapes.size() - 1; i++) {
+                       sb.append("A");
+                       sb.append(i);
+                       sb.append(", ");
+               }
+               sb.append("A");
+               sb.append(config.shapes.size() - 1);
+               sb.append(")");
+
+               sb.append("\n\n");
+               sb.append("write(R, $1)\n");
+               return sb;
+       }
+
+       private static StringBuilder createRFile(Config config, boolean 
outputScalar) {
+               StringBuilder sb = new StringBuilder();
+               sb.append("args<-commandArgs(TRUE)\n");
+               sb.append("options(digits=22)\n");
+               sb.append("library(\"Matrix\")\n");
+               sb.append("library(\"matrixStats\")\n");
+               sb.append("library(\"einsum\")\n\n");
+
+
+               for (int i = 0; i < config.shapes.size(); i++) {
+                       int[] dims = config.shapes.get(i);
+                       
+                       double factor = config.factors != null ? 
config.factors.get(i) : 0.0001;
+                       sb.append("A");
+                       sb.append(i);
+
+                       if (dims.length == 1) { // A1 = seq(1,1000) * 0.0001
+                               sb.append(" = seq(1,");
+                               sb.append(dims[0]);
+                               sb.append(") * ");
+                               sb.append(factor);
+                       } else { // A0 = matrix(seq(1,50000), 1000, 50, 
byrow=TRUE) * 0.0001
+                               sb.append(" = matrix(seq(1, ");
+                               sb.append(dims[0]*dims[1]);
+                               sb.append("), ");
+                               sb.append(dims[0]);
+                               sb.append(", ");
+                               sb.append(dims[1]);
+
+                               sb.append(", byrow=TRUE) * ");
+                               sb.append(factor);
+                       }
+                       sb.append("\n");
+               }
+               sb.append("\n");
+
+               sb.append("R = einsum(\"");
+               sb.append(config.einsumStr);
+               sb.append("\", ");
+
+               for (int i = 0; i < config.shapes.size()-1; i++) {
+                       sb.append("A");
+                       sb.append(i);
+                       sb.append(", ");
+               }
+               sb.append("A");
+               sb.append(config.shapes.size()-1);
+               sb.append(")");
+
+               sb.append("\n\n");
+               if(outputScalar){
+                       sb.append("write(R, paste(args[2], \"S\", 
sep=\"\"))\n");
+               }else{
+                       sb.append("writeMM(as(R, \"CsparseMatrix\"), 
paste(args[2], \"S\", sep=\"\"))\n");
+               }
+               return sb;
+       }
+
+       @Test
+       public void testEinsumWithFiles() {
+               System.out.println("Testing einsum: " + this.einsumStr);
+               testCodegenIntegration(TEST_NAME_EINSUM+this.id);
+       }
+       @After
+       public void cleanUp() {
+               if (this.dmlFile.exists()) {
+                       boolean deleted = this.dmlFile.delete();
+                       if (!deleted) {
+                               System.err.println("Failed to delete temp file: 
" + this.dmlFile.getAbsolutePath());
+                       }
+               }
+               if (this.rFile.exists()) {
+                       boolean deleted = this.rFile.delete();
+                       if (!deleted) {
+                               System.err.println("Failed to delete temp file: 
" + this.rFile.getAbsolutePath());
+                       }
+               }
+       }
+
+       private static class Config {
+               public List<Double> factors;
+               String einsumStr;
+               List<int[]> shapes;
+
+               Config(String einsum, List<int[]> shapes) {
+                       this.einsumStr = einsum;
+                       this.shapes = shapes;
+                       this.factors = null;
+               }
+               Config(String einsum, List<int[]> shapes, List<Double> factors) 
{
+                       this.einsumStr = einsum;
+                       this.shapes = shapes;
+                       this.factors = factors;
+               }
+       }
+
+       private static int[] shape(int... dims) {
+               return dims;
+       }
+       private static final Log LOG = 
LogFactory.getLog(EinsumTest.class.getName());
+
+       private static final String TEST_NAME_EINSUM = "einsum";
+       private static final String TEST_DIR = "functions/einsum/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
EinsumTest.class.getSimpleName() + "/";
+       private final static String TEST_CONF = "SystemDS-config-codegen.xml";
+       private final static File   TEST_CONF_FILE = new File(SCRIPT_DIR + 
TEST_DIR, TEST_CONF);
+
+       private static double eps = Math.pow(10, -10);
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               for(int i = 1; i<= TEST_CONFIGS.size(); i++)
+                       addTestConfiguration( TEST_NAME_EINSUM+i, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME_EINSUM+i, new String[] { 
String.valueOf(i) }) );
+       }
+
+       private void testCodegenIntegration( String testname)
+       {
+               boolean oldFlag = OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION;
+               ExecMode platformOld = setExecMode(ExecType.CP);
+
+               String testnameDml = this.dmlFile.getAbsolutePath();
+               String testnameR = this.rFile.getAbsolutePath();
+               try
+               {
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       loadTestConfiguration(config);
+
+                       fullDMLScriptName = testnameDml;
+                       programArgs = new String[]{"-stats", "-explain", 
"-args", output("S") };
+                       fullRScriptName = testnameR;
+                       rCmd = getRCmd(inputDir(), expectedDir());
+
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = false;
+
+                       runTest(true, false, null, -1);
+                       runRScript(true);
+
+                       if(outputScalar){
+                               HashMap<CellIndex, Double> dmlfile = 
readDMLScalarFromOutputDir("S");
+                               HashMap<CellIndex, Double> rfile = 
readRScalarFromExpectedDir("S");
+                               TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+                       }else {
+                               //compare matrices
+                               HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("S");
+                               HashMap<CellIndex, Double> rfile = 
readRMatrixFromExpectedDir("S");
+                               TestUtils.compareMatrices(dmlfile, rfile, eps, 
"Stat-DML", "Stat-R");
+                       }
+               }
+               finally {
+                       resetExecMode(platformOld);
+                       OptimizerUtils.ALLOW_ALGEBRAIC_SIMPLIFICATION = oldFlag;
+                       OptimizerUtils.ALLOW_AUTO_VECTORIZATION = true;
+                       OptimizerUtils.ALLOW_OPERATOR_FUSION = true;
+               }
+       }
+
+
+       /**
+        * Override default configuration with custom test configuration to 
ensure
+        * scratch space and local temporary directory locations are also 
updated.
+        */
+       @Override
+       protected File getConfigTemplateFile() {
+               // Instrumentation in this test's output log to show custom 
configuration file used for template.
+               LOG.debug("This test case overrides default configuration with 
" + TEST_CONF_FILE.getPath());
+               return TEST_CONF_FILE;
+       }
+}
diff --git a/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml 
b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml
new file mode 100644
index 0000000000..626b31ebd7
--- /dev/null
+++ b/src/test/scripts/functions/einsum/SystemDS-config-codegen.xml
@@ -0,0 +1,31 @@
+<!--
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+-->
+
+<root>
+   <sysds.localtmpdir>/tmp/systemds</sysds.localtmpdir>
+   <sysds.scratch>scratch_space</sysds.scratch>
+   <sysds.optlevel>2</sysds.optlevel>
+   <sysds.codegen.plancache>true</sysds.codegen.plancache>
+   <sysds.codegen.literals>1</sysds.codegen.literals>
+
+   <!-- The number of threads for the spark instance artificially selected-->
+   <sysds.local.spark.number.threads>16</sysds.local.spark.number.threads>
+
+   <sysds.codegen.api>auto</sysds.codegen.api>
+</root>
\ No newline at end of file
diff --git a/src/test/scripts/installDependencies.R 
b/src/test/scripts/installDependencies.R
index af89f2b936..60642fa8ed 100644
--- a/src/test/scripts/installDependencies.R
+++ b/src/test/scripts/installDependencies.R
@@ -64,6 +64,7 @@ custom_install("unbalanced");
 custom_install("naivebayes");
 custom_install("BiocManager");
 custom_install("mltools");
+custom_install("einsum");
 BiocManager::install("rhdf5");
 
 print("Installation Done")

Reply via email to