This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit f492ad3a82f2b44eda0e6f5bb9e62d797cacd503
Author: baunsgaard <[email protected]>
AuthorDate: Thu Nov 4 10:35:46 2021 +0100

    [SYSTEMDS-3197] Federated Put variable overwrite
    
    This commit fixes two issues in the federated worker handler
    
    1. When executing multiple jobs with a single federated worker without
    sending a clear command the PUT_VAR would not overwrite the previous value,
    and make an error response that was not propergated to the caller.
    
    2. If any command crash in the sequrnce of commands, the error would not
    be returned to the caller if the last command succeed. This is common since
    the last command usually was a EXEC_INST clear variable.
    
    The fix is to return to the caller at the first error encountered.
    And allow PUT_VAR to overwrite variable IDs, since we are not officially
    supporting multi tenants, this is a safe assumption.
    
    Also contained in this commit is some minor adjustments to the error
    handling, to make the code cleaner, and the logging to terminals less.
    
    Closes #1435
---
 .../federated/FederatedWorkerHandler.java          | 274 +++++++++++----------
 1 file changed, 145 insertions(+), 129 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index a35b736..0e98fb6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -64,15 +64,24 @@ import org.apache.sysds.runtime.privacy.DMLPrivacyException;
 import org.apache.sysds.runtime.privacy.PrivacyMonitor;
 import org.apache.sysds.utils.Statistics;
 
+/**
+ * Note: federated worker handler created for every command; and concurrent 
parfor threads at coordinator need separate
+ * execution contexts at the federated sites too
+ */
 public class FederatedWorkerHandler extends ChannelInboundHandlerAdapter {
-       protected static Logger log = 
Logger.getLogger(FederatedWorkerHandler.class);
+       private static final Logger LOG = 
Logger.getLogger(FederatedWorkerHandler.class);
 
        private final ExecutionContextMap _ecm;
 
+       /**
+        * Create a Federated Worker Handler.
+        * 
+        * Note: federated worker handler created for every command; and 
concurrent parfor threads at coordinator need
+        * separate execution contexts at the federated sites too
+        * 
+        * @param ecm A execution context, used to map variables and execution.
+        */
        public FederatedWorkerHandler(ExecutionContextMap ecm) {
-               // Note: federated worker handler created for every command;
-               // and concurrent parfor threads at coordinator need separate
-               // execution contexts at the federated sites too
                _ecm = ecm;
        }
 
@@ -81,90 +90,111 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                ctx.writeAndFlush(createResponse(msg)).addListener(new 
CloseListener());
        }
 
-       public FederatedResponse createResponse(Object msg) {
-               if(log.isDebugEnabled()) {
-                       log.debug("Received: " + 
msg.getClass().getSimpleName());
-               }
+       protected FederatedResponse createResponse(Object msg) {
                if(!(msg instanceof FederatedRequest[]))
-                       throw new DMLRuntimeException(
-                               "FederatedWorkerHandler: Received object no 
instance of 'FederatedRequest[]'.");
-               FederatedRequest[] requests = (FederatedRequest[]) msg;
-               FederatedResponse response = null; // last response
+                       return new FederatedResponse(ResponseType.ERROR,
+                               new FederatedWorkerHandlerException("Received 
object of wrong instance 'FederatedRequest[]'."));
+               final FederatedRequest[] requests = (FederatedRequest[]) msg;
+               try {
+                       return createResponse(requests);
+               }
+               catch(DMLPrivacyException | FederatedWorkerHandlerException ex) 
{
+                       // Here we control the error message, therefore it is 
allowed to send the stack trace with the response
+                       return new FederatedResponse(ResponseType.ERROR, ex);
+               }
+               catch(Exception ex) {
+                       // In all other cases it is not safe to send the 
exception message to the caller
+                       final String error = "Exception thrown while processing 
requests:\n" + Arrays.toString(requests);
+                       LOG.error(error, ex);
+                       return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(error));
+               }
+       }
 
+       private FederatedResponse createResponse(FederatedRequest[] requests)
+               throws DMLPrivacyException, FederatedWorkerHandlerException, 
Exception {
+               FederatedResponse response = null; // last response
+               boolean containsCLEAR = false;
                for(int i = 0; i < requests.length; i++) {
-                       FederatedRequest request = requests[i];
-                       if(log.isDebugEnabled()) {
-                               log.debug("Executing command " + (i + 1) + "/" 
+ requests.length + ": " + request.getType().name());
-                               if(log.isTraceEnabled()) {
-                                       log.trace("full command: " + 
request.toString());
-                               }
-                       }
+                       final FederatedRequest request = requests[i];
+                       final RequestType t = request.getType();
+                       logRequests(request, i, requests.length);
+
                        PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
                        PrivacyMonitor.clearCheckedConstraints();
 
                        // execute command and handle privacy constraints
-                       FederatedResponse tmp = executeCommand(request);
+                       final FederatedResponse tmp = executeCommand(request);
                        conditionalAddCheckedConstraints(request, tmp);
 
-                       // select the response for the entire batch of requests
+                       // select the response
                        if(!tmp.isSuccessful()) {
-                               log.error("Command " + request.getType() + " 
failed: " + tmp.getErrorMessage() + "full command: \n"
-                                       + request.toString());
-                               response = (response == null || 
response.isSuccessful()) ? tmp : response; // return first error
+                               LOG.error("Command " + t + " resulted in 
error:\n" + tmp.getErrorMessage());
+                               return tmp; // Return first error without 
executing anything further
                        }
-                       else if(request.getType() == RequestType.GET_VAR) {
-                               if(response != null && response.isSuccessful())
-                                       log.error("Multiple GET_VAR are not 
supported in single batch of requests.");
-                               response = tmp; // return last get result
+                       else if(t == RequestType.GET_VAR) {
+                               // If any of the requests was a GET_VAR then 
set it as output.
+                               if(response != null) {
+                                       String message = "Multiple GET_VAR are 
not supported in single batch of requests.";
+                                       LOG.error(message);
+                                       throw new 
FederatedWorkerHandlerException(message);
+                               }
+                               response = tmp;
                        }
                        else if(response == null && i == requests.length - 1) {
                                response = tmp; // return last
                        }
 
-                       if(DMLScript.STATISTICS && request.getType() == 
RequestType.CLEAR && Statistics.allowWorkerStatistics) {
-                               System.out.println("Federated Worker " + 
Statistics.display());
-                               Statistics.reset();
-                       }
+                       if(t == RequestType.CLEAR)
+                               containsCLEAR = true;
                }
+
+               if(containsCLEAR)
+                       printStatistics();
+
                return response;
        }
 
+       private static void printStatistics() {
+               if(DMLScript.STATISTICS && Statistics.allowWorkerStatistics) {
+                       System.out.println("Federated Worker " + 
Statistics.display());
+                       Statistics.reset();
+               }
+       }
+
+       private static void logRequests(FederatedRequest request, int 
nrRequest, int totalRequests) {
+               if(LOG.isDebugEnabled()) {
+                       LOG.debug("Executing command " + (nrRequest + 1) + "/" 
+ totalRequests + ": " + request.getType().name());
+                       if(LOG.isTraceEnabled()) 
+                               LOG.trace("full command: " + 
request.toString());
+               }
+       }
+
        private static void conditionalAddCheckedConstraints(FederatedRequest 
request, FederatedResponse response) {
                if(request.checkPrivacy())
                        
response.setCheckedConstraints(PrivacyMonitor.getCheckedConstraints());
        }
 
-       private FederatedResponse executeCommand(FederatedRequest request) {
-               RequestType method = request.getType();
-               try {
-                       switch(method) {
-                               case READ_VAR:
-                                       return readData(request); // 
matrix/frame
-                               case PUT_VAR:
-                                       return putVariable(request);
-                               case GET_VAR:
-                                       return getVariable(request);
-                               case EXEC_INST:
-                                       return execInstruction(request);
-                               case EXEC_UDF:
-                                       return execUDF(request);
-                               case CLEAR:
-                                       return execClear();
-                               case NOOP:
-                                       return execNoop();
-                               default:
-                                       String message = String.format("Method 
%s is not supported.", method);
-                                       return new 
FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(message));
-                       }
-               }
-               catch(DMLPrivacyException | FederatedWorkerHandlerException ex) 
{
-                       return new FederatedResponse(ResponseType.ERROR, ex);
-               }
-               catch (Exception ex) {
-                       String msg = "Exception of type " + ex.getClass() + " 
thrown when processing request";
-                       log.error(msg, ex);
-                       return new FederatedResponse(ResponseType.ERROR,
-                               new FederatedWorkerHandlerException(msg));
+       private FederatedResponse executeCommand(FederatedRequest request)
+               throws DMLPrivacyException, FederatedWorkerHandlerException, 
Exception {
+               final RequestType method = request.getType();
+               switch(method) {
+                       case READ_VAR:
+                               return readData(request); // matrix/frame
+                       case PUT_VAR:
+                               return putVariable(request);
+                       case GET_VAR:
+                               return getVariable(request);
+                       case EXEC_INST:
+                               return execInstruction(request);
+                       case EXEC_UDF:
+                               return execUDF(request);
+                       case CLEAR:
+                               return execClear();
+                       case NOOP:
+                               return execNoop();
+                       default:
+                               String message = String.format("Method %s is 
not supported.", method);
+                               throw new 
FederatedWorkerHandlerException(message);
                }
        }
 
@@ -187,9 +217,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                cd = new FrameObject(filename);
                                break;
                        default:
-                               // should NEVER happen (if we keep request 
codes in sync with actual behavior)
-                               return new FederatedResponse(ResponseType.ERROR,
-                                       new 
FederatedWorkerHandlerException("Could not recognize datatype"));
+                               throw new 
FederatedWorkerHandlerException("Could not recognize datatype");
                }
 
                FileFormat fmt = null;
@@ -197,16 +225,15 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                String delim = null;
                FileSystem fs = null;
                MetaDataAll mtd;
-               
+
                try {
-                       String mtdname = 
DataExpression.getMTDFileName(filename);
-                       Path path = new Path(mtdname);
-                       fs = IOUtilFunctions.getFileSystem(mtdname);
+                       final String mtdName = 
DataExpression.getMTDFileName(filename);
+                       Path path = new Path(mtdName);
+                       fs = IOUtilFunctions.getFileSystem(mtdName);
                        try(BufferedReader br = new BufferedReader(new 
InputStreamReader(fs.open(path)))) {
                                mtd = new MetaDataAll(br);
                                if(!mtd.mtdExists())
-                                       return new 
FederatedResponse(ResponseType.ERROR,
-                                               new 
FederatedWorkerHandlerException("Could not parse metadata file"));
+                                       throw new 
FederatedWorkerHandlerException("Could not parse metadata file");
                                mc.setRows(mtd.getDim1());
                                mc.setCols(mtd.getDim2());
                                mc.setNonZeros(mtd.getNnz());
@@ -216,10 +243,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                delim = mtd.getDelim();
                        }
                }
-               catch (DMLPrivacyException | FederatedWorkerHandlerException 
ex){
+               catch(DMLPrivacyException | FederatedWorkerHandlerException ex) 
{
                        throw ex;
                }
-               catch (Exception ex) {
+               catch(Exception ex) {
                        throw new DMLRuntimeException(ex);
                }
                finally {
@@ -229,12 +256,11 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                // put meta data object in symbol table, read on first operation
                cd.setMetaData(new MetaDataFormat(mc, fmt));
                if(fmt == FileFormat.CSV)
-                       cd.setFileFormatProperties(new 
FileFormatPropertiesCSV(header, delim,
-                               DataExpression.DEFAULT_DELIM_SPARSE));
+                       cd.setFileFormatProperties(new 
FileFormatPropertiesCSV(header, delim, DataExpression.DEFAULT_DELIM_SPARSE));
                cd.enableCleanup(false); // guard against deletion
                _ecm.get(tid).setVariable(String.valueOf(id), cd);
 
-               if (DMLScript.LINEAGE)
+               if(DMLScript.LINEAGE)
                        // create a literal type lineage item with the file name
                        _ecm.get(tid).getLineage().set(String.valueOf(id), new 
LineageItem(filename));
 
@@ -250,11 +276,14 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
        private FederatedResponse putVariable(FederatedRequest request) {
                checkNumParams(request.getNumParams(), 1, 2);
-               String varname = String.valueOf(request.getID());
+               final String varName = String.valueOf(request.getID());
                ExecutionContext ec = _ecm.get(request.getTID());
 
-               if(ec.containsVariable(varname)) {
-                       return new FederatedResponse(ResponseType.ERROR, 
"Variable " + request.getID() + " already existing.");
+               if(ec.containsVariable(varName)) {
+                       Data tgtData = ec.removeVariable(varName);
+                       if(tgtData != null)
+                               ec.cleanupDataObject(tgtData);
+                       LOG.warn("Variable" + request.getID() + " already 
existing, fallback to overwritten.");
                }
 
                // wrap transferred cache block into cacheable data
@@ -266,17 +295,17 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                else if(request.getParam(0) instanceof ListObject)
                        data = (ListObject) request.getParam(0);
                else if(request.getNumParams() == 2)
-                       data = request.getParam(1) == DataType.MATRIX ?
-                               
ExecutionContext.createMatrixObject((MatrixCharacteristics) 
request.getParam(0)) :
-                               
ExecutionContext.createFrameObject((MatrixCharacteristics) request.getParam(0));
+                       data = request.getParam(1) == DataType.MATRIX ? 
ExecutionContext
+                               .createMatrixObject((MatrixCharacteristics) 
request.getParam(0)) : ExecutionContext
+                                       
.createFrameObject((MatrixCharacteristics) request.getParam(0));
                else
-                       throw new DMLRuntimeException(
-                               "FederatedWorkerHandler: Unsupported object 
type, has to be of type CacheBlock or ScalarObject");
+                       throw new FederatedWorkerHandlerException(
+                               "Unsupported object type, has to be of type 
CacheBlock or ScalarObject");
 
                // set variable and construct empty response
-               ec.setVariable(varname, data);
-               if (DMLScript.LINEAGE)
-                       ec.getLineage().set(varname, new 
LineageItem(String.valueOf(request.getChecksum(0))));
+               ec.setVariable(varName, data);
+               if(DMLScript.LINEAGE)
+                       ec.getLineage().set(varName, new 
LineageItem(String.valueOf(request.getChecksum(0))));
 
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
@@ -284,10 +313,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        private FederatedResponse getVariable(FederatedRequest request) {
                checkNumParams(request.getNumParams(), 0);
                ExecutionContext ec = _ecm.get(request.getTID());
-               if(!ec.containsVariable(String.valueOf(request.getID()))) {
-                       return new FederatedResponse(ResponseType.ERROR,
+               if(!ec.containsVariable(String.valueOf(request.getID())))
+                       throw new FederatedWorkerHandlerException(
                                "Variable " + request.getID() + " does not 
exist at federated worker.");
-               }
+
                // get variable and construct response
                Data dataObject = 
ec.getVariable(String.valueOf(request.getID()));
                dataObject = PrivacyMonitor.handlePrivacy(dataObject);
@@ -295,44 +324,32 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                        case TENSOR:
                        case MATRIX:
                        case FRAME:
-                               return new 
FederatedResponse(ResponseType.SUCCESS,
-                                       ((CacheableData<?>) 
dataObject).acquireReadAndRelease());
+                               return new 
FederatedResponse(ResponseType.SUCCESS, ((CacheableData<?>) 
dataObject).acquireReadAndRelease());
                        case LIST:
                                return new 
FederatedResponse(ResponseType.SUCCESS, ((ListObject) dataObject).getData());
                        case SCALAR:
                                return new 
FederatedResponse(ResponseType.SUCCESS, dataObject);
                        default:
-                               return new 
FederatedResponse(ResponseType.ERROR, new FederatedWorkerHandlerException(
-                                       "Unsupported return datatype " + 
dataObject.getDataType().name()));
+                               throw new 
FederatedWorkerHandlerException("Unsupported return datatype " + 
dataObject.getDataType().name());
                }
        }
 
-       private FederatedResponse execInstruction(FederatedRequest request) {
+       private FederatedResponse execInstruction(FederatedRequest request) 
throws Exception {
                ExecutionContext ec = _ecm.get(request.getTID());
                BasicProgramBlock pb = new BasicProgramBlock(null);
                pb.getInstructions().clear();
                Instruction receivedInstruction = 
InstructionParser.parseSingleInstruction((String) request.getParam(0));
                pb.getInstructions().add(receivedInstruction);
 
-               if (DMLScript.LINEAGE)
+               if(DMLScript.LINEAGE)
                        // Compiler assisted optimizations are not applicable 
for Fed workers.
-                       // e.g. isMarkedForCaching fails as output operands are 
saved in the 
-                       // symbol table only after the instruction execution 
finishes. 
-                       // NOTE: In shared JVM, this will disable compiler 
assistance even for the coordinator 
+                       // e.g. isMarkedForCaching fails as output operands are 
saved in the
+                       // symbol table only after the instruction execution 
finishes.
+                       // NOTE: In shared JVM, this will disable compiler 
assistance even for the coordinator
                        LineageCacheConfig.setCompAssRW(false);
 
-               try {
-                       pb.execute(ec); // execute single instruction
-               }
-               catch(DMLPrivacyException | FederatedWorkerHandlerException ex){
-                       throw ex;
-               }
-               catch(Exception ex) {
-                       String msg = "Exception of type " + ex.getClass() + " 
thrown when processing EXEC_INST request";
-                       log.error(msg, ex);
-                       return new FederatedResponse(ResponseType.ERROR,
-                               new FederatedWorkerHandlerException(msg));
-               }
+               pb.execute(ec); // execute single instruction
+
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
 
@@ -344,51 +361,50 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                FederatedUDF udf = (FederatedUDF) request.getParam(0);
                Data[] inputs = Arrays.stream(udf.getInputIDs()).mapToObj(id -> 
ec.getVariable(String.valueOf(id)))
                        
.map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new);
-               
+
                // trace lineage
-               if (DMLScript.LINEAGE)
+               if(DMLScript.LINEAGE)
                        LineageItemUtils.traceFedUDF(ec, udf);
-               
+
                // reuse or execute user-defined function
                try {
                        // reuse UDF outputs if available in lineage cache
                        FederatedResponse reuse = LineageCache.reuse(udf, ec);
-                       if (reuse.isSuccessful())
+                       if(reuse.isSuccessful())
                                return reuse;
 
                        // else execute the UDF
                        long t0 = !ReuseCacheType.isNone() ? System.nanoTime() 
: 0;
                        FederatedResponse res = udf.execute(ec, inputs);
                        long t1 = !ReuseCacheType.isNone() ? System.nanoTime() 
: 0;
-                       //cacheUDFOutputs(udf, inputs, t1-t0, ec);
-                       LineageCache.putValue(udf, ec, t1-t0);
+                       // cacheUDFOutputs(udf, inputs, t1-t0, ec);
+                       LineageCache.putValue(udf, ec, t1 - t0);
                        return res;
                }
-               catch(DMLPrivacyException | FederatedWorkerHandlerException ex){
+               catch(DMLPrivacyException | FederatedWorkerHandlerException ex) 
{
                        throw ex;
                }
                catch(Exception ex) {
+                       // Note it is unsafe to throw the ex trace along with 
the exception here.
                        String msg = "Exception of type " + ex.getClass() + " 
thrown when processing EXEC_UDF request";
-                       log.error(msg, ex);
-                       return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(msg));
+                       throw new FederatedWorkerHandlerException(msg);
                }
        }
-       
+
        private FederatedResponse execClear() {
                try {
                        _ecm.clear();
                }
-               catch(DMLPrivacyException | FederatedWorkerHandlerException ex){
+               catch(DMLPrivacyException | FederatedWorkerHandlerException ex) 
{
                        throw ex;
                }
                catch(Exception ex) {
                        String msg = "Exception of type " + ex.getClass() + " 
thrown when processing CLEAR request";
-                       log.error(msg, ex);
-                       return new FederatedResponse(ResponseType.ERROR, new 
FederatedWorkerHandlerException(msg));
+                       throw new FederatedWorkerHandlerException(msg);
                }
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
-       
+
        private static FederatedResponse execNoop() {
                return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
        }
@@ -396,8 +412,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
        private static void checkNumParams(int actual, int... expected) {
                if(Arrays.stream(expected).anyMatch(x -> x == actual))
                        return;
-               throw new DMLRuntimeException("FederatedWorkerHandler: Received 
wrong amount of params:" 
-                       + " expected=" + Arrays.toString(expected) + ", 
actual=" + actual);
+               throw new DMLRuntimeException("FederatedWorkerHandler: Received 
wrong amount of params:" + " expected="
+                       + Arrays.toString(expected) + ", actual=" + actual);
        }
 
        @Override
@@ -410,7 +426,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                @Override
                public void operationComplete(ChannelFuture channelFuture) 
throws InterruptedException {
                        if(!channelFuture.isSuccess()) {
-                               log.error("Federated Worker Write failed");
+                               LOG.error("Federated Worker Write failed");
                                channelFuture.channel().writeAndFlush(new 
FederatedResponse(ResponseType.ERROR,
                                        new 
FederatedWorkerHandlerException("Error while sending 
response."))).channel().close().sync();
                        }

Reply via email to