This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new c2492d7 [MINOR] Add missing unary support federated
c2492d7 is described below
commit c2492d7c6ae6b46d69f5b543de428fb30b8311a8
Author: baunsgaard <[email protected]>
AuthorDate: Tue Apr 27 21:03:10 2021 +0200
[MINOR] Add missing unary support federated
Add missing support for log and sigmoid for unary operations federated.
This check is a basic copy of the CP Unary operations, but if it is
possible to execute locally that unary operation should be possible
federated.
---
.../runtime/instructions/cp/UnaryCPInstruction.java | 7 ++++---
.../instructions/fed/UnaryMatrixFEDInstruction.java | 16 +++++++++++-----
2 files changed, 15 insertions(+), 8 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
index 17ae660..0c98e84 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
@@ -62,9 +62,10 @@ public abstract class UnaryCPInstruction extends
ComputationCPInstruction {
out.split(parts[2]);
func = Builtin.getBuiltinFnObject(opcode);
- if( Arrays.asList(new
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
)
- return new UnaryMatrixCPInstruction(new
UnaryOperator(func,
-
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4])), in, out, opcode,
str);
+ if( Arrays.asList(new
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
){
+ UnaryOperator op = new UnaryOperator(func,
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
+ return new UnaryMatrixCPInstruction(op, in,
out, opcode, str);
+ }
else
return new UnaryScalarCPInstruction(null, in,
out, opcode, str);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index 1f5cdd2..24a850f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.Arrays;
import java.util.concurrent.Future;
import org.apache.sysds.common.Types.DataType;
@@ -40,6 +41,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
+
protected UnaryMatrixFEDInstruction(Operator op, CPOperand in,
CPOperand out, String opcode, String instr) {
super(FEDType.Unary, op, in, out, opcode, instr);
}
@@ -53,14 +55,18 @@ public class UnaryMatrixFEDInstruction extends
UnaryFEDInstruction {
CPOperand out = new CPOperand("", ValueType.UNKNOWN,
DataType.UNKNOWN);
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
- String opcode;
- opcode = parts[0];
- if( (opcode.equalsIgnoreCase("exp") ||
opcode.startsWith("ucum")) && parts.length == 5) {
+ String opcode = parts[0];
+
+ if(parts.length == 5 && (opcode.equalsIgnoreCase("exp") ||
opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
in.split(parts[1]);
out.split(parts[2]);
ValueFunction func = Builtin.getBuiltinFnObject(opcode);
- return new UnaryMatrixFEDInstruction(new
UnaryOperator(func,
-
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4])), in, out, opcode,
str);
+ if( Arrays.asList(new
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
){
+ UnaryOperator op = new
UnaryOperator(func,Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
+ return new UnaryMatrixFEDInstruction(op, in,
out, opcode, str);
+ }
+ else
+ return new UnaryMatrixFEDInstruction(null, in,
out, opcode, str);
}
opcode = parseUnaryInstruction(str, in, out);
return new
UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out,
opcode, str);