This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 69d3358  [SYSTEMDS-3143] Frame rm empty instruction
69d3358 is described below

commit 69d33589de1258ba68b3652dfb9a5884adea213e
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Wed Sep 22 00:17:50 2021 +0200

    [SYSTEMDS-3143] Frame rm empty instruction
    
    This commit adds the remove empty instruction to frame, this instruction
    was previously only supported on matrices.
    
    Closes #1397
---
 .../ParameterizedBuiltinFunctionExpression.java    |  16 +-
 .../cp/ParameterizedBuiltinCPInstruction.java      |  34 ++-
 .../fed/ParameterizedBuiltinFEDInstruction.java    | 251 ++++++++++++++++++++-
 .../sysds/runtime/matrix/data/FrameBlock.java      | 110 +++++++++
 .../apache/sysds/runtime/util/UtilFunctions.java   |   7 +
 .../test/component/frame/FrameRemoveEmptyTest.java | 195 ++++++++++++++++
 src/test/scripts/functions/frame/removeEmpty1.dml  |  30 +++
 7 files changed, 623 insertions(+), 20 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 444ab54..442d1e6 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -584,7 +584,8 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
                                + Arrays.toString(invalid.toArray(new 
String[0])), false);
                
                //check existence and correctness of arguments
-               checkTargetParam(getVarParam("target"), conditional);
+               Expression target = getVarParam("target");
+               checkEmptyTargetParam(target, conditional);
                
                Expression margin = getVarParam("margin");
                if( margin==null ){
@@ -608,8 +609,11 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        _varParams.put("empty.return", new 
BooleanIdentifier(true));
                
                // Output is a matrix with unknown dims
-               output.setDataType(DataType.MATRIX);
-               output.setValueType(ValueType.FP64);
+               output.setDataType(target.getOutput().getDataType());
+               if(target.getOutput().getDataType() == DataType.FRAME)
+                       output.setValueType(ValueType.STRING);
+               else
+                       output.setValueType(ValueType.FP64);
                output.setDimensions(-1, -1);
        }
        
@@ -726,6 +730,12 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        raiseValidateError("Input matrix 'target' is of type 
'"+target.getOutput().getDataType()
                                +"'. Please specify the input matrix.", 
conditional, LanguageErrorCodes.INVALID_PARAMETERS);
        }
+
+       private void checkEmptyTargetParam(Expression target, boolean 
conditional) {
+               if( target==null )
+                       raiseValidateError("Named parameter 'target' missing. 
Please specify the input matrix.",
+                               conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
+       }
        
        private void checkOptionalBooleanParam(Expression param, String name, 
boolean conditional) {
                if( param!=null && (!param.getOutput().getDataType().isScalar() 
|| param.getOutput().getValueType() != ValueType.BOOLEAN) ){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index ccced11..233154a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -208,21 +208,31 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        String margin = params.get("margin");
                        if(!(margin.equals("rows") || margin.equals("cols")))
                                throw new DMLRuntimeException("Unspupported 
margin identifier '" + margin + "'.");
+                       if(ec.isFrameObject(params.get("target"))) {
+                               FrameBlock target = 
ec.getFrameInput(params.get("target"));
+                               MatrixBlock select = 
params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null;
 
-                       // acquire locks
-                       MatrixBlock target = 
ec.getMatrixInput(params.get("target"));
-                       MatrixBlock select = params.containsKey("select") ? 
ec.getMatrixInput(params.get("select")) : null;
+                               boolean emptyReturn = 
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
+                               FrameBlock soresBlock = 
target.removeEmptyOperations(margin.equals("rows"), emptyReturn, select);
+                               ec.setFrameOutput(output.getName(), soresBlock);
+                               ec.releaseFrameInput(params.get("target"));
+                               if(params.containsKey("select"))
+                                       
ec.releaseMatrixInput(params.get("select"));
+                       } else {
+                               // acquire locks
+                               MatrixBlock target = 
ec.getMatrixInput(params.get("target"));
+                               MatrixBlock select = 
params.containsKey("select") ? ec.getMatrixInput(params.get("select")) : null;
 
-                       // compute the result
-                       boolean emptyReturn = 
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
-                       MatrixBlock soresBlock = target
-                               .removeEmptyOperations(new MatrixBlock(), 
margin.equals("rows"), emptyReturn, select);
+                               // compute the result
+                               boolean emptyReturn = 
Boolean.parseBoolean(params.get("empty.return").toLowerCase());
+                               MatrixBlock soresBlock = 
target.removeEmptyOperations(new MatrixBlock(), margin.equals("rows"), 
emptyReturn, select);
 
-                       // release locks
-                       ec.setMatrixOutput(output.getName(), soresBlock);
-                       ec.releaseMatrixInput(params.get("target"));
-                       if(params.containsKey("select"))
-                               ec.releaseMatrixInput(params.get("select"));
+                               // release locks
+                               ec.setMatrixOutput(output.getName(), 
soresBlock);
+                               ec.releaseMatrixInput(params.get("target"));
+                               if(params.containsKey("select"))
+                                       
ec.releaseMatrixInput(params.get("select"));
+                       }
                }
                else if(opcode.equalsIgnoreCase("replace")) {
                        if(ec.isFrameObject(params.get("target"))){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 02d34a1..a6c5ef1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -28,10 +28,12 @@ import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.Future;
+import java.util.stream.IntStream;
 import java.util.stream.Stream;
 import java.util.zip.Adler32;
 import java.util.zip.Checksum;
 
+import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.lang3.SerializationUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
@@ -73,6 +75,7 @@ import 
org.apache.sysds.runtime.transform.decode.DecoderFactory;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.EncoderOmit;
 import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.runtime.util.UtilFunctions;
 
 public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstruction {
        protected final LinkedHashMap<String, String> params;
@@ -151,7 +154,10 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
                }
                else if(opcode.equals("rmempty"))
-                       rmempty(ec);
+                       if (getTarget(ec) instanceof FrameObject)
+                               rmemptyFrame(ec);
+                       else
+                               rmemptyMatrix(ec);
                else if(opcode.equals("lowertri") || opcode.equals("uppertri"))
                        triangle(ec, opcode);
                else if(opcode.equalsIgnoreCase("transformdecode"))
@@ -329,7 +335,170 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                }
        }
 
-       private void rmempty(ExecutionContext ec) {
+       private void rmemptyFrame(ExecutionContext ec) {
+               String margin = params.get("margin");
+               if(!(margin.equals("rows") || margin.equals("cols")))
+                       throw new DMLRuntimeException("Unsupported margin 
identifier '" + margin + "'.");
+
+               FrameObject mo = (FrameObject) getTarget(ec);
+               MatrixObject select = params.containsKey("select") ? 
ec.getMatrixObject(params.get("select")) : null;
+               FrameObject out = ec.getFrameObject(output);
+
+               boolean marginRow = params.get("margin").equals("rows");
+               boolean isNotAligned = ((marginRow && 
mo.getFedMapping().getType().isColPartitioned()) ||
+                       (!marginRow && 
mo.getFedMapping().getType().isRowPartitioned()));
+
+               MatrixBlock s = new MatrixBlock();
+               if(select == null && isNotAligned) {
+                       List<MatrixBlock> colSums = new ArrayList<>();
+                       mo.getFedMapping().forEachParallel((range, data) -> {
+                               try {
+                                       FederatedResponse response = data
+                                               .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                                       new 
GetFrameVector(data.getVarID(), margin.equals("rows"))))
+                                               .get();
+
+                                       if(!response.isSuccessful())
+                                               
response.throwExceptionFromResponse();
+                                       MatrixBlock vector = (MatrixBlock) 
response.getData()[0];
+                                       synchronized(colSums) {
+                                               colSums.add(vector);
+                                       }
+                               }
+                               catch(Exception e) {
+                                       throw new DMLRuntimeException(e);
+                               }
+                               return null;
+                       });
+                       // find empty in matrix
+                       BinaryOperator plus = 
InstructionUtils.parseBinaryOperator("+");
+                       BinaryOperator greater = 
InstructionUtils.parseBinaryOperator(">");
+                       s = colSums.get(0);
+                       for(int i = 1; i < colSums.size(); i++)
+                               s = s.binaryOperationsInPlace(plus, 
colSums.get(i));
+                       s = s.binaryOperationsInPlace(greater, new 
MatrixBlock(s.getNumRows(), s.getNumColumns(), 0.0));
+                       select = ExecutionContext.createMatrixObject(s);
+
+                       long varID = FederationUtils.getNextFedDataID();
+                       ec.setVariable(String.valueOf(varID), select);
+                       params.put("select", String.valueOf(varID));
+                       // construct new string
+                       String[] oldString = 
InstructionUtils.getInstructionParts(instString);
+                       String[] newString = new String[oldString.length + 1];
+                       newString[2] = "select=" + varID;
+                       System.arraycopy(oldString, 0, newString, 0, 2);
+                       System.arraycopy(oldString, 2, newString, 3, 
newString.length - 3);
+                       instString = 
instString.replace(InstructionUtils.concatOperands(oldString),
+                               InstructionUtils.concatOperands(newString));
+               }
+
+               if(select == null) {
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand()},
+                               new long[] {mo.getFedMapping().getID()});
+                       mo.getFedMapping().execute(getTID(), true, fr1);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
+               }
+               else if(!isNotAligned) {
+                       // construct commands: broadcast , fed rmempty, clean 
broadcast
+                       FederatedRequest[] fr1 = 
mo.getFedMapping().broadcastSliced(select, !marginRow);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand(),
+                                       new CPOperand(params.get("select"), 
ValueType.FP64, DataType.MATRIX)},
+                               new long[] {mo.getFedMapping().getID(), 
fr1[0].getID()});
+
+                       // execute federated operations and set output
+                       mo.getFedMapping().execute(getTID(), true, fr1, fr2);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+               }
+               else {
+                       // construct commands: broadcast , fed rmempty, clean 
broadcast
+                       FederatedRequest fr1 = 
mo.getFedMapping().broadcast(select);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(instString,
+                               output,
+                               new CPOperand[] {getTargetOperand(),
+                                       new CPOperand(params.get("select"), 
ValueType.FP64, DataType.MATRIX)},
+                               new long[] {mo.getFedMapping().getID(), 
fr1.getID()});
+
+                       // execute federated operations and set output
+                       mo.getFedMapping().execute(getTID(), true, fr1, fr2);
+                       
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr2.getID()));
+               }
+
+               // new ranges
+               Map<FederatedRange, int[]> dcs = new HashMap<>();
+               Map<FederatedRange, int[]> finalDcs1 = dcs;
+               Map<FederatedRange, ValueType[]> finalSchema = new HashMap<>();
+               out.getFedMapping().forEachParallel((range, data) -> {
+                       try {
+                               FederatedResponse response = data
+                                       .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                               new 
GetFrameCharacteristics(data.getVarID())))
+                                       .get();
+
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               Object[] ret = response.getData();
+                               int[] subRangeCharacteristics = new int[]{(int) 
ret[0], (int) ret[1]};
+                               ValueType[] schema = (ValueType[]) ret[2];
+                               synchronized(finalDcs1) {
+                                       finalDcs1.put(range, 
subRangeCharacteristics);
+                               }
+                               synchronized(finalSchema) {
+                                       finalSchema.put(range, schema);
+                               }
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+                       return null;
+               });
+
+               dcs = finalDcs1;
+               out.getDataCharacteristics().set(mo.getDataCharacteristics());
+               int len = marginRow ? mo.getSchema().length : (int) 
(mo.isFederated(FederationMap.FType.ROW) ? s
+                       .getNonZeros() : 
finalSchema.values().stream().mapToInt(e -> e.length).sum());
+               ValueType[] schema = new ValueType[len];
+               int pos = 0;
+               for(int i = 0; i < 
mo.getFedMapping().getFederatedRanges().length; i++) {
+                       FederatedRange federatedRange = new 
FederatedRange(out.getFedMapping().getFederatedRanges()[i]);
+
+                       if(marginRow) {
+                               schema = mo.getSchema();
+                       } else if(mo.isFederated(FederationMap.FType.ROW)) {
+                               schema = finalSchema.get(federatedRange);
+                       } else  {
+                               ValueType[] tmp = 
finalSchema.get(federatedRange);
+                               System.arraycopy(tmp, 0, schema, pos, 
tmp.length);
+                               pos += tmp.length;
+                       }
+
+                       int[] newRange = dcs.get(federatedRange);
+                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(0,
+                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[0]);
+
+                       out.getFedMapping().getFederatedRanges()[i].setEndDim(0,
+                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+
+                       
out.getFedMapping().getFederatedRanges()[i].setBeginDim(1,
+                               
(out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+                                       i == 0) ? 0 : 
out.getFedMapping().getFederatedRanges()[i - 1].getEndDims()[1]);
+
+                       out.getFedMapping().getFederatedRanges()[i].setEndDim(1,
+                               
out.getFedMapping().getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+               }
+
+               out.setSchema(schema);
+               
out.getDataCharacteristics().set(out.getFedMapping().getMaxIndexInRange(0),
+                       out.getFedMapping().getMaxIndexInRange(1),
+                       (int) mo.getBlocksize());
+       }
+
+
+       private void rmemptyMatrix(ExecutionContext ec) {
                String margin = params.get("margin");
                if(!(margin.equals("rows") || margin.equals("cols")))
                        throw new DMLRuntimeException("Unsupported margin 
identifier '" + margin + "'.");
@@ -428,7 +597,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        try {
                                FederatedResponse response = data
                                        .executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                               new 
GetDataCharacteristics(data.getVarID())))
+                                               new 
GetMatrixCharacteristics(data.getVarID())))
                                        .get();
 
                                if(!response.isSuccessful())
@@ -724,11 +893,11 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                }
        }
 
-       private static class GetDataCharacteristics extends FederatedUDF {
+       private static class GetMatrixCharacteristics extends FederatedUDF {
 
                private static final long serialVersionUID = 
578461386177730925L;
 
-               public GetDataCharacteristics(long varID) {
+               public GetMatrixCharacteristics(long varID) {
                        super(new long[] {varID});
                }
 
@@ -746,6 +915,28 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                }
        }
 
+       private static class GetFrameCharacteristics extends FederatedUDF {
+
+               private static final long serialVersionUID = 
578461386177730925L;
+
+               public GetFrameCharacteristics(long varID) {
+                       super(new long[] {varID});
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       FrameBlock fb = ((FrameObject) 
data[0]).acquireReadAndRelease();
+                       int r = fb.getNumRows() != 0 || fb.getNumRows() != -1 ? 
fb.getNumRows() : 0;
+                       int c = fb.getNumColumns() != 0 || fb.getNumColumns() 
!= -1 ? fb.getNumColumns() : 0;
+                       return new FederatedResponse(ResponseType.SUCCESS, new 
Object[] {r, c, fb.getSchema()});
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
        private static class GetVector extends FederatedUDF {
 
                private static final long serialVersionUID = 
-1003061862215703768L;
@@ -779,4 +970,54 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        return null;
                }
        }
+
+       private static class GetFrameVector extends FederatedUDF {
+
+               private static final long serialVersionUID = 
-1003061862215703768L;
+               private final boolean _marginRow;
+
+               public GetFrameVector(long varID, boolean marginRow) {
+                       super(new long[] {varID});
+                       _marginRow = marginRow;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       FrameBlock fb = ((FrameObject) 
data[0]).acquireReadAndRelease();
+
+                       MatrixBlock ret = _marginRow ? new 
MatrixBlock(fb.getNumRows(), 1, 0.0) : new MatrixBlock(1,fb.getNumColumns(), 
0.0);
+
+                       if(_marginRow) {
+                               for(int i = 0; i < fb.getNumRows(); i++) {
+                                       boolean isEmpty = true;
+
+                                       for(int j = 0; j < fb.getNumColumns(); 
j++) {
+                                               ValueType type = 
fb.getSchema()[j];
+                                               isEmpty = isEmpty && 
(ArrayUtils.contains(new double[]{0.0, Double.NaN}, 
UtilFunctions.objectToDoubleSafe(type, fb.get(i, j))));
+
+                                       }
+
+                                       if(!isEmpty)
+                                               ret.setValue(i, 0, 1.0);
+                               }
+                       } else {
+                               for(int i = 0; i < fb.getNumColumns(); i++) {
+                                       int finalI = i;
+                                       ValueType type = fb.getSchema()[i];
+                                       boolean isEmpty = IntStream.range(0, 
fb.getNumRows()).mapToObj(j -> fb.get(j, finalI))
+                                               .allMatch(e -> 
ArrayUtils.contains(new double[]{0.0, Double.NaN}, 
UtilFunctions.objectToDoubleSafe(type, e)));
+
+                                       if(!isEmpty)
+                                               ret.setValue(0, i,1.0);
+                               }
+                       }
+
+                       return new FederatedResponse(ResponseType.SUCCESS, ret);
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
 }
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 86bbdab..64f6e80 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -37,6 +37,8 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.function.Function;
+import java.util.function.IntFunction;
+import java.util.stream.IntStream;
 
 import org.apache.commons.lang.ArrayUtils;
 import org.apache.commons.lang.NotImplementedException;
@@ -46,6 +48,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.hadoop.io.Writable;
 import org.apache.sysds.api.DMLException;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.codegen.CodegenUtils;
@@ -598,6 +601,31 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                _msize = -1;
        }
 
+       public void appendColumn(ValueType vt, Array col) {
+               switch (vt) {
+                       case STRING:
+                               appendColumn(((StringArray) col).get());
+                               break;
+                       case BOOLEAN:
+                               appendColumn(((BooleanArray) col).get());
+                               break;
+                       case INT32:
+                               appendColumn(((IntegerArray) col).get());
+                               break;
+                       case INT64:
+                               appendColumn(((LongArray) col).get());
+                               break;
+                       case FP32:
+                               appendColumn(((FloatArray) col).get());
+                               break;
+                       case FP64:
+                               appendColumn(((DoubleArray) col).get());
+                               break;
+                       default:
+                               throw new RuntimeException("Unsupported value 
type: " + vt);
+               }
+       }
+
        public Object getColumnData(int c) {
                switch(_schema[c]) {
                        case STRING:  return ((StringArray)_coldata[c])._data;
@@ -1640,10 +1668,13 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        _data = data;
                        _size = _data.length;
                }
+               public String[] get() { return _data; }
+
                @Override
                public String get(int index) {
                        return _data[index];
                }
+
                @Override
                public void set(int index, String value) {
                        _data[index] = value;
@@ -1705,10 +1736,13 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        _data = data;
                        _size = _data.length;
                }
+               public boolean[] get() { return _data; }
+
                @Override
                public Boolean get(int index) {
                        return _data[index];
                }
+
                @Override
                public void set(int index, Boolean value) {
                        _data[index] = (value!=null) ? value : false;
@@ -1772,6 +1806,7 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        _data = data;
                        _size = _data.length;
                }
+               public long[] get() { return _data; }
                @Override
                public Long get(int index) {
                        return _data[index];
@@ -1839,6 +1874,7 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        _data = data;
                        _size = _data.length;
                }
+               public int[] get() { return _data; }
 
                @Override
                public Integer get(int index) {
@@ -1906,6 +1942,8 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        _data = data;
                        _size = _data.length;
                }
+               public float[] get() { return _data; }
+
                @Override
                public Float get(int index) {
                        return _data[index];
@@ -1972,6 +2010,7 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                        _data = data;
                        _size = _data.length;
                }
+               public double[] get() { return _data; }
                @Override
                public Double get(int index) {
                        return _data[index];
@@ -2473,6 +2512,77 @@ public class FrameBlock implements CacheBlock, 
Externalizable {
                return ret;
        }
 
+       public  FrameBlock removeEmptyOperations(boolean rows, boolean 
emptyReturn, MatrixBlock select) {
+               if( rows )
+                       return removeEmptyRows(select, emptyReturn);
+               else //cols
+                       return removeEmptyColumns(select, emptyReturn);
+       }
+
+       private FrameBlock removeEmptyRows(MatrixBlock select, boolean 
emptyReturn) {
+               FrameBlock ret = new FrameBlock(_schema, _colnames);
+
+               for(int i = 0; i < _numRows; i++) {
+                       boolean isEmpty = true;
+                       Object[] row = new Object[getNumColumns()];
+
+                       for(int j = 0; j < getNumColumns(); j++) {
+                               Array colData = _coldata[j].clone();
+                               row[j] = colData.get(i);
+                               ValueType type = _schema[j];
+                               isEmpty = isEmpty && (ArrayUtils.contains(new 
double[]{0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(type, 
colData.get(i))));
+                       }
+
+                       if((!isEmpty && select == null) || (select != null && 
select.getValue(i, 0) == 1)) {
+                               ret.appendRow(row);
+                       }
+               }
+
+               if(ret.getNumRows() == 0 && emptyReturn) {
+                       String[][] arr = new String[1][getNumColumns()];
+                       Arrays.fill(arr, new String[]{null});
+                       ValueType[] schema = new ValueType[getNumColumns()];
+                       Arrays.fill(schema, ValueType.STRING);
+                       return new FrameBlock(schema, arr);
+               }
+
+               return ret;
+       }
+
+       private FrameBlock removeEmptyColumns(MatrixBlock select, boolean 
emptyReturn) {
+               FrameBlock ret = new FrameBlock();
+               List<ColumnMetadata> columnMetadata = new ArrayList<>();
+
+               for(int i = 0; i < getNumColumns(); i++) {
+                       Array colData = _coldata[i];
+
+                       boolean isEmpty = false;
+                       if(select == null) {
+                               ValueType type = _schema[i];
+                               isEmpty = IntStream.range(0, 
colData._size).mapToObj((IntFunction<Object>) colData::get)
+                                       .allMatch(e -> ArrayUtils.contains(new 
double[]{0.0, Double.NaN}, UtilFunctions.objectToDoubleSafe(type, e)));
+                       }
+
+                       if((select != null && select.getValue(0, i) == 1) || 
(!isEmpty && select == null)) {
+                               Types.ValueType vt = _schema[i];
+                               ret.appendColumn(vt, _coldata[i].clone());
+                               columnMetadata.add(new 
ColumnMetadata(_colmeta[i]));
+                       }
+               }
+
+               if(ret.getNumColumns() == 0 && emptyReturn) {
+                       String[][] arr = new String[_numRows][];
+                       Arrays.fill(arr, new String[]{null});
+                       return new FrameBlock(new 
ValueType[]{ValueType.STRING}, arr);
+               }
+
+               ret._colmeta = new ColumnMetadata[columnMetadata.size()];
+               columnMetadata.toArray(ret._colmeta);
+               ret.setColumnMetadata(ret._colmeta);
+
+               return ret;
+       }
+
        @Override
        public String toString(){
                StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java 
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index ee6d913..ee64bc8 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -20,6 +20,7 @@
 package org.apache.sysds.runtime.util;
 
 import org.apache.commons.lang.ArrayUtils;
+import org.apache.commons.lang3.math.NumberUtils;
 import org.apache.commons.math3.random.RandomDataGenerator;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -487,6 +488,12 @@ public class UtilFunctions {
                }
        }
 
+       public static double objectToDoubleSafe(ValueType vt, Object in) {
+               if(vt == ValueType.STRING && !NumberUtils.isCreatable((String) 
in)) {
+                       return 1.0;
+               } else return objectToDouble(vt, in);
+       }
+
        public static double objectToDouble(ValueType vt, Object in) {
                if( in == null )  return Double.NaN;
                switch( vt ) {
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/FrameRemoveEmptyTest.java 
b/src/test/java/org/apache/sysds/test/component/frame/FrameRemoveEmptyTest.java
new file mode 100644
index 0000000..d3bbdc4
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/frame/FrameRemoveEmptyTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.frame;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.unary.matrix.RemoveEmptyTest;
+import org.junit.Ignore;
+import org.junit.Test;
+
+public class FrameRemoveEmptyTest extends AutomatedTestBase {
+       private final static String TEST_NAME1 = "removeEmpty1";
+       private final static String TEST_DIR = "functions/frame/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
RemoveEmptyTest.class.getSimpleName() + "/";
+
+       private final static int _rows = 10;
+       private final static int _cols = 6;
+
+       private final static double _sparsityDense = 0.7;
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"V"}));
+       }
+
+       @Test
+       public void testRemoveEmptyRowsDenseCP() {
+               runTestRemoveEmpty(TEST_NAME1, "rows", Types.ExecType.CP, 
false);
+       }
+
+       @Test
+       public void testRemoveEmptyRowsSparseCP() {
+               runTestRemoveEmpty(TEST_NAME1, "cols", Types.ExecType.CP, true);
+       }
+
+       @Test
+       @Ignore
+       public void testRemoveEmptyRowsDenseSP() {
+               runTestRemoveEmpty(TEST_NAME1, "rows", Types.ExecType.SPARK, 
false);
+       }
+
+       @Test
+       @Ignore
+       public void testRemoveEmptyRowsSparseSP() {
+               runTestRemoveEmpty(TEST_NAME1, "rows", Types.ExecType.SPARK, 
true);
+       }
+
+       private void runTestRemoveEmpty(String testname, String margin, 
Types.ExecType et, boolean bSelectIndex) {
+               // rtplatform for MR
+               Types.ExecMode platformOld = rtplatform;
+               switch(et) {
+                       case SPARK:
+                               rtplatform = Types.ExecMode.SPARK;
+                               break;
+                       default:
+                               rtplatform = Types.ExecMode.HYBRID;
+                               break;
+               }
+
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               if(rtplatform == Types.ExecMode.SPARK)
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+               try {
+                       // register test configuration
+                       TestConfiguration config = 
getTestConfiguration(testname);
+                       config.addVariable("rows", _rows);
+                       config.addVariable("cols", _cols);
+                       loadTestConfiguration(config);
+
+                       /* This is for running the junit test the new way, 
i.e., construct the arguments directly */
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + testname + ".dml";
+                       programArgs = new String[] {"-explain", "-args", 
input("V"), margin, output("V")};
+
+                       MatrixBlock in = createInputMatrix(margin, _rows, 
_cols, _sparsityDense, bSelectIndex);
+
+                       runTest(true, false, null, -1);
+                       double[][] outArray = 
TestUtils.convertHashMapToDoubleArray(readDMLMatrixFromOutputDir("V"));
+                       MatrixBlock out = new MatrixBlock(outArray.length, 
outArray[0].length, false);
+                       out.init(outArray, outArray.length, outArray[0].length);
+
+                       MatrixBlock in2 = new MatrixBlock(_rows, _cols + 2, 
0.0);
+                       in2.copy(0, _rows - 1, 0, _cols - 1, in, true);
+                       in2.copy(0, (_rows / 2) - 1, _cols, _cols + 1, new 
MatrixBlock(_rows / 2, 2, 1.0), true);
+                       MatrixBlock expected = in2.removeEmptyOperations(new 
MatrixBlock(), margin.equals("rows"), false, null);
+                       expected = expected.slice(0, expected.getNumRows() - 1, 
0, expected.getNumColumns() - 3);
+
+                       TestUtils.compareMatrices(expected, out, 0);
+               }
+               finally {
+                       // reset platform for additional tests
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+       }
+
+       private MatrixBlock createInputMatrix(String margin, int rows, int 
cols, double sparsity, boolean bSelectIndex) {
+               int rowsp = -1, colsp = -1;
+               if(margin.equals("rows")) {
+                       rowsp = rows / 2;
+                       colsp = cols;
+               }
+               else {
+                       rowsp = rows;
+                       colsp = cols / 2;
+               }
+
+               // long seed = System.nanoTime();
+               double[][] V = getRandomMatrix(rows, cols, 0, 1, sparsity, 7);
+               double[][] Vp = new double[rowsp][colsp];
+               double[][] Ix = null;
+               int innz = 0, vnnz = 0;
+
+               // clear out every other row/column
+               if(margin.equals("rows")) {
+                       Ix = new double[rows][1];
+                       for(int i = 0; i < rows; i++) {
+                               boolean clear = i % 2 != 0;
+                               if(clear) {
+                                       for(int j = 0; j < cols; j++)
+                                               V[i][j] = 0;
+                                       Ix[i][0] = 0;
+                               }
+                               else {
+                                       boolean bNonEmpty = false;
+                                       for(int j = 0; j < cols; j++) {
+                                               Vp[i / 2][j] = V[i][j];
+                                               bNonEmpty |= (V[i][j] != 0.0) ? 
true : false;
+                                               vnnz += (V[i][j] == 0.0) ? 0 : 
1;
+                                       }
+                                       Ix[i][0] = (bNonEmpty) ? 1 : 0;
+                                       innz += Ix[i][0];
+                               }
+                       }
+               }
+               else {
+                       Ix = new double[1][cols];
+                       for(int j = 0; j < cols; j++) {
+                               boolean clear = j % 2 != 0;
+                               if(clear) {
+                                       for(int i = 0; i < rows; i++)
+                                               V[i][j] = 0;
+                                       Ix[0][j] = 0;
+                               }
+                               else {
+                                       boolean bNonEmpty = false;
+                                       for(int i = 0; i < rows; i++) {
+                                               Vp[i][j / 2] = V[i][j];
+                                               bNonEmpty |= (V[i][j] != 0.0) ? 
true : false;
+                                               vnnz += (V[i][j] == 0.0) ? 0 : 
1;
+                                       }
+                                       Ix[0][j] = (bNonEmpty) ? 1 : 0;
+                                       innz += Ix[0][j];
+                               }
+                       }
+               }
+
+               MatrixCharacteristics imc = new 
MatrixCharacteristics(margin.equals("rows") ? rows : 1,
+                       margin.equals("rows") ? 1 : cols, 1000, innz);
+               MatrixCharacteristics vmc = new MatrixCharacteristics(rows, 
cols, 1000, vnnz);
+
+               MatrixBlock in = new MatrixBlock(rows, cols, false);
+               in.init(V, _rows, _cols);
+
+               writeInputMatrixWithMTD("V", V, false, vmc); // always text
+               writeExpectedMatrix("V", Vp);
+               if(bSelectIndex)
+                       writeInputMatrixWithMTD("I", Ix, false, imc);
+
+               return in;
+       }
+}
diff --git a/src/test/scripts/functions/frame/removeEmpty1.dml 
b/src/test/scripts/functions/frame/removeEmpty1.dml
new file mode 100644
index 0000000..696880e
--- /dev/null
+++ b/src/test/scripts/functions/frame/removeEmpty1.dml
@@ -0,0 +1,30 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+
+A = read($1, naStrings= ["NA", "null","  ","NaN", "nan", "", "?", "99999"])
+B = frame(data=["TRUE", "abc"], rows=nrow(A) / 2, cols=2, schema=["BOOLEAN", 
"STRING"])
+C = frame(data=["FALSE", "0.0"], rows=nrow(A) / 2, cols=2, schema=["BOOLEAN", 
"STRING"])
+D = rbind(B, C)
+V = cbind(as.frame(A), D)
+Vp = removeEmpty(target=V, margin=$2)
+X = as.matrix(Vp[, 1:(ncol(Vp)-2)])
+write(X, $3);

Reply via email to