This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit ed8b5f5c725526f8414d2dbe9113d52260597352
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Dec 17 13:17:03 2020 +0100

    [SYSTEMDS-2759] Fix error handling federated read, wrong test meta data
    
    This patch improves the federated read by explicit error handling for
    invalid meta data (e.g., federated data partitions larger than global
    matrix dimensions) in order to detect meta data inconsistencies.
---
 .github/workflows/functionsTests.yml                 |  3 ++-
 .../federated/FederatedWorkerHandler.java            |  4 ++--
 .../runtime/instructions/fed/InitFEDInstruction.java | 20 ++++++++++++++++----
 .../test/functions/privacy/FederatedLmCGTest.java    |  7 ++++---
 .../scripts/functions/privacy/FederatedLmCG2.dml     |  2 +-
 5 files changed, 25 insertions(+), 11 deletions(-)

diff --git a/.github/workflows/functionsTests.yml 
b/.github/workflows/functionsTests.yml
index c816245..cf64a2f 100644
--- a/.github/workflows/functionsTests.yml
+++ b/.github/workflows/functionsTests.yml
@@ -40,7 +40,8 @@ jobs:
           
"**.functions.aggregate.**,**.functions.append.**,**.functions.binary.frame.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**",
           
"**.functions.blocks.**,**.functions.compress.**,**.functions.countDistinct.**,**.functions.data.misc.**,**.functions.data.rand.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**",
           
"**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**",
-          "**.functions.federated.**",
+          "**.functions.federated.algorithms.**",
+          
"**.functions.federated.io.**,**.functions.federated.paramserv.**,**.functions.federated.primitives.**,**.functions.federated.transform.**",
           "**.functions.codegenalg.partone.**",
           "**.functions.builtin.**",
           
"**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**",
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 10e679c..7cbf421 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -237,9 +237,9 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        frameObject.acquireRead();
                        frameObject.refreshMetaData(); // get block schema
                        frameObject.release();
-                       return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {id, frameObject.getSchema()});
+                       return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {id, frameObject.getSchema(), mc});
                }
-               return new FederatedResponse(ResponseType.SUCCESS, id);
+               return new FederatedResponse(ResponseType.SUCCESS, new Object[] 
{id, mc});
        }
 
        private FederatedResponse putVariable(FederatedRequest request) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
index 3dae771..17e2855 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java
@@ -56,6 +56,7 @@ import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.ListObject;
 import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.StringObject;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
 
 public class InitFEDInstruction extends FEDInstruction {
 
@@ -236,9 +237,16 @@ public class InitFEDInstruction extends FEDInstruction {
                try {
                        int timeout = ConfigurationManager.getDMLConfig()
                                
.getIntValue(DMLConfig.DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT);
-                       LOG.debug("Federated Initialization with timeout: " + 
timeout);
-                       for(Pair<FederatedData, Future<FederatedResponse>> 
idResponse : idResponses)
-                               idResponse.getRight().get(timeout, 
TimeUnit.SECONDS); // wait for initialization
+                       if( LOG.isDebugEnabled() )
+                               LOG.debug("Federated Initialization with 
timeout: " + timeout);
+                       for(Pair<FederatedData, Future<FederatedResponse>> 
idResponse : idResponses) {
+                               // wait for initialization and check dimensions
+                               FederatedResponse re = 
idResponse.getRight().get(timeout, TimeUnit.SECONDS);
+                               DataCharacteristics dc = (DataCharacteristics) 
re.getData()[1];
+                               if( dc.getRows() > output.getNumRows() || 
dc.getCols() > output.getNumColumns() )
+                                       throw new DMLRuntimeException("Invalid 
federated meta data: "
+                                               + 
output.getDataCharacteristics()+" vs federated response: "+dc);
+                       }
                }
                catch(TimeoutException e) {
                        throw new DMLRuntimeException("Federated Initialization 
timeout exceeded", e);
@@ -294,6 +302,10 @@ public class InitFEDInstruction extends FEDInstruction {
                                FederatedResponse response = 
idResponse.getRight().getRight().get();
                                int startCol = idResponse.getRight().getLeft();
                                handleFedFrameResponse(schema, fedData, 
response, startCol);
+                               DataCharacteristics dc = (DataCharacteristics) 
response.getData()[2];
+                               if( dc.getRows() > output.getNumRows() || 
dc.getCols() > output.getNumColumns() )
+                                       throw new DMLRuntimeException("Invalid 
federated meta data: "
+                                               + 
output.getDataCharacteristics()+" vs federated response: "+dc);
                        }
                }
                catch(Exception e) {
@@ -315,7 +327,7 @@ public class InitFEDInstruction extends FEDInstruction {
                        // Index 0 is the varID, Index 1 is the schema of the 
frame
                        Object[] data = response.getData();
                        federatedData.setVarID((Long) data[0]);
-                       // copy the
+                       // copy the schema
                        Types.ValueType[] range_schema = (Types.ValueType[]) 
data[1];
                        for(int i = 0; i < range_schema.length; i++) {
                                Types.ValueType vType = range_schema[i];
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
index d019f32..edbe774 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java
@@ -95,7 +95,7 @@ public class FederatedLmCGTest extends AutomatedTestBase
                        
                        if (doubleFederated){
                                programArgs = new String[]{
-                                       "-explain", "-nvargs",
+                                       "-explain", "-stats", "-nvargs",
                                        "X1="+TestUtils.federatedAddress(port1, 
input("X1")),
                                        "X2="+TestUtils.federatedAddress(port2, 
input("X2")),
                                        "y1=" + 
TestUtils.federatedAddress(port1, input("y1")),
@@ -104,7 +104,7 @@ public class FederatedLmCGTest extends AutomatedTestBase
                                        "r=" + rows, "c=" + cols};
                        } else {
                                programArgs = new String[]{
-                                       "-explain", "-nvargs",
+                                       "-explain", "-stats", "-nvargs",
                                        "X1="+TestUtils.federatedAddress(port1, 
input("X1")),
                                        "X2="+TestUtils.federatedAddress(port2, 
input("X2")),
                                        "y=" + input("y"),
@@ -132,7 +132,8 @@ public class FederatedLmCGTest extends AutomatedTestBase
                        runTest(true, false, null, -1);
 
                        //check expected operations
-                       
Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
+                       if( instType == ExecType.CP )
+                               
Assert.assertTrue(heavyHittersContainsString("fed_mmchain"));
                        
                        TestUtils.shutdownThreads(t1, t2);
                }
diff --git a/src/test/scripts/functions/privacy/FederatedLmCG2.dml 
b/src/test/scripts/functions/privacy/FederatedLmCG2.dml
index 707a370..c8ff3b7 100644
--- a/src/test/scripts/functions/privacy/FederatedLmCG2.dml
+++ b/src/test/scripts/functions/privacy/FederatedLmCG2.dml
@@ -22,7 +22,7 @@
 X = federated(addresses=list($X1, $X2),
               ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), 
list($r, $c)))
 y = federated(addresses=list($y1, $y2),
-              ranges=list(list(0, 0), list($r / 2, 0), list($r / 2, 0), 
list($r, 0)))
+              ranges=list(list(0, 0), list($r / 2, 1), list($r / 2, 0), 
list($r, 1)))
 C = lmCG(X = X, y = y, reg = 1e-12, maxi = 2, verbose=FALSE)
 write(C, $C)
 

Reply via email to