This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new c2492d7  [MINOR] Add missing unary support federated
c2492d7 is described below

commit c2492d7c6ae6b46d69f5b543de428fb30b8311a8
Author: baunsgaard <[email protected]>
AuthorDate: Tue Apr 27 21:03:10 2021 +0200

    [MINOR] Add missing unary support federated
    
    Add missing support for log and sigmoid for unary operations federated.
    This check is a basic copy of the CP Unary operations, but if it is
    possible to execute locally that unary operation should be possible
    federated.
---
 .../runtime/instructions/cp/UnaryCPInstruction.java      |  7 ++++---
 .../instructions/fed/UnaryMatrixFEDInstruction.java      | 16 +++++++++++-----
 2 files changed, 15 insertions(+), 8 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
index 17ae660..0c98e84 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
@@ -62,9 +62,10 @@ public abstract class UnaryCPInstruction extends 
ComputationCPInstruction {
                        out.split(parts[2]);
                        func = Builtin.getBuiltinFnObject(opcode);
                        
-                       if( Arrays.asList(new 
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
 )
-                               return new UnaryMatrixCPInstruction(new 
UnaryOperator(func,
-                                       
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4])), in, out, opcode, 
str);
+                       if( Arrays.asList(new 
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
 ){
+                               UnaryOperator op = new UnaryOperator(func, 
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
+                               return new UnaryMatrixCPInstruction(op, in, 
out, opcode, str);
+                       }
                        else
                                return new UnaryScalarCPInstruction(null, in, 
out, opcode, str);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
index 1f5cdd2..24a850f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryMatrixFEDInstruction.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import java.util.Arrays;
 import java.util.concurrent.Future;
 
 import org.apache.sysds.common.Types.DataType;
@@ -40,6 +41,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
 
 public class UnaryMatrixFEDInstruction extends UnaryFEDInstruction {
+
        protected UnaryMatrixFEDInstruction(Operator op, CPOperand in, 
CPOperand out, String opcode, String instr) {
                super(FEDType.Unary, op, in, out, opcode, instr);
        }
@@ -53,14 +55,18 @@ public class UnaryMatrixFEDInstruction extends 
UnaryFEDInstruction {
                CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
 
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
-               String opcode;
-               opcode = parts[0];
-               if( (opcode.equalsIgnoreCase("exp") || 
opcode.startsWith("ucum")) && parts.length == 5) {
+               String opcode = parts[0];
+               
+               if(parts.length == 5 && (opcode.equalsIgnoreCase("exp") || 
opcode.equalsIgnoreCase("log") || opcode.startsWith("ucum"))) {
                        in.split(parts[1]);
                        out.split(parts[2]);
                        ValueFunction func = Builtin.getBuiltinFnObject(opcode);
-                       return new UnaryMatrixFEDInstruction(new 
UnaryOperator(func,
-                               
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4])), in, out, opcode, 
str);
+                       if( Arrays.asList(new 
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
 ){
+                               UnaryOperator op = new 
UnaryOperator(func,Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
+                               return new UnaryMatrixFEDInstruction(op, in, 
out, opcode, str);
+                       }
+                       else
+                               return new UnaryMatrixFEDInstruction(null, in, 
out, opcode, str);
                }
                opcode = parseUnaryInstruction(str, in, out);
                return new 
UnaryMatrixFEDInstruction(InstructionUtils.parseUnaryOperator(opcode), in, out, 
opcode, str);

Reply via email to