This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit ed8b5f5c725526f8414d2dbe9113d52260597352 Author: Matthias Boehm <[email protected]> AuthorDate: Thu Dec 17 13:17:03 2020 +0100 [SYSTEMDS-2759] Fix error handling federated read, wrong test meta data This patch improves the federated read by explicit error handling for invalid meta data (e.g., federated data partitions larger than global matrix dimensions) in order to detect meta data inconsistencies. --- .github/workflows/functionsTests.yml | 3 ++- .../federated/FederatedWorkerHandler.java | 4 ++-- .../runtime/instructions/fed/InitFEDInstruction.java | 20 ++++++++++++++++---- .../test/functions/privacy/FederatedLmCGTest.java | 7 ++++--- .../scripts/functions/privacy/FederatedLmCG2.dml | 2 +- 5 files changed, 25 insertions(+), 11 deletions(-) diff --git a/.github/workflows/functionsTests.yml b/.github/workflows/functionsTests.yml index c816245..cf64a2f 100644 --- a/.github/workflows/functionsTests.yml +++ b/.github/workflows/functionsTests.yml @@ -40,7 +40,8 @@ jobs: "**.functions.aggregate.**,**.functions.append.**,**.functions.binary.frame.**,**.functions.binary.matrix.**,**.functions.binary.scalar.**,**.functions.binary.tensor.**", "**.functions.blocks.**,**.functions.compress.**,**.functions.countDistinct.**,**.functions.data.misc.**,**.functions.data.rand.**,**.functions.data.tensor.**,**.functions.codegenalg.parttwo.**,**.functions.codegen.**,**.functions.caching.**", "**.functions.binary.matrix_full_cellwise.**,**.functions.binary.matrix_full_other.**", - "**.functions.federated.**", + "**.functions.federated.algorithms.**", + "**.functions.federated.io.**,**.functions.federated.paramserv.**,**.functions.federated.primitives.**,**.functions.federated.transform.**", "**.functions.codegenalg.partone.**", "**.functions.builtin.**", "**.functions.frame.**,**.functions.indexing.**,**.functions.io.**,**.functions.jmlc.**,**.functions.lineage.**", diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index 10e679c..7cbf421 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -237,9 +237,9 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { frameObject.acquireRead(); frameObject.refreshMetaData(); // get block schema frameObject.release(); - return new FederatedResponse(ResponseType.SUCCESS, new Object[] {id, frameObject.getSchema()}); + return new FederatedResponse(ResponseType.SUCCESS, new Object[] {id, frameObject.getSchema(), mc}); } - return new FederatedResponse(ResponseType.SUCCESS, id); + return new FederatedResponse(ResponseType.SUCCESS, new Object[] {id, mc}); } private FederatedResponse putVariable(FederatedRequest request) { diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java index 3dae771..17e2855 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/InitFEDInstruction.java @@ -56,6 +56,7 @@ import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.cp.ListObject; import org.apache.sysds.runtime.instructions.cp.ScalarObject; import org.apache.sysds.runtime.instructions.cp.StringObject; +import org.apache.sysds.runtime.meta.DataCharacteristics; public class InitFEDInstruction extends FEDInstruction { @@ -236,9 +237,16 @@ public class InitFEDInstruction extends FEDInstruction { try { int timeout = ConfigurationManager.getDMLConfig() .getIntValue(DMLConfig.DEFAULT_FEDERATED_INITIALIZATION_TIMEOUT); - LOG.debug("Federated Initialization with timeout: " + timeout); - for(Pair<FederatedData, Future<FederatedResponse>> idResponse : idResponses) - idResponse.getRight().get(timeout, TimeUnit.SECONDS); // wait for initialization + if( LOG.isDebugEnabled() ) + LOG.debug("Federated Initialization with timeout: " + timeout); + for(Pair<FederatedData, Future<FederatedResponse>> idResponse : idResponses) { + // wait for initialization and check dimensions + FederatedResponse re = idResponse.getRight().get(timeout, TimeUnit.SECONDS); + DataCharacteristics dc = (DataCharacteristics) re.getData()[1]; + if( dc.getRows() > output.getNumRows() || dc.getCols() > output.getNumColumns() ) + throw new DMLRuntimeException("Invalid federated meta data: " + + output.getDataCharacteristics()+" vs federated response: "+dc); + } } catch(TimeoutException e) { throw new DMLRuntimeException("Federated Initialization timeout exceeded", e); @@ -294,6 +302,10 @@ public class InitFEDInstruction extends FEDInstruction { FederatedResponse response = idResponse.getRight().getRight().get(); int startCol = idResponse.getRight().getLeft(); handleFedFrameResponse(schema, fedData, response, startCol); + DataCharacteristics dc = (DataCharacteristics) response.getData()[2]; + if( dc.getRows() > output.getNumRows() || dc.getCols() > output.getNumColumns() ) + throw new DMLRuntimeException("Invalid federated meta data: " + + output.getDataCharacteristics()+" vs federated response: "+dc); } } catch(Exception e) { @@ -315,7 +327,7 @@ public class InitFEDInstruction extends FEDInstruction { // Index 0 is the varID, Index 1 is the schema of the frame Object[] data = response.getData(); federatedData.setVarID((Long) data[0]); - // copy the + // copy the schema Types.ValueType[] range_schema = (Types.ValueType[]) data[1]; for(int i = 0; i < range_schema.length; i++) { Types.ValueType vType = range_schema[i]; diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java index d019f32..edbe774 100644 --- a/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java +++ b/src/test/java/org/apache/sysds/test/functions/privacy/FederatedLmCGTest.java @@ -95,7 +95,7 @@ public class FederatedLmCGTest extends AutomatedTestBase if (doubleFederated){ programArgs = new String[]{ - "-explain", "-nvargs", + "-explain", "-stats", "-nvargs", "X1="+TestUtils.federatedAddress(port1, input("X1")), "X2="+TestUtils.federatedAddress(port2, input("X2")), "y1=" + TestUtils.federatedAddress(port1, input("y1")), @@ -104,7 +104,7 @@ public class FederatedLmCGTest extends AutomatedTestBase "r=" + rows, "c=" + cols}; } else { programArgs = new String[]{ - "-explain", "-nvargs", + "-explain", "-stats", "-nvargs", "X1="+TestUtils.federatedAddress(port1, input("X1")), "X2="+TestUtils.federatedAddress(port2, input("X2")), "y=" + input("y"), @@ -132,7 +132,8 @@ public class FederatedLmCGTest extends AutomatedTestBase runTest(true, false, null, -1); //check expected operations - Assert.assertTrue(heavyHittersContainsString("fed_mmchain")); + if( instType == ExecType.CP ) + Assert.assertTrue(heavyHittersContainsString("fed_mmchain")); TestUtils.shutdownThreads(t1, t2); } diff --git a/src/test/scripts/functions/privacy/FederatedLmCG2.dml b/src/test/scripts/functions/privacy/FederatedLmCG2.dml index 707a370..c8ff3b7 100644 --- a/src/test/scripts/functions/privacy/FederatedLmCG2.dml +++ b/src/test/scripts/functions/privacy/FederatedLmCG2.dml @@ -22,7 +22,7 @@ X = federated(addresses=list($X1, $X2), ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c))) y = federated(addresses=list($y1, $y2), - ranges=list(list(0, 0), list($r / 2, 0), list($r / 2, 0), list($r, 0))) + ranges=list(list(0, 0), list($r / 2, 1), list($r / 2, 0), list($r, 1))) C = lmCG(X = X, y = y, reg = 1e-12, maxi = 2, verbose=FALSE) write(C, $C)
