mboehm7 commented on code in PR #2190:
URL: https://github.com/apache/systemds/pull/2190#discussion_r1929504112
##########
src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java:
##########
@@ -270,26 +271,26 @@ public String toString() {
case VECT_CBIND: return "b(cbind)";
case VECT_BIASADD: return "b(vbias+)";
case VECT_BIASMULT: return "b(vbias*)";
- case MULT: return "b(*)";
- case DIV: return "b(/)";
- case PLUS: return "b(+)";
- case MINUS: return "b(-)";
- case POW: return "b(^)";
- case MODULUS: return "b(%%)";
- case INTDIV: return "b(%/%)";
- case LESS: return "b(<)";
- case LESSEQUAL: return "b(<=)";
- case GREATER: return "b(>)";
- case GREATEREQUAL: return "b(>=)";
- case EQUAL: return "b(==)";
- case NOTEQUAL: return "b(!=)";
+ case MULT: return "b(" +
Opcodes.MULT.getName() + ")";
+ case DIV: return "b(" +
Opcodes.DIV.getName() + ")";
+ case PLUS: return "b(" +
Opcodes.PLUS.getName() + ")";
+ case MINUS: return "b(" +
Opcodes.MINUS.getName() + ")";
+ case POW: return "b(" +
Opcodes.POW.getName() + ")";
+ case MODULUS: return "b(" +
Opcodes.MODULUS.getName() + ")";
+ case INTDIV: return "b(" +
Opcodes.INTDIV.getName() + ")";
+ case LESS: return "b(" +
Opcodes.LESS.getName() + ")";
+ case LESSEQUAL: return "b(" +
Opcodes.LESSEQUAL.getName() + ")";
+ case GREATER: return "b(" +
Opcodes.GREATER.getName() + ")";
+ case GREATEREQUAL: return "b(" +
Opcodes.GREATEREQUAL.getName() + ")";
+ case EQUAL: return "b(" +
Opcodes.EQUAL.getName() + ")";
+ case NOTEQUAL: return "b(" +
Opcodes.NOTEQUAL.getName() + ")";
case OR: return "b(|)";
case AND: return "b(&)";
Review Comment:
why are these operators not covered?
##########
src/main/java/org/apache/sysds/common/Opcodes.java:
##########
@@ -0,0 +1,328 @@
+package org.apache.sysds.common;
+
+import org.apache.sysds.lops.*;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.hops.FunctionOp;
+
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
+
+public enum Opcodes {
+ MMULT("ba+*", CPType.AggregateBinary),
+ TAKPM("tak+*", CPType.AggregateTernary),
+ TACKPM("tack+*", CPType.AggregateTernary),
+
+ UAKP("uak+", CPType.AggregateUnary),
+ UARKP("uark+", CPType.AggregateUnary),
+ UACKP( "uack+", CPType.AggregateUnary),
+ UASQKP( "uasqk+", CPType.AggregateUnary),
+ UARSQKP( "uarsqk+", CPType.AggregateUnary),
+ UACSQKP( "uacsqk+", CPType.AggregateUnary),
+ UAMEAN( "uamean", CPType.AggregateUnary),
+ UARMEAN("uarmean", CPType.AggregateUnary),
+ UACMEAN("uacmean", CPType.AggregateUnary),
+ UAVAR("uavar", CPType.AggregateUnary),
+ UARVAR("uarvar", CPType.AggregateUnary),
+ UACVAR("uacvar", CPType.AggregateUnary),
+ UAMAX("uamax", CPType.AggregateUnary),
+ UARMAX("uarmax", CPType.AggregateUnary),
+ UARIMAX("uarimax", CPType.AggregateUnary),
+ UACMAX("uacmax", CPType.AggregateUnary),
+ UAMIN("uamin", CPType.AggregateUnary),
+ UARMIN("uarmin", CPType.AggregateUnary),
+ UARIMIN("uarimin", CPType.AggregateUnary),
+ UACMIN("uacmin", CPType.AggregateUnary),
+ UAP("ua+", CPType.AggregateUnary),
+ UARP("uar+", CPType.AggregateUnary),
+ UACP("uac+", CPType.AggregateUnary),
+ UAM("ua*", CPType.AggregateUnary),
+ UARM("uar*", CPType.AggregateUnary),
+ UACM("uac*", CPType.AggregateUnary),
+ UATRACE("uatrace", CPType.AggregateUnary),
+ UAKTRACE("uaktrace", CPType.AggregateUnary),
+
+ NROW("nrow", CPType.AggregateUnary),
+ NCOL("ncol", CPType.AggregateUnary),
+ LENGTH("length", CPType.AggregateUnary),
+ EXISTS("exists", CPType.AggregateUnary),
+ LINEAGE("lineage", CPType.AggregateUnary),
+ UACD("uacd", CPType.AggregateUnary),
+ UACDR("uacdr", CPType.AggregateUnary),
+ UACDC("uacdc", CPType.AggregateUnary),
+ UACDAP("uacdap", CPType.AggregateUnary),
+ UACDAPR("uacdapr", CPType.AggregateUnary),
+ UACDAPC("uacdapc", CPType.AggregateUnary),
+ UNIQUE("unique", CPType.AggregateUnary),
+ UNIQUER("uniquer", CPType.AggregateUnary),
+ UNIQUEC("uniquec", CPType.AggregateUnary),
+
+ UAGGOUTERCHAIN("uaggouterchain", CPType.UaggOuterChain),
+
+ // Arithmetic Instruction Opcodes
+ PLUS("+", CPType.Binary),
+ MINUS("-", CPType.Binary),
+ MULT("*", CPType.Binary),
+ DIV("/", CPType.Binary),
+ MODULUS("%%", CPType.Binary),
+ INTDIV("%/%", CPType.Binary),
+ POW("^", CPType.Binary),
+ MINUS1_MULT("1-*", CPType.Binary), //special * case
+ POW2("^2", CPType.Binary), //special ^ case
+ MULT2("*2", CPType.Binary), //special * case
+ MINUS_NZ("-nz", CPType.Binary), //special - case
+
+ // Boolean Instruction Opcodes
+ AND("&&", CPType.Binary),
+ OR("||", CPType.Binary),
+ XOR("xor", CPType.Binary),
+ BITWAND("bitwAnd", CPType.Binary),
+ BITWOR("bitwOr", CPType.Binary),
+ BITWXOR("bitwXor", CPType.Binary),
+ BITWSHIFTL("bitwShiftL", CPType.Binary),
+ BITWSHIFTR("bitwShiftR", CPType.Binary),
+ NOT("!", CPType.Unary),
+
+ // Relational Instruction Opcodes
+ EQUAL("==", CPType.Binary),
+ NOTEQUAL("!=", CPType.Binary),
+ LESS("<", CPType.Binary),
+ GREATER(">", CPType.Binary),
+ LESSEQUAL("<=", CPType.Binary),
+ GREATEREQUAL(">=", CPType.Binary),
+
+ // Builtin Instruction Opcodes
+ LOG("log", CPType.Builtin),
+ LOGNZ("log_nz", CPType.Builtin),
+
+ SOLVE("solve", CPType.Binary),
+ MAX("max", CPType.Binary),
+ MIN("min", CPType.Binary),
+ DROPINVALIDTYPE("dropInvalidType", CPType.Binary),
+ DROPINVALIDLENGTH("dropInvalidLength", CPType.Binary),
+ FREPLICATE("freplicate", CPType.Binary),
+ VALUESWAP("valueSwap", CPType.Binary),
+ APPLYSCHEMA("applySchema", CPType.Binary),
+ MAP("_map", CPType.Ternary),
+
+ NMAX("nmax", CPType.BuiltinNary),
+ NMIN("nmin", CPType.BuiltinNary),
+ NP("n+", CPType.BuiltinNary),
+ NM("n*", CPType.BuiltinNary),
+
+ EXP("exp", CPType.Unary),
+ ABS("abs", CPType.Unary),
+ SIN("sin", CPType.Unary),
+ COS("cos", CPType.Unary),
+ TAN("tan", CPType.Unary),
+ SINH("sinh", CPType.Unary),
+ COSH("cosh", CPType.Unary),
+ TANH("tanh", CPType.Unary),
+ ASIN("asin", CPType.Unary),
+ ACOS("acos", CPType.Unary),
+ ATAN("atan", CPType.Unary),
+ SIGN("sign", CPType.Unary),
+ SQRT("sqrt", CPType.Unary),
+ PLOGP("plogp", CPType.Unary),
+ PRINT("print", CPType.Unary),
+ ASSERT("assert", CPType.Unary),
+ ROUND("round", CPType.Unary),
+ CEIL("ceil", CPType.Unary),
+ FLOOR("floor", CPType.Unary),
+ UCUMKP("ucumk+", CPType.Unary),
+ UCUMM("ucum*", CPType.Unary),
+ UCUMKPM("ucumk+*", CPType.Unary),
+ UCUMMIN("ucummin", CPType.Unary),
+ UCUMMAX("ucummax", CPType.Unary),
+ STOP("stop", CPType.Unary),
+ INVERSE("inverse", CPType.Unary),
+ CHOLESKY("cholesky", CPType.Unary),
+ SPROP("sprop", CPType.Unary),
+ SIGMOID("sigmoid", CPType.Unary),
+ TYPEOF("typeOf", CPType.Unary),
+ DETECTSCHEMA("detectSchema", CPType.Unary),
+ COLNAMES("colnames", CPType.Unary),
+ ISNA("isna", CPType.Unary),
+ ISNAN("isnan", CPType.Unary),
+ ISINF("isinf", CPType.Unary),
+ PRINTF("printf", CPType.BuiltinNary),
+ CBIND("cbind", CPType.BuiltinNary),
+ RBIND("rbind", CPType.BuiltinNary),
+ EVAL("eval", CPType.BuiltinNary),
+ LIST("list", CPType.BuiltinNary),
+
+ //Parametrized builtin functions
+ AUTODIFF("autoDiff", CPType.ParameterizedBuiltin),
+ CONTAINS("contains", CPType.ParameterizedBuiltin),
+ PARAMSERV("paramserv", CPType.ParameterizedBuiltin),
+ NVLIST("nvlist", CPType.ParameterizedBuiltin),
+ CDF("cdf", CPType.ParameterizedBuiltin),
+ INVCDF("invcdf", CPType.ParameterizedBuiltin),
+ GROUPEDAGG("groupedagg", CPType.ParameterizedBuiltin),
+ RMEMPTY("rmempty", CPType.ParameterizedBuiltin),
+ REPLACE("replace", CPType.ParameterizedBuiltin),
+ LOWERTRI("lowertri", CPType.ParameterizedBuiltin),
+ UPPERTRI("uppertri", CPType.ParameterizedBuiltin),
+ REXPAND("rexpand", CPType.ParameterizedBuiltin),
+ TOSTRING("toString", CPType.ParameterizedBuiltin),
+ TOKENIZE("tokenize", CPType.ParameterizedBuiltin),
+ TRANSFORMAPPLY("transformapply", CPType.ParameterizedBuiltin),
+ TRANSFORMDECODE("transformdecode", CPType.ParameterizedBuiltin),
+ TRANSFORMCOLMAP("transformcolmap", CPType.ParameterizedBuiltin),
+ TRANSFORMMETA("transformmeta", CPType.ParameterizedBuiltin),
+ TRANSFORMENCODE("transformencode", CPType.MultiReturnParameterizedBuiltin),
+
+ //Ternary instruction opcodes
+ PM("+*", CPType.Ternary),
+ MINUSMULT("-*", CPType.Ternary),
+ IFELSE("ifelse", CPType.Ternary),
+
+ //Variable instruction opcodes
+ ASSIGNVAR("assignvar", CPType.Variable),
+ CPVAR("cpvar", CPType.Variable),
+ MVVAR("mvvar", CPType.Variable),
+ RMVAR("rmvar", CPType.Variable),
+ RMFILEVAR("rmfilevar", CPType.Variable),
+ CAST_AS_SCALAR(OpOp1.CAST_AS_SCALAR.toString(), CPType.Variable),
+ CAST_AS_MATRIX(OpOp1.CAST_AS_MATRIX.toString(), CPType.Variable),
+ CAST_AS_FRAME_VAR("cast_as_frame", CPType.Variable),
+ CAST_AS_FRAME(OpOp1.CAST_AS_FRAME.toString(), CPType.Variable),
+ CAST_AS_LIST(OpOp1.CAST_AS_LIST.toString(), CPType.Variable),
+ CAST_AS_DOUBLE(OpOp1.CAST_AS_DOUBLE.toString(), CPType.Variable),
+ CAST_AS_INT(OpOp1.CAST_AS_INT.toString(), CPType.Variable),
+ CAST_AS_BOOLEAN(OpOp1.CAST_AS_BOOLEAN.toString(), CPType.Variable),
+ ATTACHFILETOVAR("attachfiletovar", CPType.Variable),
+ READ("read", CPType.Variable),
+ WRITE("write", CPType.Variable),
+ CREATEVAR("createvar", CPType.Variable),
+
+ //Reorg instruction opcodes
+ TRANSPOSE("r'", CPType.Reorg),
+ REV("rev", CPType.Reorg),
+ ROLL("roll", CPType.Reorg),
+ DIAG("rdiag", CPType.Reorg),
+ RESHAPE("rshape", CPType.Reshape),
+ SORT("rsort", CPType.Reorg),
+
+ // Opcodes related to convolutions
+ RELU_BACKWARD("relu_backward", CPType.Dnn),
+ RELU_MAXPOOLING("relu_maxpooling", CPType.Dnn),
+ RELU_MAXPOOLING_BACKWARD("relu_maxpooling_backward", CPType.Dnn),
+ MAXPOOLING("maxpooling", CPType.Dnn),
+ MAXPOOLING_BACKWARD("maxpooling_backward", CPType.Dnn),
+ AVGPOOLING("avgpooling", CPType.Dnn),
+ AVGPOOLING_BACKWARD("avgpooling_backward", CPType.Dnn),
+ CONV2D("conv2d", CPType.Dnn),
+ CONV2D_BIAS_ADD("conv2d_bias_add", CPType.Dnn),
+ CONV2D_BACKWARD_FILTER("conv2d_backward_filter", CPType.Dnn),
+ CONV2D_BACKWARD_DATA("conv2d_backward_data", CPType.Dnn),
+ BIAS_ADD("bias_add", CPType.Dnn),
+ BIAS_MULTIPLY("bias_multiply", CPType.Dnn),
+ BATCH_NORM2D("batch_norm2d", CPType.Dnn),
+ BATCH_NORM2D_BACKWARD("batch_norm2d_backward", CPType.Dnn),
+ LSTM("lstm", CPType.Dnn),
+ LSTM_BACKWARD("lstm_backward", CPType.Dnn),
+
+ //Quaternary instruction opcodes
+ WSLOSS("wsloss", CPType.Quaternary),
+ WSIGMOID("wsigmoid", CPType.Quaternary),
+ WDIVMM("wdivmm", CPType.Quaternary),
+ WCEMM("wcemm", CPType.Quaternary),
+ WUMM("wumm", CPType.Quaternary),
+
+ //User-defined function Opcodes
+ FCALL(FunctionOp.OPCODE, CPType.FCall),
+
+ APPEND(Append.OPCODE, CPType.Append),
+ REMOVE("remove", CPType.Append),
+
+ //data generation opcodes
+ RANDOM(DataGen.RAND_OPCODE, CPType.Rand),
+ SEQUENCE(DataGen.SEQ_OPCODE, CPType.Rand),
+ STRINGINIT(DataGen.SINIT_OPCODE, CPType.StringInit),
+ SAMPLE(DataGen.SAMPLE_OPCODE, CPType.Rand),
+ TIME(DataGen.TIME_OPCODE, CPType.Rand),
+ FRAME(DataGen.FRAME_OPCODE, CPType.Rand),
+
+ CTABLE("ctable", CPType.Ctable),
+ CTABLEEXPAND("ctableexpand", CPType.Ctable),
+
+ //central moment, covariance, quantiles (sort/pick)
+ CM("cm", CPType.CentralMoment),
+ COV("cov", CPType.Covariance),
+ QSORT("qsort", CPType.QSort),
+ QPICK("qpick", CPType.QPick),
+
+ RIGHT_INDEX(RightIndex.OPCODE, CPType.MatrixIndexing),
+ LEFT_INDEX(LeftIndex.OPCODE, CPType.MatrixIndexing),
+
+ TSMM("tsmm", CPType.MMTSJ),
+ PMM("pmm", CPType.PMMJ),
+ MMCHAIN("mmchain", CPType.MMChain),
+
+ QR("qr", CPType.MultiReturnBuiltin),
+ LU("lu", CPType.MultiReturnBuiltin),
+ EIGEN("eigen", CPType.MultiReturnBuiltin),
+ FFT("fft", CPType.MultiReturnBuiltin),
+ IFFT("ifft", CPType.MultiReturnComplexMatrixBuiltin),
+ FFT_LINEARIZED("fft_linearized", CPType.MultiReturnBuiltin),
+ IFFT_LINEARIZED("ifft_linearized", CPType.MultiReturnComplexMatrixBuiltin),
+ STFT("stft", CPType.MultiReturnComplexMatrixBuiltin),
+ SVD("svd", CPType.MultiReturnBuiltin),
+ RCM("rcm", CPType.MultiReturnComplexMatrixBuiltin),
+
+ PARTITION("partition", CPType.Partition),
+ COMPRESS(Compression.OPCODE, CPType.Compression),
+ DECOMPRESS(DeCompression.OPCODE, CPType.DeCompression),
+ SPOOF("spoof", CPType.SpoofFused),
+ PREFETCH("prefetch", CPType.Prefetch),
+ EVICT("_evict", CPType.EvictLineageCache),
+ BROADCAST("broadcast", CPType.Broadcast),
+ TRIGREMOTE("trigremote", CPType.TrigRemote),
+ LOCAL(Local.OPCODE, CPType.Local),
+
+ SQL("sql", CPType.Sql);
+
+
+ // Constructor
+ Opcodes(String name, CPType type) {
+ this._name = name;
+ this._type = type;
+ }
+
+ // Fields
+ private final String _name;
+ private final CPType _type;
+
+ private static final Map<String, Opcodes> _lookupMap = new HashMap<>();
+
+ // Initialize lookup map
+ static {
+ for (Opcodes op : EnumSet.allOf(Opcodes.class)) {
+ _lookupMap.put(op.getName(), op);
+ }
+ }
+
+ // Getters
+ public String getName() {
+ return _name;
+ }
+
+ public CPType getType() {
+ return _type;
+ }
+
+ public static CPType getCPTypeByOpcode(String opcode) {
+ for (Opcodes op : Opcodes.values()) {
+ if (op.getName().equalsIgnoreCase(opcode.trim())) {
+ return op.getType();
+ }
+ }
+ return null;
+ }
+
+
+
Review Comment:
avoid such free lines at the end of the enum
##########
src/main/java/org/apache/sysds/hops/cost/CostEstimatorStaticRuntime.java:
##########
@@ -281,15 +282,15 @@ private static double getNFLOP( String optype, boolean
inMR, long d1m, long d1n,
//NOTE: all instruction types that are equivalent in CP and MR
are only
//included in CP to prevent redundancy
- CPType cptype =
CPInstructionParser.String2CPInstructionType.get(optype);
- if( cptype != null ) //for CP Ops and equivalent MR ops
+ CPType cptype = Opcodes.valueOf(optype).getType();
Review Comment:
shouldn't this be getCPTypeByOpcode?
##########
src/main/java/org/apache/sysds/common/Opcodes.java:
##########
@@ -0,0 +1,328 @@
+package org.apache.sysds.common;
+
+import org.apache.sysds.lops.*;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.hops.FunctionOp;
+
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
+
+public enum Opcodes {
+ MMULT("ba+*", CPType.AggregateBinary),
+ TAKPM("tak+*", CPType.AggregateTernary),
+ TACKPM("tack+*", CPType.AggregateTernary),
+
+ UAKP("uak+", CPType.AggregateUnary),
+ UARKP("uark+", CPType.AggregateUnary),
+ UACKP( "uack+", CPType.AggregateUnary),
+ UASQKP( "uasqk+", CPType.AggregateUnary),
+ UARSQKP( "uarsqk+", CPType.AggregateUnary),
+ UACSQKP( "uacsqk+", CPType.AggregateUnary),
+ UAMEAN( "uamean", CPType.AggregateUnary),
+ UARMEAN("uarmean", CPType.AggregateUnary),
Review Comment:
There seems to be a formatting issue - please ensure consistent tab
indentation.
##########
src/main/java/org/apache/sysds/common/Opcodes.java:
##########
@@ -0,0 +1,328 @@
+package org.apache.sysds.common;
+
+import org.apache.sysds.lops.*;
+import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
+import org.apache.sysds.common.Types.OpOp1;
+import org.apache.sysds.hops.FunctionOp;
+
+import java.util.Arrays;
+import java.util.EnumSet;
+import java.util.HashMap;
+import java.util.Map;
+
+public enum Opcodes {
+ MMULT("ba+*", CPType.AggregateBinary),
+ TAKPM("tak+*", CPType.AggregateTernary),
+ TACKPM("tack+*", CPType.AggregateTernary),
+
+ UAKP("uak+", CPType.AggregateUnary),
+ UARKP("uark+", CPType.AggregateUnary),
+ UACKP( "uack+", CPType.AggregateUnary),
+ UASQKP( "uasqk+", CPType.AggregateUnary),
+ UARSQKP( "uarsqk+", CPType.AggregateUnary),
+ UACSQKP( "uacsqk+", CPType.AggregateUnary),
+ UAMEAN( "uamean", CPType.AggregateUnary),
+ UARMEAN("uarmean", CPType.AggregateUnary),
+ UACMEAN("uacmean", CPType.AggregateUnary),
+ UAVAR("uavar", CPType.AggregateUnary),
+ UARVAR("uarvar", CPType.AggregateUnary),
+ UACVAR("uacvar", CPType.AggregateUnary),
+ UAMAX("uamax", CPType.AggregateUnary),
+ UARMAX("uarmax", CPType.AggregateUnary),
+ UARIMAX("uarimax", CPType.AggregateUnary),
+ UACMAX("uacmax", CPType.AggregateUnary),
+ UAMIN("uamin", CPType.AggregateUnary),
+ UARMIN("uarmin", CPType.AggregateUnary),
+ UARIMIN("uarimin", CPType.AggregateUnary),
+ UACMIN("uacmin", CPType.AggregateUnary),
+ UAP("ua+", CPType.AggregateUnary),
+ UARP("uar+", CPType.AggregateUnary),
+ UACP("uac+", CPType.AggregateUnary),
+ UAM("ua*", CPType.AggregateUnary),
+ UARM("uar*", CPType.AggregateUnary),
+ UACM("uac*", CPType.AggregateUnary),
+ UATRACE("uatrace", CPType.AggregateUnary),
+ UAKTRACE("uaktrace", CPType.AggregateUnary),
+
+ NROW("nrow", CPType.AggregateUnary),
+ NCOL("ncol", CPType.AggregateUnary),
+ LENGTH("length", CPType.AggregateUnary),
+ EXISTS("exists", CPType.AggregateUnary),
+ LINEAGE("lineage", CPType.AggregateUnary),
+ UACD("uacd", CPType.AggregateUnary),
+ UACDR("uacdr", CPType.AggregateUnary),
+ UACDC("uacdc", CPType.AggregateUnary),
+ UACDAP("uacdap", CPType.AggregateUnary),
+ UACDAPR("uacdapr", CPType.AggregateUnary),
+ UACDAPC("uacdapc", CPType.AggregateUnary),
+ UNIQUE("unique", CPType.AggregateUnary),
+ UNIQUER("uniquer", CPType.AggregateUnary),
+ UNIQUEC("uniquec", CPType.AggregateUnary),
+
+ UAGGOUTERCHAIN("uaggouterchain", CPType.UaggOuterChain),
+
+ // Arithmetic Instruction Opcodes
+ PLUS("+", CPType.Binary),
+ MINUS("-", CPType.Binary),
+ MULT("*", CPType.Binary),
+ DIV("/", CPType.Binary),
+ MODULUS("%%", CPType.Binary),
+ INTDIV("%/%", CPType.Binary),
+ POW("^", CPType.Binary),
+ MINUS1_MULT("1-*", CPType.Binary), //special * case
+ POW2("^2", CPType.Binary), //special ^ case
+ MULT2("*2", CPType.Binary), //special * case
+ MINUS_NZ("-nz", CPType.Binary), //special - case
+
+ // Boolean Instruction Opcodes
+ AND("&&", CPType.Binary),
+ OR("||", CPType.Binary),
+ XOR("xor", CPType.Binary),
+ BITWAND("bitwAnd", CPType.Binary),
+ BITWOR("bitwOr", CPType.Binary),
+ BITWXOR("bitwXor", CPType.Binary),
+ BITWSHIFTL("bitwShiftL", CPType.Binary),
+ BITWSHIFTR("bitwShiftR", CPType.Binary),
+ NOT("!", CPType.Unary),
+
+ // Relational Instruction Opcodes
+ EQUAL("==", CPType.Binary),
+ NOTEQUAL("!=", CPType.Binary),
+ LESS("<", CPType.Binary),
+ GREATER(">", CPType.Binary),
+ LESSEQUAL("<=", CPType.Binary),
+ GREATEREQUAL(">=", CPType.Binary),
+
+ // Builtin Instruction Opcodes
+ LOG("log", CPType.Builtin),
+ LOGNZ("log_nz", CPType.Builtin),
+
+ SOLVE("solve", CPType.Binary),
+ MAX("max", CPType.Binary),
+ MIN("min", CPType.Binary),
+ DROPINVALIDTYPE("dropInvalidType", CPType.Binary),
+ DROPINVALIDLENGTH("dropInvalidLength", CPType.Binary),
+ FREPLICATE("freplicate", CPType.Binary),
+ VALUESWAP("valueSwap", CPType.Binary),
+ APPLYSCHEMA("applySchema", CPType.Binary),
+ MAP("_map", CPType.Ternary),
+
+ NMAX("nmax", CPType.BuiltinNary),
+ NMIN("nmin", CPType.BuiltinNary),
+ NP("n+", CPType.BuiltinNary),
+ NM("n*", CPType.BuiltinNary),
+
+ EXP("exp", CPType.Unary),
+ ABS("abs", CPType.Unary),
+ SIN("sin", CPType.Unary),
+ COS("cos", CPType.Unary),
+ TAN("tan", CPType.Unary),
+ SINH("sinh", CPType.Unary),
+ COSH("cosh", CPType.Unary),
+ TANH("tanh", CPType.Unary),
+ ASIN("asin", CPType.Unary),
+ ACOS("acos", CPType.Unary),
+ ATAN("atan", CPType.Unary),
+ SIGN("sign", CPType.Unary),
+ SQRT("sqrt", CPType.Unary),
+ PLOGP("plogp", CPType.Unary),
+ PRINT("print", CPType.Unary),
+ ASSERT("assert", CPType.Unary),
+ ROUND("round", CPType.Unary),
+ CEIL("ceil", CPType.Unary),
+ FLOOR("floor", CPType.Unary),
+ UCUMKP("ucumk+", CPType.Unary),
+ UCUMM("ucum*", CPType.Unary),
+ UCUMKPM("ucumk+*", CPType.Unary),
+ UCUMMIN("ucummin", CPType.Unary),
+ UCUMMAX("ucummax", CPType.Unary),
+ STOP("stop", CPType.Unary),
+ INVERSE("inverse", CPType.Unary),
+ CHOLESKY("cholesky", CPType.Unary),
+ SPROP("sprop", CPType.Unary),
+ SIGMOID("sigmoid", CPType.Unary),
+ TYPEOF("typeOf", CPType.Unary),
+ DETECTSCHEMA("detectSchema", CPType.Unary),
+ COLNAMES("colnames", CPType.Unary),
+ ISNA("isna", CPType.Unary),
+ ISNAN("isnan", CPType.Unary),
+ ISINF("isinf", CPType.Unary),
+ PRINTF("printf", CPType.BuiltinNary),
+ CBIND("cbind", CPType.BuiltinNary),
+ RBIND("rbind", CPType.BuiltinNary),
+ EVAL("eval", CPType.BuiltinNary),
+ LIST("list", CPType.BuiltinNary),
+
+ //Parametrized builtin functions
+ AUTODIFF("autoDiff", CPType.ParameterizedBuiltin),
+ CONTAINS("contains", CPType.ParameterizedBuiltin),
+ PARAMSERV("paramserv", CPType.ParameterizedBuiltin),
+ NVLIST("nvlist", CPType.ParameterizedBuiltin),
+ CDF("cdf", CPType.ParameterizedBuiltin),
+ INVCDF("invcdf", CPType.ParameterizedBuiltin),
+ GROUPEDAGG("groupedagg", CPType.ParameterizedBuiltin),
+ RMEMPTY("rmempty", CPType.ParameterizedBuiltin),
+ REPLACE("replace", CPType.ParameterizedBuiltin),
+ LOWERTRI("lowertri", CPType.ParameterizedBuiltin),
+ UPPERTRI("uppertri", CPType.ParameterizedBuiltin),
+ REXPAND("rexpand", CPType.ParameterizedBuiltin),
+ TOSTRING("toString", CPType.ParameterizedBuiltin),
+ TOKENIZE("tokenize", CPType.ParameterizedBuiltin),
+ TRANSFORMAPPLY("transformapply", CPType.ParameterizedBuiltin),
+ TRANSFORMDECODE("transformdecode", CPType.ParameterizedBuiltin),
+ TRANSFORMCOLMAP("transformcolmap", CPType.ParameterizedBuiltin),
+ TRANSFORMMETA("transformmeta", CPType.ParameterizedBuiltin),
+ TRANSFORMENCODE("transformencode", CPType.MultiReturnParameterizedBuiltin),
+
+ //Ternary instruction opcodes
+ PM("+*", CPType.Ternary),
+ MINUSMULT("-*", CPType.Ternary),
+ IFELSE("ifelse", CPType.Ternary),
+
+ //Variable instruction opcodes
+ ASSIGNVAR("assignvar", CPType.Variable),
+ CPVAR("cpvar", CPType.Variable),
+ MVVAR("mvvar", CPType.Variable),
+ RMVAR("rmvar", CPType.Variable),
+ RMFILEVAR("rmfilevar", CPType.Variable),
+ CAST_AS_SCALAR(OpOp1.CAST_AS_SCALAR.toString(), CPType.Variable),
+ CAST_AS_MATRIX(OpOp1.CAST_AS_MATRIX.toString(), CPType.Variable),
+ CAST_AS_FRAME_VAR("cast_as_frame", CPType.Variable),
+ CAST_AS_FRAME(OpOp1.CAST_AS_FRAME.toString(), CPType.Variable),
+ CAST_AS_LIST(OpOp1.CAST_AS_LIST.toString(), CPType.Variable),
+ CAST_AS_DOUBLE(OpOp1.CAST_AS_DOUBLE.toString(), CPType.Variable),
+ CAST_AS_INT(OpOp1.CAST_AS_INT.toString(), CPType.Variable),
+ CAST_AS_BOOLEAN(OpOp1.CAST_AS_BOOLEAN.toString(), CPType.Variable),
+ ATTACHFILETOVAR("attachfiletovar", CPType.Variable),
+ READ("read", CPType.Variable),
+ WRITE("write", CPType.Variable),
+ CREATEVAR("createvar", CPType.Variable),
+
+ //Reorg instruction opcodes
+ TRANSPOSE("r'", CPType.Reorg),
+ REV("rev", CPType.Reorg),
+ ROLL("roll", CPType.Reorg),
+ DIAG("rdiag", CPType.Reorg),
+ RESHAPE("rshape", CPType.Reshape),
+ SORT("rsort", CPType.Reorg),
+
+ // Opcodes related to convolutions
+ RELU_BACKWARD("relu_backward", CPType.Dnn),
+ RELU_MAXPOOLING("relu_maxpooling", CPType.Dnn),
+ RELU_MAXPOOLING_BACKWARD("relu_maxpooling_backward", CPType.Dnn),
+ MAXPOOLING("maxpooling", CPType.Dnn),
+ MAXPOOLING_BACKWARD("maxpooling_backward", CPType.Dnn),
+ AVGPOOLING("avgpooling", CPType.Dnn),
+ AVGPOOLING_BACKWARD("avgpooling_backward", CPType.Dnn),
+ CONV2D("conv2d", CPType.Dnn),
+ CONV2D_BIAS_ADD("conv2d_bias_add", CPType.Dnn),
+ CONV2D_BACKWARD_FILTER("conv2d_backward_filter", CPType.Dnn),
+ CONV2D_BACKWARD_DATA("conv2d_backward_data", CPType.Dnn),
+ BIAS_ADD("bias_add", CPType.Dnn),
+ BIAS_MULTIPLY("bias_multiply", CPType.Dnn),
+ BATCH_NORM2D("batch_norm2d", CPType.Dnn),
+ BATCH_NORM2D_BACKWARD("batch_norm2d_backward", CPType.Dnn),
+ LSTM("lstm", CPType.Dnn),
+ LSTM_BACKWARD("lstm_backward", CPType.Dnn),
+
+ //Quaternary instruction opcodes
+ WSLOSS("wsloss", CPType.Quaternary),
+ WSIGMOID("wsigmoid", CPType.Quaternary),
+ WDIVMM("wdivmm", CPType.Quaternary),
+ WCEMM("wcemm", CPType.Quaternary),
+ WUMM("wumm", CPType.Quaternary),
+
+ //User-defined function Opcodes
+ FCALL(FunctionOp.OPCODE, CPType.FCall),
+
+ APPEND(Append.OPCODE, CPType.Append),
+ REMOVE("remove", CPType.Append),
+
+ //data generation opcodes
+ RANDOM(DataGen.RAND_OPCODE, CPType.Rand),
+ SEQUENCE(DataGen.SEQ_OPCODE, CPType.Rand),
+ STRINGINIT(DataGen.SINIT_OPCODE, CPType.StringInit),
+ SAMPLE(DataGen.SAMPLE_OPCODE, CPType.Rand),
+ TIME(DataGen.TIME_OPCODE, CPType.Rand),
+ FRAME(DataGen.FRAME_OPCODE, CPType.Rand),
+
+ CTABLE("ctable", CPType.Ctable),
+ CTABLEEXPAND("ctableexpand", CPType.Ctable),
+
+ //central moment, covariance, quantiles (sort/pick)
+ CM("cm", CPType.CentralMoment),
+ COV("cov", CPType.Covariance),
+ QSORT("qsort", CPType.QSort),
+ QPICK("qpick", CPType.QPick),
+
+ RIGHT_INDEX(RightIndex.OPCODE, CPType.MatrixIndexing),
+ LEFT_INDEX(LeftIndex.OPCODE, CPType.MatrixIndexing),
+
+ TSMM("tsmm", CPType.MMTSJ),
+ PMM("pmm", CPType.PMMJ),
+ MMCHAIN("mmchain", CPType.MMChain),
+
+ QR("qr", CPType.MultiReturnBuiltin),
+ LU("lu", CPType.MultiReturnBuiltin),
+ EIGEN("eigen", CPType.MultiReturnBuiltin),
+ FFT("fft", CPType.MultiReturnBuiltin),
+ IFFT("ifft", CPType.MultiReturnComplexMatrixBuiltin),
+ FFT_LINEARIZED("fft_linearized", CPType.MultiReturnBuiltin),
+ IFFT_LINEARIZED("ifft_linearized", CPType.MultiReturnComplexMatrixBuiltin),
+ STFT("stft", CPType.MultiReturnComplexMatrixBuiltin),
+ SVD("svd", CPType.MultiReturnBuiltin),
+ RCM("rcm", CPType.MultiReturnComplexMatrixBuiltin),
+
+ PARTITION("partition", CPType.Partition),
+ COMPRESS(Compression.OPCODE, CPType.Compression),
+ DECOMPRESS(DeCompression.OPCODE, CPType.DeCompression),
+ SPOOF("spoof", CPType.SpoofFused),
+ PREFETCH("prefetch", CPType.Prefetch),
+ EVICT("_evict", CPType.EvictLineageCache),
+ BROADCAST("broadcast", CPType.Broadcast),
+ TRIGREMOTE("trigremote", CPType.TrigRemote),
+ LOCAL(Local.OPCODE, CPType.Local),
+
+ SQL("sql", CPType.Sql);
+
+
+ // Constructor
+ Opcodes(String name, CPType type) {
+ this._name = name;
+ this._type = type;
+ }
+
+ // Fields
+ private final String _name;
+ private final CPType _type;
+
+ private static final Map<String, Opcodes> _lookupMap = new HashMap<>();
+
+ // Initialize lookup map
+ static {
+ for (Opcodes op : EnumSet.allOf(Opcodes.class)) {
+ _lookupMap.put(op.getName(), op);
+ }
+ }
+
+ // Getters
+ public String getName() {
+ return _name;
+ }
Review Comment:
instead of this method, please overwrite `toString` which then allows to use
`OpCode.XXX` instead of `OpCode.XXX.getName()`
##########
src/main/java/org/apache/sysds/hops/codegen/cplan/CNodeBinary.java:
##########
@@ -270,26 +271,26 @@ public String toString() {
case VECT_CBIND: return "b(cbind)";
case VECT_BIASADD: return "b(vbias+)";
case VECT_BIASMULT: return "b(vbias*)";
- case MULT: return "b(*)";
- case DIV: return "b(/)";
- case PLUS: return "b(+)";
- case MINUS: return "b(-)";
- case POW: return "b(^)";
- case MODULUS: return "b(%%)";
- case INTDIV: return "b(%/%)";
- case LESS: return "b(<)";
- case LESSEQUAL: return "b(<=)";
- case GREATER: return "b(>)";
- case GREATEREQUAL: return "b(>=)";
- case EQUAL: return "b(==)";
- case NOTEQUAL: return "b(!=)";
+ case MULT: return "b(" +
Opcodes.MULT.getName() + ")";
+ case DIV: return "b(" +
Opcodes.DIV.getName() + ")";
+ case PLUS: return "b(" +
Opcodes.PLUS.getName() + ")";
+ case MINUS: return "b(" +
Opcodes.MINUS.getName() + ")";
+ case POW: return "b(" +
Opcodes.POW.getName() + ")";
+ case MODULUS: return "b(" +
Opcodes.MODULUS.getName() + ")";
+ case INTDIV: return "b(" +
Opcodes.INTDIV.getName() + ")";
+ case LESS: return "b(" +
Opcodes.LESS.getName() + ")";
+ case LESSEQUAL: return "b(" +
Opcodes.LESSEQUAL.getName() + ")";
+ case GREATER: return "b(" +
Opcodes.GREATER.getName() + ")";
+ case GREATEREQUAL: return "b(" +
Opcodes.GREATEREQUAL.getName() + ")";
+ case EQUAL: return "b(" +
Opcodes.EQUAL.getName() + ")";
+ case NOTEQUAL: return "b(" +
Opcodes.NOTEQUAL.getName() + ")";
case OR: return "b(|)";
case AND: return "b(&)";
- case XOR: return "b(xor)";
- case BITWAND: return "b(bitwAnd)";
+ case XOR: return "b(" +
Opcodes.XOR.getName() + ")";
Review Comment:
XOR can use the default branch
##########
src/main/java/org/apache/sysds/common/Opcodes.java:
##########
@@ -0,0 +1,328 @@
+package org.apache.sysds.common;
Review Comment:
please add the license header
##########
src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java:
##########
@@ -79,288 +69,11 @@
public class CPInstructionParser extends InstructionParser {
protected static final Log LOG =
LogFactory.getLog(CPInstructionParser.class.getName());
- public static final HashMap<String, CPType> String2CPInstructionType;
- static {
- String2CPInstructionType = new HashMap<>();
- String2CPInstructionType.put( "ba+*" ,
CPType.AggregateBinary);
- String2CPInstructionType.put( "tak+*" ,
CPType.AggregateTernary);
- String2CPInstructionType.put( "tack+*" ,
CPType.AggregateTernary);
-
- String2CPInstructionType.put( "uak+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uark+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uack+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uasqk+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarsqk+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacsqk+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uamean" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarmean" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacmean" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uavar" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarvar" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacvar" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uamax" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarmax" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarimax" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacmax" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uamin" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarmin" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uarimin" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacmin" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "ua+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uar+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uac+" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "ua*" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uar*" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uac*" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uatrace" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uaktrace",
CPType.AggregateUnary);
- String2CPInstructionType.put( "nrow" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "ncol" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "length" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "exists" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "lineage" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacd" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacdr" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacdc" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacdap" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacdapr" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uacdapc" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "unique" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uniquer" ,
CPType.AggregateUnary);
- String2CPInstructionType.put( "uniquec" ,
CPType.AggregateUnary);
-
- String2CPInstructionType.put( "uaggouterchain",
CPType.UaggOuterChain);
-
- // Arithmetic Instruction Opcodes
- String2CPInstructionType.put( "+" , CPType.Binary);
- String2CPInstructionType.put( "-" , CPType.Binary);
- String2CPInstructionType.put( "*" , CPType.Binary);
- String2CPInstructionType.put( "/" , CPType.Binary);
- String2CPInstructionType.put( "%%" , CPType.Binary);
- String2CPInstructionType.put( "%/%" , CPType.Binary);
- String2CPInstructionType.put( "^" , CPType.Binary);
- String2CPInstructionType.put( "1-*" , CPType.Binary);
//special * case
- String2CPInstructionType.put( "^2" , CPType.Binary);
//special ^ case
- String2CPInstructionType.put( "*2" , CPType.Binary);
//special * case
- String2CPInstructionType.put( "-nz" , CPType.Binary);
//special - case
-
- // Boolean Instruction Opcodes
- String2CPInstructionType.put( "&&" , CPType.Binary);
- String2CPInstructionType.put( "||" , CPType.Binary);
- String2CPInstructionType.put( "xor" , CPType.Binary);
- String2CPInstructionType.put( "bitwAnd", CPType.Binary);
- String2CPInstructionType.put( "bitwOr", CPType.Binary);
- String2CPInstructionType.put( "bitwXor", CPType.Binary);
- String2CPInstructionType.put( "bitwShiftL", CPType.Binary);
- String2CPInstructionType.put( "bitwShiftR", CPType.Binary);
- String2CPInstructionType.put( "!" , CPType.Unary);
-
- // Relational Instruction Opcodes
- String2CPInstructionType.put( "==" , CPType.Binary);
- String2CPInstructionType.put( "!=" , CPType.Binary);
- String2CPInstructionType.put( "<" , CPType.Binary);
- String2CPInstructionType.put( ">" , CPType.Binary);
- String2CPInstructionType.put( "<=" , CPType.Binary);
- String2CPInstructionType.put( ">=" , CPType.Binary);
-
- // Builtin Instruction Opcodes
- String2CPInstructionType.put( "log" , CPType.Builtin);
- String2CPInstructionType.put( "log_nz" , CPType.Builtin);
-
- String2CPInstructionType.put( "solve" , CPType.Binary);
- String2CPInstructionType.put( "max" , CPType.Binary);
- String2CPInstructionType.put( "min" , CPType.Binary);
- String2CPInstructionType.put( "dropInvalidType" ,
CPType.Binary);
- String2CPInstructionType.put( "dropInvalidLength" ,
CPType.Binary);
- String2CPInstructionType.put( "freplicate" , CPType.Binary);
- String2CPInstructionType.put( "valueSwap" , CPType.Binary);
- String2CPInstructionType.put( "applySchema" , CPType.Binary);
- String2CPInstructionType.put( "_map" , CPType.Ternary); //
_map represents the operation map
-
- String2CPInstructionType.put( "nmax", CPType.BuiltinNary);
- String2CPInstructionType.put( "nmin", CPType.BuiltinNary);
- String2CPInstructionType.put( "n+" , CPType.BuiltinNary);
- String2CPInstructionType.put( "n*" , CPType.BuiltinNary);
-
- String2CPInstructionType.put( "exp" , CPType.Unary);
- String2CPInstructionType.put( "abs" , CPType.Unary);
- String2CPInstructionType.put( "sin" , CPType.Unary);
- String2CPInstructionType.put( "cos" , CPType.Unary);
- String2CPInstructionType.put( "tan" , CPType.Unary);
- String2CPInstructionType.put( "sinh" , CPType.Unary);
- String2CPInstructionType.put( "cosh" , CPType.Unary);
- String2CPInstructionType.put( "tanh" , CPType.Unary);
- String2CPInstructionType.put( "asin" , CPType.Unary);
- String2CPInstructionType.put( "acos" , CPType.Unary);
- String2CPInstructionType.put( "atan" , CPType.Unary);
- String2CPInstructionType.put( "sign" , CPType.Unary);
- String2CPInstructionType.put( "sqrt" , CPType.Unary);
- String2CPInstructionType.put( "plogp" , CPType.Unary);
- String2CPInstructionType.put( "print" , CPType.Unary);
- String2CPInstructionType.put( "assert" , CPType.Unary);
- String2CPInstructionType.put( "round" , CPType.Unary);
- String2CPInstructionType.put( "ceil" , CPType.Unary);
- String2CPInstructionType.put( "floor" , CPType.Unary);
- String2CPInstructionType.put( "ucumk+", CPType.Unary);
- String2CPInstructionType.put( "ucum*" , CPType.Unary);
- String2CPInstructionType.put( "ucumk+*" , CPType.Unary);
- String2CPInstructionType.put( "ucummin", CPType.Unary);
- String2CPInstructionType.put( "ucummax", CPType.Unary);
- String2CPInstructionType.put( "stop" , CPType.Unary);
- String2CPInstructionType.put( "inverse", CPType.Unary);
- String2CPInstructionType.put( "cholesky",CPType.Unary);
- String2CPInstructionType.put( "sprop", CPType.Unary);
- String2CPInstructionType.put( "sigmoid", CPType.Unary);
- String2CPInstructionType.put( "typeOf", CPType.Unary);
- String2CPInstructionType.put( "detectSchema", CPType.Unary);
- String2CPInstructionType.put( "colnames", CPType.Unary);
- String2CPInstructionType.put( "isna", CPType.Unary);
- String2CPInstructionType.put( "isnan", CPType.Unary);
- String2CPInstructionType.put( "isinf", CPType.Unary);
- String2CPInstructionType.put( "printf", CPType.BuiltinNary);
- String2CPInstructionType.put( "cbind", CPType.BuiltinNary);
- String2CPInstructionType.put( "rbind", CPType.BuiltinNary);
- String2CPInstructionType.put( "eval", CPType.BuiltinNary);
- String2CPInstructionType.put( "list", CPType.BuiltinNary);
-
- // Parameterized Builtin Functions
- String2CPInstructionType.put( "autoDiff" ,
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "contains",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put("paramserv",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "nvlist",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "cdf",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "invcdf",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "groupedagg",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "rmempty" ,
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "replace",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "lowertri",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "uppertri",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "rexpand",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "toString",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "tokenize",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "transformapply",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put(
"transformdecode",CPType.ParameterizedBuiltin);
- String2CPInstructionType.put(
"transformcolmap",CPType.ParameterizedBuiltin);
- String2CPInstructionType.put( "transformmeta",
CPType.ParameterizedBuiltin);
- String2CPInstructionType.put(
"transformencode",CPType.MultiReturnParameterizedBuiltin);
-
- // Ternary Instruction Opcodes
- String2CPInstructionType.put( "+*", CPType.Ternary);
- String2CPInstructionType.put( "-*", CPType.Ternary);
- String2CPInstructionType.put( "ifelse", CPType.Ternary);
-
- // Variable Instruction Opcodes
- String2CPInstructionType.put( "assignvar" , CPType.Variable);
- String2CPInstructionType.put( "cpvar" , CPType.Variable);
- String2CPInstructionType.put( "mvvar" , CPType.Variable);
- String2CPInstructionType.put( "rmvar" , CPType.Variable);
- String2CPInstructionType.put( "rmfilevar" , CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_SCALAR.toString(),
CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_MATRIX.toString(),
CPType.Variable);
- String2CPInstructionType.put( "cast_as_frame", CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_FRAME.toString(),
CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_LIST.toString(),
CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_DOUBLE.toString(),
CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_INT.toString(),
CPType.Variable);
- String2CPInstructionType.put( OpOp1.CAST_AS_BOOLEAN.toString(),
CPType.Variable);
- String2CPInstructionType.put( "attachfiletovar" ,
CPType.Variable);
- String2CPInstructionType.put( "read" , CPType.Variable);
- String2CPInstructionType.put( "write" , CPType.Variable);
- String2CPInstructionType.put( "createvar" , CPType.Variable);
-
- // Reorg Instruction Opcodes (repositioning of existing values)
- String2CPInstructionType.put( "r'" , CPType.Reorg);
- String2CPInstructionType.put( "rev" , CPType.Reorg);
- String2CPInstructionType.put( "roll" , CPType.Reorg);
- String2CPInstructionType.put( "rdiag" , CPType.Reorg);
- String2CPInstructionType.put( "rshape" , CPType.Reshape);
- String2CPInstructionType.put( "rsort" , CPType.Reorg);
-
- // Opcodes related to convolutions
- String2CPInstructionType.put( "relu_backward" ,
CPType.Dnn);
- String2CPInstructionType.put( "relu_maxpooling" ,
CPType.Dnn);
- String2CPInstructionType.put( "relu_maxpooling_backward" ,
CPType.Dnn);
- String2CPInstructionType.put( "maxpooling" , CPType.Dnn);
- String2CPInstructionType.put( "maxpooling_backward" ,
CPType.Dnn);
- String2CPInstructionType.put( "avgpooling" , CPType.Dnn);
- String2CPInstructionType.put( "avgpooling_backward" ,
CPType.Dnn);
- String2CPInstructionType.put( "conv2d" , CPType.Dnn);
- String2CPInstructionType.put( "conv2d_bias_add" ,
CPType.Dnn);
- String2CPInstructionType.put( "conv2d_backward_filter" ,
CPType.Dnn);
- String2CPInstructionType.put( "conv2d_backward_data" ,
CPType.Dnn);
- String2CPInstructionType.put( "bias_add" , CPType.Dnn);
- String2CPInstructionType.put( "bias_multiply" ,
CPType.Dnn);
- String2CPInstructionType.put( "batch_norm2d",
CPType.Dnn);
- String2CPInstructionType.put( "batch_norm2d_backward",
CPType.Dnn);
- String2CPInstructionType.put( "lstm" , CPType.Dnn);
- String2CPInstructionType.put( "lstm_backward" ,
CPType.Dnn);
-
- // Quaternary instruction opcodes
- String2CPInstructionType.put( "wsloss" , CPType.Quaternary);
- String2CPInstructionType.put( "wsigmoid", CPType.Quaternary);
- String2CPInstructionType.put( "wdivmm", CPType.Quaternary);
- String2CPInstructionType.put( "wcemm", CPType.Quaternary);
- String2CPInstructionType.put( "wumm", CPType.Quaternary);
-
- // User-defined function Opcodes
- String2CPInstructionType.put(FunctionOp.OPCODE, CPType.FCall);
-
- String2CPInstructionType.put(Append.OPCODE, CPType.Append);
- String2CPInstructionType.put( "remove", CPType.Append);
-
- // data generation opcodes
- String2CPInstructionType.put( DataGen.RAND_OPCODE ,
CPType.Rand);
- String2CPInstructionType.put( DataGen.SEQ_OPCODE ,
CPType.Rand);
- String2CPInstructionType.put( DataGen.SINIT_OPCODE ,
CPType.StringInit);
- String2CPInstructionType.put( DataGen.SAMPLE_OPCODE ,
CPType.Rand);
- String2CPInstructionType.put( DataGen.TIME_OPCODE ,
CPType.Rand);
- String2CPInstructionType.put( DataGen.FRAME_OPCODE ,
CPType.Rand);
-
- String2CPInstructionType.put( "ctable", CPType.Ctable);
- String2CPInstructionType.put( "ctableexpand", CPType.Ctable);
-
- //central moment, covariance, quantiles (sort/pick)
- String2CPInstructionType.put( "cm", CPType.CentralMoment);
- String2CPInstructionType.put( "cov", CPType.Covariance);
- String2CPInstructionType.put( "qsort", CPType.QSort);
- String2CPInstructionType.put( "qpick", CPType.QPick);
-
-
- String2CPInstructionType.put( RightIndex.OPCODE,
CPType.MatrixIndexing);
- String2CPInstructionType.put( LeftIndex.OPCODE,
CPType.MatrixIndexing);
-
- String2CPInstructionType.put( "tsmm", CPType.MMTSJ);
- String2CPInstructionType.put( "pmm", CPType.PMMJ);
- String2CPInstructionType.put( "mmchain", CPType.MMChain);
-
- String2CPInstructionType.put( "qr",
CPType.MultiReturnBuiltin);
- String2CPInstructionType.put( "lu",
CPType.MultiReturnBuiltin);
- String2CPInstructionType.put( "eigen",
CPType.MultiReturnBuiltin);
- String2CPInstructionType.put( "fft",
CPType.MultiReturnBuiltin);
- String2CPInstructionType.put( "ifft",
CPType.MultiReturnComplexMatrixBuiltin);
- String2CPInstructionType.put( "fft_linearized",
CPType.MultiReturnBuiltin);
- String2CPInstructionType.put( "ifft_linearized",
CPType.MultiReturnComplexMatrixBuiltin);
- String2CPInstructionType.put( "stft",
CPType.MultiReturnComplexMatrixBuiltin);
- String2CPInstructionType.put( "svd",
CPType.MultiReturnBuiltin);
- String2CPInstructionType.put( "rcm",
CPType.MultiReturnComplexMatrixBuiltin);
-
- String2CPInstructionType.put( "partition", CPType.Partition);
- String2CPInstructionType.put( Compression.OPCODE,
CPType.Compression);
- String2CPInstructionType.put( DeCompression.OPCODE,
CPType.DeCompression);
- String2CPInstructionType.put( "spoof", CPType.SpoofFused);
- String2CPInstructionType.put( "prefetch", CPType.Prefetch);
- String2CPInstructionType.put( "_evict",
CPType.EvictLineageCache);
- String2CPInstructionType.put( "broadcast", CPType.Broadcast);
- String2CPInstructionType.put( "trigremote", CPType.TrigRemote);
- String2CPInstructionType.put( Local.OPCODE, CPType.Local);
-
- String2CPInstructionType.put( "sql", CPType.Sql);
- }
-
public static CPInstruction parseSingleInstruction (String str ) {
if ( str == null || str.isEmpty() )
return null;
- CPType cptype = InstructionUtils.getCPType(str);
+ CPType cptype = InstructionUtils.getCPType(str);
+ //CPType cptype = Opcodes.getCPTypeByOpcode(str);
Review Comment:
avoid such commented code
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]