This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit f492ad3a82f2b44eda0e6f5bb9e62d797cacd503 Author: baunsgaard <[email protected]> AuthorDate: Thu Nov 4 10:35:46 2021 +0100 [SYSTEMDS-3197] Federated Put variable overwrite This commit fixes two issues in the federated worker handler 1. When executing multiple jobs with a single federated worker without sending a clear command the PUT_VAR would not overwrite the previous value, and make an error response that was not propergated to the caller. 2. If any command crash in the sequrnce of commands, the error would not be returned to the caller if the last command succeed. This is common since the last command usually was a EXEC_INST clear variable. The fix is to return to the caller at the first error encountered. And allow PUT_VAR to overwrite variable IDs, since we are not officially supporting multi tenants, this is a safe assumption. Also contained in this commit is some minor adjustments to the error handling, to make the code cleaner, and the logging to terminals less. Closes #1435 --- .../federated/FederatedWorkerHandler.java | 274 +++++++++++---------- 1 file changed, 145 insertions(+), 129 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java index a35b736..0e98fb6 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java @@ -64,15 +64,24 @@ import org.apache.sysds.runtime.privacy.DMLPrivacyException; import org.apache.sysds.runtime.privacy.PrivacyMonitor; import org.apache.sysds.utils.Statistics; +/** + * Note: federated worker handler created for every command; and concurrent parfor threads at coordinator need separate + * execution contexts at the federated sites too + */ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { - protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class); + private static final Logger LOG = Logger.getLogger(FederatedWorkerHandler.class); private final ExecutionContextMap _ecm; + /** + * Create a Federated Worker Handler. + * + * Note: federated worker handler created for every command; and concurrent parfor threads at coordinator need + * separate execution contexts at the federated sites too + * + * @param ecm A execution context, used to map variables and execution. + */ public FederatedWorkerHandler(ExecutionContextMap ecm) { - // Note: federated worker handler created for every command; - // and concurrent parfor threads at coordinator need separate - // execution contexts at the federated sites too _ecm = ecm; } @@ -81,90 +90,111 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { ctx.writeAndFlush(createResponse(msg)).addListener(new CloseListener()); } - public FederatedResponse createResponse(Object msg) { - if(log.isDebugEnabled()) { - log.debug("Received: " + msg.getClass().getSimpleName()); - } + protected FederatedResponse createResponse(Object msg) { if(!(msg instanceof FederatedRequest[])) - throw new DMLRuntimeException( - "FederatedWorkerHandler: Received object no instance of 'FederatedRequest[]'."); - FederatedRequest[] requests = (FederatedRequest[]) msg; - FederatedResponse response = null; // last response + return new FederatedResponse(ResponseType.ERROR, + new FederatedWorkerHandlerException("Received object of wrong instance 'FederatedRequest[]'.")); + final FederatedRequest[] requests = (FederatedRequest[]) msg; + try { + return createResponse(requests); + } + catch(DMLPrivacyException | FederatedWorkerHandlerException ex) { + // Here we control the error message, therefore it is allowed to send the stack trace with the response + return new FederatedResponse(ResponseType.ERROR, ex); + } + catch(Exception ex) { + // In all other cases it is not safe to send the exception message to the caller + final String error = "Exception thrown while processing requests:\n" + Arrays.toString(requests); + LOG.error(error, ex); + return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(error)); + } + } + private FederatedResponse createResponse(FederatedRequest[] requests) + throws DMLPrivacyException, FederatedWorkerHandlerException, Exception { + FederatedResponse response = null; // last response + boolean containsCLEAR = false; for(int i = 0; i < requests.length; i++) { - FederatedRequest request = requests[i]; - if(log.isDebugEnabled()) { - log.debug("Executing command " + (i + 1) + "/" + requests.length + ": " + request.getType().name()); - if(log.isTraceEnabled()) { - log.trace("full command: " + request.toString()); - } - } + final FederatedRequest request = requests[i]; + final RequestType t = request.getType(); + logRequests(request, i, requests.length); + PrivacyMonitor.setCheckPrivacy(request.checkPrivacy()); PrivacyMonitor.clearCheckedConstraints(); // execute command and handle privacy constraints - FederatedResponse tmp = executeCommand(request); + final FederatedResponse tmp = executeCommand(request); conditionalAddCheckedConstraints(request, tmp); - // select the response for the entire batch of requests + // select the response if(!tmp.isSuccessful()) { - log.error("Command " + request.getType() + " failed: " + tmp.getErrorMessage() + "full command: \n" - + request.toString()); - response = (response == null || response.isSuccessful()) ? tmp : response; // return first error + LOG.error("Command " + t + " resulted in error:\n" + tmp.getErrorMessage()); + return tmp; // Return first error without executing anything further } - else if(request.getType() == RequestType.GET_VAR) { - if(response != null && response.isSuccessful()) - log.error("Multiple GET_VAR are not supported in single batch of requests."); - response = tmp; // return last get result + else if(t == RequestType.GET_VAR) { + // If any of the requests was a GET_VAR then set it as output. + if(response != null) { + String message = "Multiple GET_VAR are not supported in single batch of requests."; + LOG.error(message); + throw new FederatedWorkerHandlerException(message); + } + response = tmp; } else if(response == null && i == requests.length - 1) { response = tmp; // return last } - if(DMLScript.STATISTICS && request.getType() == RequestType.CLEAR && Statistics.allowWorkerStatistics) { - System.out.println("Federated Worker " + Statistics.display()); - Statistics.reset(); - } + if(t == RequestType.CLEAR) + containsCLEAR = true; } + + if(containsCLEAR) + printStatistics(); + return response; } + private static void printStatistics() { + if(DMLScript.STATISTICS && Statistics.allowWorkerStatistics) { + System.out.println("Federated Worker " + Statistics.display()); + Statistics.reset(); + } + } + + private static void logRequests(FederatedRequest request, int nrRequest, int totalRequests) { + if(LOG.isDebugEnabled()) { + LOG.debug("Executing command " + (nrRequest + 1) + "/" + totalRequests + ": " + request.getType().name()); + if(LOG.isTraceEnabled()) + LOG.trace("full command: " + request.toString()); + } + } + private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response) { if(request.checkPrivacy()) response.setCheckedConstraints(PrivacyMonitor.getCheckedConstraints()); } - private FederatedResponse executeCommand(FederatedRequest request) { - RequestType method = request.getType(); - try { - switch(method) { - case READ_VAR: - return readData(request); // matrix/frame - case PUT_VAR: - return putVariable(request); - case GET_VAR: - return getVariable(request); - case EXEC_INST: - return execInstruction(request); - case EXEC_UDF: - return execUDF(request); - case CLEAR: - return execClear(); - case NOOP: - return execNoop(); - default: - String message = String.format("Method %s is not supported.", method); - return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(message)); - } - } - catch(DMLPrivacyException | FederatedWorkerHandlerException ex) { - return new FederatedResponse(ResponseType.ERROR, ex); - } - catch (Exception ex) { - String msg = "Exception of type " + ex.getClass() + " thrown when processing request"; - log.error(msg, ex); - return new FederatedResponse(ResponseType.ERROR, - new FederatedWorkerHandlerException(msg)); + private FederatedResponse executeCommand(FederatedRequest request) + throws DMLPrivacyException, FederatedWorkerHandlerException, Exception { + final RequestType method = request.getType(); + switch(method) { + case READ_VAR: + return readData(request); // matrix/frame + case PUT_VAR: + return putVariable(request); + case GET_VAR: + return getVariable(request); + case EXEC_INST: + return execInstruction(request); + case EXEC_UDF: + return execUDF(request); + case CLEAR: + return execClear(); + case NOOP: + return execNoop(); + default: + String message = String.format("Method %s is not supported.", method); + throw new FederatedWorkerHandlerException(message); } } @@ -187,9 +217,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { cd = new FrameObject(filename); break; default: - // should NEVER happen (if we keep request codes in sync with actual behavior) - return new FederatedResponse(ResponseType.ERROR, - new FederatedWorkerHandlerException("Could not recognize datatype")); + throw new FederatedWorkerHandlerException("Could not recognize datatype"); } FileFormat fmt = null; @@ -197,16 +225,15 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { String delim = null; FileSystem fs = null; MetaDataAll mtd; - + try { - String mtdname = DataExpression.getMTDFileName(filename); - Path path = new Path(mtdname); - fs = IOUtilFunctions.getFileSystem(mtdname); + final String mtdName = DataExpression.getMTDFileName(filename); + Path path = new Path(mtdName); + fs = IOUtilFunctions.getFileSystem(mtdName); try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) { mtd = new MetaDataAll(br); if(!mtd.mtdExists()) - return new FederatedResponse(ResponseType.ERROR, - new FederatedWorkerHandlerException("Could not parse metadata file")); + throw new FederatedWorkerHandlerException("Could not parse metadata file"); mc.setRows(mtd.getDim1()); mc.setCols(mtd.getDim2()); mc.setNonZeros(mtd.getNnz()); @@ -216,10 +243,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { delim = mtd.getDelim(); } } - catch (DMLPrivacyException | FederatedWorkerHandlerException ex){ + catch(DMLPrivacyException | FederatedWorkerHandlerException ex) { throw ex; } - catch (Exception ex) { + catch(Exception ex) { throw new DMLRuntimeException(ex); } finally { @@ -229,12 +256,11 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { // put meta data object in symbol table, read on first operation cd.setMetaData(new MetaDataFormat(mc, fmt)); if(fmt == FileFormat.CSV) - cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, delim, - DataExpression.DEFAULT_DELIM_SPARSE)); + cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, delim, DataExpression.DEFAULT_DELIM_SPARSE)); cd.enableCleanup(false); // guard against deletion _ecm.get(tid).setVariable(String.valueOf(id), cd); - if (DMLScript.LINEAGE) + if(DMLScript.LINEAGE) // create a literal type lineage item with the file name _ecm.get(tid).getLineage().set(String.valueOf(id), new LineageItem(filename)); @@ -250,11 +276,14 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private FederatedResponse putVariable(FederatedRequest request) { checkNumParams(request.getNumParams(), 1, 2); - String varname = String.valueOf(request.getID()); + final String varName = String.valueOf(request.getID()); ExecutionContext ec = _ecm.get(request.getTID()); - if(ec.containsVariable(varname)) { - return new FederatedResponse(ResponseType.ERROR, "Variable " + request.getID() + " already existing."); + if(ec.containsVariable(varName)) { + Data tgtData = ec.removeVariable(varName); + if(tgtData != null) + ec.cleanupDataObject(tgtData); + LOG.warn("Variable" + request.getID() + " already existing, fallback to overwritten."); } // wrap transferred cache block into cacheable data @@ -266,17 +295,17 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { else if(request.getParam(0) instanceof ListObject) data = (ListObject) request.getParam(0); else if(request.getNumParams() == 2) - data = request.getParam(1) == DataType.MATRIX ? - ExecutionContext.createMatrixObject((MatrixCharacteristics) request.getParam(0)) : - ExecutionContext.createFrameObject((MatrixCharacteristics) request.getParam(0)); + data = request.getParam(1) == DataType.MATRIX ? ExecutionContext + .createMatrixObject((MatrixCharacteristics) request.getParam(0)) : ExecutionContext + .createFrameObject((MatrixCharacteristics) request.getParam(0)); else - throw new DMLRuntimeException( - "FederatedWorkerHandler: Unsupported object type, has to be of type CacheBlock or ScalarObject"); + throw new FederatedWorkerHandlerException( + "Unsupported object type, has to be of type CacheBlock or ScalarObject"); // set variable and construct empty response - ec.setVariable(varname, data); - if (DMLScript.LINEAGE) - ec.getLineage().set(varname, new LineageItem(String.valueOf(request.getChecksum(0)))); + ec.setVariable(varName, data); + if(DMLScript.LINEAGE) + ec.getLineage().set(varName, new LineageItem(String.valueOf(request.getChecksum(0)))); return new FederatedResponse(ResponseType.SUCCESS_EMPTY); } @@ -284,10 +313,10 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private FederatedResponse getVariable(FederatedRequest request) { checkNumParams(request.getNumParams(), 0); ExecutionContext ec = _ecm.get(request.getTID()); - if(!ec.containsVariable(String.valueOf(request.getID()))) { - return new FederatedResponse(ResponseType.ERROR, + if(!ec.containsVariable(String.valueOf(request.getID()))) + throw new FederatedWorkerHandlerException( "Variable " + request.getID() + " does not exist at federated worker."); - } + // get variable and construct response Data dataObject = ec.getVariable(String.valueOf(request.getID())); dataObject = PrivacyMonitor.handlePrivacy(dataObject); @@ -295,44 +324,32 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { case TENSOR: case MATRIX: case FRAME: - return new FederatedResponse(ResponseType.SUCCESS, - ((CacheableData<?>) dataObject).acquireReadAndRelease()); + return new FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>) dataObject).acquireReadAndRelease()); case LIST: return new FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData()); case SCALAR: return new FederatedResponse(ResponseType.SUCCESS, dataObject); default: - return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException( - "Unsupported return datatype " + dataObject.getDataType().name())); + throw new FederatedWorkerHandlerException("Unsupported return datatype " + dataObject.getDataType().name()); } } - private FederatedResponse execInstruction(FederatedRequest request) { + private FederatedResponse execInstruction(FederatedRequest request) throws Exception { ExecutionContext ec = _ecm.get(request.getTID()); BasicProgramBlock pb = new BasicProgramBlock(null); pb.getInstructions().clear(); Instruction receivedInstruction = InstructionParser.parseSingleInstruction((String) request.getParam(0)); pb.getInstructions().add(receivedInstruction); - if (DMLScript.LINEAGE) + if(DMLScript.LINEAGE) // Compiler assisted optimizations are not applicable for Fed workers. - // e.g. isMarkedForCaching fails as output operands are saved in the - // symbol table only after the instruction execution finishes. - // NOTE: In shared JVM, this will disable compiler assistance even for the coordinator + // e.g. isMarkedForCaching fails as output operands are saved in the + // symbol table only after the instruction execution finishes. + // NOTE: In shared JVM, this will disable compiler assistance even for the coordinator LineageCacheConfig.setCompAssRW(false); - try { - pb.execute(ec); // execute single instruction - } - catch(DMLPrivacyException | FederatedWorkerHandlerException ex){ - throw ex; - } - catch(Exception ex) { - String msg = "Exception of type " + ex.getClass() + " thrown when processing EXEC_INST request"; - log.error(msg, ex); - return new FederatedResponse(ResponseType.ERROR, - new FederatedWorkerHandlerException(msg)); - } + pb.execute(ec); // execute single instruction + return new FederatedResponse(ResponseType.SUCCESS_EMPTY); } @@ -344,51 +361,50 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { FederatedUDF udf = (FederatedUDF) request.getParam(0); Data[] inputs = Arrays.stream(udf.getInputIDs()).mapToObj(id -> ec.getVariable(String.valueOf(id))) .map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new); - + // trace lineage - if (DMLScript.LINEAGE) + if(DMLScript.LINEAGE) LineageItemUtils.traceFedUDF(ec, udf); - + // reuse or execute user-defined function try { // reuse UDF outputs if available in lineage cache FederatedResponse reuse = LineageCache.reuse(udf, ec); - if (reuse.isSuccessful()) + if(reuse.isSuccessful()) return reuse; // else execute the UDF long t0 = !ReuseCacheType.isNone() ? System.nanoTime() : 0; FederatedResponse res = udf.execute(ec, inputs); long t1 = !ReuseCacheType.isNone() ? System.nanoTime() : 0; - //cacheUDFOutputs(udf, inputs, t1-t0, ec); - LineageCache.putValue(udf, ec, t1-t0); + // cacheUDFOutputs(udf, inputs, t1-t0, ec); + LineageCache.putValue(udf, ec, t1 - t0); return res; } - catch(DMLPrivacyException | FederatedWorkerHandlerException ex){ + catch(DMLPrivacyException | FederatedWorkerHandlerException ex) { throw ex; } catch(Exception ex) { + // Note it is unsafe to throw the ex trace along with the exception here. String msg = "Exception of type " + ex.getClass() + " thrown when processing EXEC_UDF request"; - log.error(msg, ex); - return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(msg)); + throw new FederatedWorkerHandlerException(msg); } } - + private FederatedResponse execClear() { try { _ecm.clear(); } - catch(DMLPrivacyException | FederatedWorkerHandlerException ex){ + catch(DMLPrivacyException | FederatedWorkerHandlerException ex) { throw ex; } catch(Exception ex) { String msg = "Exception of type " + ex.getClass() + " thrown when processing CLEAR request"; - log.error(msg, ex); - return new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(msg)); + throw new FederatedWorkerHandlerException(msg); } return new FederatedResponse(ResponseType.SUCCESS_EMPTY); } - + private static FederatedResponse execNoop() { return new FederatedResponse(ResponseType.SUCCESS_EMPTY); } @@ -396,8 +412,8 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { private static void checkNumParams(int actual, int... expected) { if(Arrays.stream(expected).anyMatch(x -> x == actual)) return; - throw new DMLRuntimeException("FederatedWorkerHandler: Received wrong amount of params:" - + " expected=" + Arrays.toString(expected) + ", actual=" + actual); + throw new DMLRuntimeException("FederatedWorkerHandler: Received wrong amount of params:" + " expected=" + + Arrays.toString(expected) + ", actual=" + actual); } @Override @@ -410,7 +426,7 @@ public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter { @Override public void operationComplete(ChannelFuture channelFuture) throws InterruptedException { if(!channelFuture.isSuccess()) { - log.error("Federated Worker Write failed"); + LOG.error("Federated Worker Write failed"); channelFuture.channel().writeAndFlush(new FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException("Error while sending response."))).channel().close().sync(); }
