This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 738dd42f80 [SYSTEMDS-3623] Lossy Compression with Binning
738dd42f80 is described below

commit 738dd42f80b53930b161a3dc0b6ff9b883797b59
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Mon Sep 25 01:02:59 2023 +0200

    [SYSTEMDS-3623] Lossy Compression with Binning
    
    This commit overloads the compression instruction to now be able to
    fuse a transform encode binning and compression of a matrix.
    
    `[res, meta] = compress(X, d)`
    
    X is an n x m non-compressed matrix/frame and d is an 1 x m vector
    specifying a number of bins in each column. It returns an
    compressed matrix and transform meta.
    
    Closes #1919
---
 .../java/org/apache/sysds/common/Builtins.java     |   2 +-
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  12 +-
 .../sysds/parser/BuiltinFunctionExpression.java    |  31 ++-
 .../org/apache/sysds/parser/DMLTranslator.java     |  78 +++----
 .../runtime/compress/lib/CLALibBinCompress.java    |  74 +++++++
 .../instructions/cp/CompressionCPInstruction.java  |  57 +++++-
 .../runtime/transform/encode/ColumnEncoderBin.java |  16 +-
 .../org/apache/sysds/test/AutomatedTestBase.java   |  58 +++++-
 src/test/java/org/apache/sysds/test/TestUtils.java |  20 +-
 .../compress/matrixByBin/CompressByBinTest.java    | 225 +++++++++++++++++++++
 .../compress/matrixByBin/compressByBins.dml        |  33 +++
 12 files changed, 547 insertions(+), 61 deletions(-)

diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index 2243eeb963..9f7b3a0d0a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -86,7 +86,7 @@ public enum Builtins {
        COLSUM("colSums", false),
        COLVAR("colVars", false),
        COMPONENTS("components", true),
-       COMPRESS("compress", false),
+       COMPRESS("compress", false, ReturnType.MULTI_RETURN),
        CONFUSIONMATRIX("confusionMatrix", true),
        CONV2D("conv2d", false),
        CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index 7bf48f5da5..0a3748a9f8 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -400,7 +400,7 @@ public class Types
        // Operations that require 2 operands
        public enum OpOp2 {
                AND(true), APPLY_SCHEMA(false), BITWAND(true), BITWOR(true), 
BITWSHIFTL(true), BITWSHIFTR(true),
-               BITWXOR(true), CBIND(false), CONCAT(false), COV(false), 
DIV(true),
+               BITWXOR(true), CBIND(false), COMPRESS(true), CONCAT(false), 
COV(false), DIV(true),
                DROP_INVALID_TYPE(false), DROP_INVALID_LENGTH(false), 
EQUAL(true),
                FRAME_ROW_REPLICATE(true), GREATER(true), GREATEREQUAL(true), 
INTDIV(true),
                INTERQUANTILE(false), IQM(false), LESS(true),
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index a4abc513c6..dc2bc487ed 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -19,10 +19,15 @@
 
 package org.apache.sysds.hops;
 
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
 import org.apache.log4j.Level;
 import org.apache.log4j.Logger;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.FileFormat;
 import org.apache.sysds.common.Types.OpOp1;
 import org.apache.sysds.common.Types.OpOp2;
@@ -38,7 +43,6 @@ import 
org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
 import org.apache.sysds.hops.rewrite.HopRewriteUtils;
 import org.apache.sysds.lops.Checkpoint;
 import org.apache.sysds.lops.Lop;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.lops.compile.Dag;
 import org.apache.sysds.parser.ForStatementBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -61,10 +65,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.IndexRange;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-
 public class OptimizerUtils 
 {
        ////////////////////////////////////////////////////////
@@ -269,7 +269,7 @@ public class OptimizerUtils
         * This variable allows for insertion of Compress and decompress in the 
dml script from the user.
         * This is added because we want to have a way to test, and verify the 
correct placement of compress and decompress commands.
         */
-       public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = false;
+       public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;
 
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java 
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index dd9bfe7892..8f1a496c1e 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -407,6 +407,31 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                        
svdOut3.setBlocksize(getFirstExpr().getOutput().getBlocksize());
 
                        break;
+
+               case COMPRESS:
+                       if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
+                               Expression expressionTwo = getSecondExpr();
+                               checkNumParameters(getSecondExpr() != null ? 2 
: 1);
+                               checkMatrixFrameParam(getFirstExpr());
+                               if(expressionTwo != null)
+                                       checkMatrixParam(getSecondExpr());
+
+                               Identifier compressInput1 = 
getFirstExpr().getOutput();
+                               Identifier compressInput2 = 
getSecondExpr().getOutput();
+
+                               DataIdentifier compressOutput = 
(DataIdentifier) getOutputs()[0];
+                               compressOutput.setDataType(DataType.MATRIX);
+                               
compressOutput.setDimensions(compressInput1.getDim1(), 
compressInput1.getDim2());
+                               compressOutput.setBlocksize 
(compressInput1.getBlocksize());
+                               
compressOutput.setValueType(compressInput1.getValueType());
+
+                               DataIdentifier metaOutput = (DataIdentifier) 
getOutputs()[1];
+                               metaOutput.setDataType(DataType.FRAME);
+                               
metaOutput.setDimensions(compressInput1.getDim1(), -1);
+                       }
+                       else
+                               raiseValidateError("Compress/DeCompress 
instruction not allowed in dml script");
+                       break;
                
                default: //always unconditional
                        raiseValidateError("Unknown Builtin Function opcode: " 
+ _opcode, false);
@@ -1605,8 +1630,10 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                case COMPRESS:
                case DECOMPRESS:
                        if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
-                               checkNumParameters(1);
+                               Expression expressionTwo = getSecondExpr();;
                                checkMatrixParam(getFirstExpr());
+                               if(expressionTwo != null)
+                                       checkMatrixParam(getSecondExpr());
                                output.setDataType(DataType.MATRIX);
                                output.setDimensions(id.getDim1(), 
id.getDim2());
                                output.setBlocksize (id.getBlocksize());
@@ -1702,7 +1729,7 @@ public class BuiltinFunctionExpression extends 
DataIdentifier
                output.setDimensions(Math.max(dims1.getRows(), 
dims2.getRows()), Math.max(dims1.getCols(), dims2.getCols()));
                output.setBlocksize(Math.max(dims1.getBlocksize(), 
dims2.getBlocksize()));
        }
-       
+
        private void setNaryOutputProperties(DataIdentifier output) {
                DataType dt = Arrays.stream(getAllExpr()).allMatch(
                        e -> e.getOutput().getDataType().isScalar()) ? 
DataType.SCALAR : DataType.MATRIX;
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index ba5877d81b..a8ca42d54a 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1045,7 +1045,7 @@ public class DMLTranslator
                                        throw new 
LanguageException(source.printErrorLocation()+": Unsupported indexing 
expression in write statement. " +
                                                                            
"Please, assign the right indexing result to a variable and write this 
variable.");
                                }
-                               
+
                                DataOp ae = (DataOp)processExpression(source, 
target, ids);
                                Expression fmtExpr = 
os.getExprParam(DataExpression.FORMAT_TYPE);
                                ae.setFileFormat((fmtExpr instanceof 
StringIdentifier) ?
@@ -1077,7 +1077,7 @@ public class DMLTranslator
                                                        throw new 
LanguageException("Unrecognized file format: " + ae.getFileFormat());
                                        }
                                }
-                               
+
                                output.add(ae);
                        }
 
@@ -1138,15 +1138,15 @@ public class DMLTranslator
                        }
 
                        if (current instanceof AssignmentStatement) {
-       
+
                                AssignmentStatement as = (AssignmentStatement) 
current;
                                DataIdentifier target = as.getTarget();
                                Expression source = as.getSource();
 
-                       
-                               // CASE: regular assignment statement -- source 
is DML expression that is NOT user-defined or external function 
+
+                               // CASE: regular assignment statement -- source 
is DML expression that is NOT user-defined or external function
                                if (!(source instanceof 
FunctionCallIdentifier)){
-                               
+
                                        // CASE: target is regular data 
identifier
                                        if (!(target instanceof 
IndexedIdentifier)) {
                                                //process right hand side and 
accumulation
@@ -1166,7 +1166,7 @@ public class DMLTranslator
                                                }
 
                                                ids.put(target.getName(), ae);
-                                               
+
                                                //add transient write if needed
                                                Integer statementId = 
liveOutToTemp.get(target.getName());
                                                if ((statementId != null) && 
(statementId.intValue() == i)) {
@@ -1176,11 +1176,11 @@ public class DMLTranslator
                                                        
updatedLiveOut.addVariable(target.getName(), target);
                                                        
output.add(transientwrite);
                                                }
-                                       } 
+                                       }
                                        // CASE: target is indexed identifier 
(left-hand side indexed expression)
                                        else {
                                                Hop ae = 
processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
-                                               
+
                                                if( as.isAccumulator() ) {
                                                        DataIdentifier accum = 
getAccumulatorData(liveIn, target.getName());
                                                        Hop rix = 
processIndexingExpression((IndexedIdentifier)target, null, ids);
@@ -1189,16 +1189,16 @@ public class DMLTranslator
                                                        
HopRewriteUtils.replaceChildReference(ae, ae.getInput(1), binary);
                                                        
target.setProperties(accum.getOutput());
                                                }
-                                               
+
                                                ids.put(target.getName(), ae);
-                                               
+
                                                // obtain origDim values BEFORE 
they are potentially updated during setProperties call
                                                //      (this is incorrect for 
LHS Indexing)
                                                long origDim1 = 
((IndexedIdentifier)target).getOrigDim1();
                                                long origDim2 = 
((IndexedIdentifier)target).getOrigDim2();
                                                
target.setProperties(source.getOutput());
                                                
((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2);
-                                               
+
                                                // preserve data type matrix of 
any index identifier
                                                // (required for scalar input 
to left indexing)
                                                if( target.getDataType() != 
DataType.MATRIX ) {
@@ -1206,7 +1206,7 @@ public class DMLTranslator
                                                        
target.setValueType(ValueType.FP64);
                                                        
target.setBlocksize(ConfigurationManager.getBlocksize());
                                                }
-                                               
+
                                                Integer statementId = 
liveOutToTemp.get(target.getName());
                                                if ((statementId != null) && 
(statementId.intValue() == i)) {
                                                        DataOp transientwrite = 
new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae, 
OpOpData.TRANSIENTWRITE, null);
@@ -1222,36 +1222,36 @@ public class DMLTranslator
                                        //assignment, function call
                                        FunctionCallIdentifier fci = 
(FunctionCallIdentifier) source;
                                        FunctionStatementBlock fsb = 
this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
-                                       
+
                                        //error handling missing function
-                                       if (fsb == null) { 
-                                               throw new 
LanguageException(source.printErrorLocation() + "function " 
+                                       if (fsb == null) {
+                                               throw new 
LanguageException(source.printErrorLocation() + "function "
                                                        + fci.getName() + " is 
undefined in namespace " + fci.getNamespace());
                                        }
-                                       
+
                                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
                                        String fkey = 
DMLProgram.constructFunctionKey(fci.getNamespace(),fci.getName());
-                                       
+
                                        //error handling unsupported function 
call in indexing expression
                                        if( target instanceof IndexedIdentifier 
) {
                                                throw new 
LanguageException("Unsupported function call to '"+fkey+"' in left indexing "
                                                        + "expression. Please, 
assign the function output to a variable.");
                                        }
-                                       
+
                                        //prepare function input names and 
inputs
                                        List<String> inputNames = new 
ArrayList<>(fci.getParamExprs().stream()
                                                .map(e -> 
e.getName()).collect(Collectors.toList()));
                                        List<Hop> finputs = new 
ArrayList<>(fci.getParamExprs().stream()
                                                .map(e -> 
processExpression(e.getExpr(), null, ids)).collect(Collectors.toList()));
-                                       
+
                                        //append default expression for missing 
arguments
                                        appendDefaultArguments(fstmt, 
inputNames, finputs, ids);
-                                       
+
                                        //use function signature to obtain 
names for unnamed args
                                        //(note: consistent parameters already 
checked for functions in general)
                                        if( inputNames.stream().allMatch(n -> 
n==null) )
                                                inputNames = 
fstmt._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
-                                       
+
                                        //create function op
                                        String[] inputNames2 = 
inputNames.toArray(new String[0]);
                                        FunctionType ftype = 
fsb.getFunctionOpType();
@@ -1267,31 +1267,31 @@ public class DMLTranslator
                                //multi-assignment, by definition a function 
call
                                MultiAssignmentStatement mas = 
(MultiAssignmentStatement) current;
                                Expression source = mas.getSource();
-                               
+
                                if ( source instanceof FunctionCallIdentifier ) 
{
                                        FunctionCallIdentifier fci = 
(FunctionCallIdentifier) source;
                                        FunctionStatementBlock fsb = 
this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
                                        if (fsb == null){
-                                               throw new 
LanguageException(source.printErrorLocation() + "function " 
+                                               throw new 
LanguageException(source.printErrorLocation() + "function "
                                                        + fci.getName() + " is 
undefined in namespace " + fci.getNamespace());
                                        }
-                                       
+
                                        FunctionStatement fstmt = 
(FunctionStatement)fsb.getStatement(0);
-                                       
+
                                        //prepare function input names and 
inputs
                                        List<String> inputNames = new 
ArrayList<>(fci.getParamExprs().stream()
                                                .map(e -> 
e.getName()).collect(Collectors.toList()));
                                        List<Hop> finputs = new 
ArrayList<>(fci.getParamExprs().stream()
                                                .map(e -> 
processExpression(e.getExpr(), null, ids)).collect(Collectors.toList()));
-                                       
+
                                        //use function signature to obtain 
names for unnamed args
                                        //(note: consistent parameters already 
checked for functions in general)
                                        if( inputNames.stream().allMatch(n -> 
n==null) )
                                                inputNames = 
fstmt._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
-                                       
+
                                        //append default expression for missing 
arguments
                                        appendDefaultArguments(fstmt, 
inputNames, finputs, ids);
-                                       
+
                                        //create function op
                                        String[] foutputs = 
mas.getTargetList().stream()
                                                .map(d -> 
d.getName()).toArray(String[]::new);
@@ -1314,7 +1314,7 @@ public class DMLTranslator
                                else
                                        throw new LanguageException("Class \"" 
+ source.getClass() + "\" is not supported in Multiple Assignment statements");
                        }
-                       
+
                }
                sb.updateLiveVariablesOut(updatedLiveOut);
                sb.setHops(output);
@@ -2265,11 +2265,21 @@ public class DMLTranslator
                                }
                                
                                // Create the hop for current function call
-                               FunctionOp fcall = new FunctionOp(ftype, 
nameSpace, source.getOpCode().toString(), null, inputs, outputNames, outputs);
-                               currBuiltinOp = fcall;
-                               
+                               currBuiltinOp = new FunctionOp(ftype, 
nameSpace, source.getOpCode().toString(), null, inputs, outputNames, outputs);
                                break;
-                               
+
+                       case COMPRESS:
+                               // Number of outputs = size of targetList = #of 
identifiers in source.getOutputs
+                               String[] outputNamesCompress = new 
String[targetList.size()];
+                               outputNamesCompress[0] = 
targetList.get(0).getName();
+                               outputNamesCompress[1] = 
targetList.get(1).getName();
+                               outputs.add(new DataOp(outputNamesCompress[0], 
DataType.MATRIX, ValueType.FP64, inputs.get(0), OpOpData.FUNCTIONOUTPUT, 
inputs.get(0).getFilename()));
+                               outputs.add(new DataOp(outputNamesCompress[1], 
DataType.FRAME, ValueType.STRING, inputs.get(0), OpOpData.FUNCTIONOUTPUT, 
inputs.get(0).getFilename()));
+
+                               // Create the hop for current function call
+                               currBuiltinOp = new FunctionOp(ftype, 
nameSpace, source.getOpCode().toString(), null, inputs, outputNamesCompress, 
outputs);
+                               break;
+
                        default:
                                throw new ParseException("Invaid Opcode in 
DMLTranslator:processMultipleReturnBuiltinFunctionExpression(): " + 
source.getOpCode());
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinCompress.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinCompress.java
new file mode 100644
index 0000000000..8ef49332a0
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinCompress.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.runtime.compress.lib;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.compress.CompressionStatistics;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+
+public class CLALibBinCompress {
+       public static ColumnEncoderBin.BinMethod binMethod = 
ColumnEncoderBin.BinMethod.EQUI_WIDTH;
+       public static Pair<MatrixBlock, FrameBlock> binCompress(CacheBlock<?> 
X, MatrixBlock d, int k){
+               // Create transform spec acc to binMethod
+               String spec = createSpec(d);
+
+               // Apply compressed transform encode using spec
+               MultiColumnEncoder encoder = //
+                       EncoderFactory.createEncoder(spec, null, 
X.getNumColumns(), null);
+               MatrixBlock binned = encoder.encode(X, k, true);
+
+               // Get metadata from transformencode
+               FrameBlock meta = new FrameBlock(X.getNumColumns(), 
Types.ValueType.STRING);
+               encoder.initMetaData(meta);
+               FrameBlock newMeta = encoder.getMetaData(meta, k);
+
+               // FIXME Optional: recompress, else can be removed (lines 
54-55) once fixed compression
+               if(X instanceof MatrixBlock) {
+                       Pair<MatrixBlock, CompressionStatistics> recompressed = 
CompressedMatrixBlockFactory.compress(binned, k);
+                       return new ImmutablePair<>(recompressed.getKey(), 
newMeta);
+               }
+               else
+                       return new ImmutablePair<>(binned, newMeta);
+       }
+
+       private static String createSpec(MatrixBlock d) {
+               d.sparseToDense();
+               double[] values = d.getDenseBlockValues();
+
+               String binning = binMethod.toString();
+
+               StringBuilder stringBuilder = new StringBuilder();
+               stringBuilder.append("{\"ids\":true,\"bin\":[");
+               for(int i = 0; i < values.length; i++) {
+                       
stringBuilder.append(String.format("{\"id\":%d,\"method\":\"%s\",\"numbins\":%d}",
 i + 1, binning, (int)values[i]));
+                       if(i + 1 < values.length)
+                               stringBuilder.append(',');
+               }
+               stringBuilder.append("]}");
+               return stringBuilder.toString();
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 22766079e6..b59e4d9db8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -19,6 +19,9 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
+import java.util.ArrayList;
+import java.util.List;
+
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -26,6 +29,7 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
 import org.apache.sysds.runtime.compress.CompressionStatistics;
 import org.apache.sysds.runtime.compress.SingletonLookupHashMap;
+import org.apache.sysds.runtime.compress.lib.CLALibBinCompress;
 import org.apache.sysds.runtime.compress.workload.WTreeRoot;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
@@ -39,19 +43,37 @@ public class CompressionCPInstruction extends 
ComputationCPInstruction {
 
        private final int _singletonLookupID;
 
+       /** This is only for binned compression with 2 outputs*/
+       protected final List<CPOperand> _outputs;
+
        private CompressionCPInstruction(Operator op, CPOperand in, CPOperand 
out, String opcode, String istr,
                int singletonLookupID) {
                super(CPType.Compression, op, in, null, null, out, opcode, 
istr);
+               _outputs = null;
+               this._singletonLookupID = singletonLookupID;
+       }
+
+       private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand 
in2, List<CPOperand> out, String opcode, String istr,
+               int singletonLookupID) {
+               super(CPType.Compression, op, in1, in2, null, out.get(0), 
opcode, istr);
+               _outputs = out;
                this._singletonLookupID = singletonLookupID;
        }
 
        public static CompressionCPInstruction parseInstruction(String str) {
-               InstructionUtils.checkNumFields(str, 2, 3);
+               InstructionUtils.checkNumFields(str, 2, 3, 4);
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
                CPOperand in1 = new CPOperand(parts[1]);
                CPOperand out = new CPOperand(parts[2]);
-               if(parts.length == 4) {
+               if(parts.length == 5) {
+                       /** Compression with bins that returns two outputs*/
+                       List<CPOperand> outputs = new ArrayList<>();
+                       outputs.add(new CPOperand(parts[3]));
+                       outputs.add(new CPOperand(parts[4]));
+                       return new CompressionCPInstruction(null, in1, out, 
outputs, opcode, str, 0);
+               }
+               else if(parts.length == 4) {
                        int treeNodeID = Integer.parseInt(parts[3]);
                        return new CompressionCPInstruction(null, in1, out, 
opcode, str, treeNodeID);
                }
@@ -61,6 +83,36 @@ public class CompressionCPInstruction extends 
ComputationCPInstruction {
 
        @Override
        public void processInstruction(ExecutionContext ec) {
+               if(input2 == null)
+                       processSimpleCompressInstruction(ec);
+               else
+                       processCompressByBinInstruction(ec);
+       }
+
+       private void processCompressByBinInstruction(ExecutionContext ec) {
+               final MatrixBlock d = ec.getMatrixInput(input2.getName());
+
+               final int k = OptimizerUtils.getConstrainedNumThreads(-1);
+
+               Pair<MatrixBlock, FrameBlock> out;
+
+               if(ec.isMatrixObject(input1.getName())) {
+                       final MatrixBlock X = 
ec.getMatrixInput(input1.getName());
+                       out = CLALibBinCompress.binCompress(X, d, k);
+                       ec.releaseMatrixInput(input1.getName());
+               } else {
+                       final FrameBlock X = ec.getFrameInput(input1.getName());
+                       out = CLALibBinCompress.binCompress(X, d, k);
+                       ec.releaseFrameInput(input1.getName());
+               }
+               
+               // Set output and release input
+               ec.releaseMatrixInput(input2.getName());
+               ec.setMatrixOutput(_outputs.get(0).getName(), out.getKey());
+               ec.setFrameOutput(_outputs.get(1).getName(), out.getValue());
+       }
+
+       private void processSimpleCompressInstruction(ExecutionContext ec) {
                // final MatrixBlock in = ec.getMatrixInput(input1.getName());
                final SingletonLookupHashMap m = 
SingletonLookupHashMap.getMap();
 
@@ -74,7 +126,6 @@ public class CompressionCPInstruction extends 
ComputationCPInstruction {
                        processMatrixBlockCompression(ec, 
ec.getMatrixInput(input1.getName()), k, root);
                else
                        processFrameBlockCompression(ec, 
ec.getFrameInput(input1.getName()), k, root);
-
        }
 
        private void processMatrixBlockCompression(ExecutionContext ec, 
MatrixBlock in, int k, WTreeRoot root) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 2df54bef20..848491c326 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.transform.encode;
 
-import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
-
 import java.io.IOException;
 import java.io.ObjectInput;
 import java.io.ObjectOutput;
@@ -30,9 +28,11 @@ import java.util.PriorityQueue;
 import java.util.Random;
 import java.util.concurrent.Callable;
 
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
 import org.apache.commons.lang3.tuple.MutableTriple;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.frame.data.columns.Array;
@@ -495,7 +495,17 @@ public class ColumnEncoderBin extends ColumnEncoder {
        }
 
        public enum BinMethod {
-               INVALID, EQUI_WIDTH, EQUI_HEIGHT, EQUI_HEIGHT_APPROX
+               INVALID, EQUI_WIDTH, EQUI_HEIGHT, EQUI_HEIGHT_APPROX;
+
+               @Override
+               public String toString(){
+                       switch(this) {
+                               case EQUI_WIDTH: return "EQUI-WIDTH";
+                               case EQUI_HEIGHT: return "EQUI-HEIGHT";
+                               case EQUI_HEIGHT_APPROX: return 
"EQUI_HEIGHT_APPROX";
+                               default: throw new DMLRuntimeException("Invalid 
encoder type.");
+                       }
+               }
        }
 
        private static class BinSparseApplyTask extends 
ColumnApplyTask<ColumnEncoderBin> {
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 1f48ef760e..354fa12feb 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -19,11 +19,6 @@
 
 package org.apache.sysds.test;
 
-import static java.lang.Math.ceil;
-import static java.lang.Thread.sleep;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
 import java.io.ByteArrayOutputStream;
 import java.io.File;
 import java.io.IOException;
@@ -43,6 +38,10 @@ import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
 
+import static java.lang.Math.ceil;
+import static java.lang.Thread.sleep;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.lang3.ArrayUtils;
@@ -2238,6 +2237,55 @@ public abstract class AutomatedTestBase {
                return writeInputFrameWithMTD(name, data, bIncludeR, mc, 
schema, fmt);
        }
 
+       protected FrameBlock writeInputFrame(String name, FrameBlock data, 
boolean bIncludeR, ValueType[] schema,
+               FileFormat fmt) throws IOException {
+               String completePath = baseDirectory + INPUT_DIR + name;
+               String completeRPath = baseDirectory + INPUT_DIR + name + 
".csv";
+
+               try {
+                       cleanupExistingData(baseDirectory + INPUT_DIR + name, 
bIncludeR);
+               }
+               catch(IOException e) {
+                       e.printStackTrace();
+                       throw new RuntimeException(e);
+               }
+
+               TestUtils.writeTestFrame(completePath, data, schema, fmt);
+               if(bIncludeR) {
+                       TestUtils.writeTestFrame(completeRPath, data, schema, 
FileFormat.CSV, true);
+                       inputRFiles.add(completeRPath);
+               }
+               if(DEBUG)
+                       TestUtils.writeTestFrame(DEBUG_TEMP_DIR + completePath, 
data, schema, fmt);
+               inputDirectories.add(baseDirectory + INPUT_DIR + name);
+
+               return data;
+       }
+
+       protected FrameBlock writeInputFrameWithMTD(String name, FrameBlock 
data, boolean bIncludeR, ValueType[] schema,
+               FileFormat fmt) throws IOException {
+               MatrixCharacteristics mc = new 
MatrixCharacteristics(data.getNumRows(), data.getNumColumns(),
+                       OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
+               return writeInputFrameWithMTD(name, data, bIncludeR, mc, 
schema, fmt);
+       }
+
+       protected FrameBlock writeInputFrameWithMTD(String name, FrameBlock 
data, boolean bIncludeR,
+               MatrixCharacteristics mc, ValueType[] schema, FileFormat fmt) 
throws IOException {
+               writeInputFrame(name, data, bIncludeR, schema, fmt);
+
+               // write metadata file
+               try {
+                       String completeMTDPath = baseDirectory + INPUT_DIR + 
name + ".mtd";
+                       HDFSTool.writeMetaDataFile(completeMTDPath, null, 
schema, DataType.FRAME, mc, fmt);
+               }
+               catch(IOException e) {
+                       e.printStackTrace();
+                       throw new RuntimeException(e);
+               }
+
+               return data;
+       }
+
        protected double[][] writeInputFrameWithMTD(String name, double[][] 
data, boolean bIncludeR,
                MatrixCharacteristics mc, ValueType[] schema, FileFormat fmt) 
throws IOException {
                writeInputFrame(name, data, bIncludeR, schema, fmt);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index c90318b448..907c9adab8 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -19,12 +19,6 @@
 
 package org.apache.sysds.test;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-
 import java.io.BufferedReader;
 import java.io.BufferedWriter;
 import java.io.DataOutputStream;
@@ -50,6 +44,11 @@ import java.util.Random;
 import java.util.Set;
 import java.util.StringTokenizer;
 
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
 import org.apache.commons.io.FileUtils;
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.lang3.NotImplementedException;
@@ -2740,6 +2739,11 @@ public class TestUtils {
                writer.writeFrameToHDFS(frame, file, data.length, 
schema.length);
        }
 
+       public static void writeTestFrame(String file, FrameBlock data, 
ValueType[] schema, FileFormat fmt, boolean isR) throws IOException {
+               FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt);
+               writer.writeFrameToHDFS(data, file, data.getNumRows(), 
schema.length);
+       }
+
        /**
         * <p>
         * Writes a frame to a file using the text format.
@@ -2755,6 +2759,10 @@ public class TestUtils {
                writeTestFrame(file, data, schema, fmt, false);
        }
 
+       public static void writeTestFrame(String file, FrameBlock data, 
ValueType[] schema, FileFormat fmt) throws IOException {
+               writeTestFrame(file, data, schema, fmt, false);
+       }
+
        public static void initFrameData(FrameBlock frame, double[][] data, 
ValueType[] lschema, int rows) {
                Object[] row1 = new Object[lschema.length];
                for( int i=0; i<rows; i++ ) {
diff --git 
a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java
 
b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java
new file mode 100644
index 0000000000..8265b261d2
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.compress.matrixByBin;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Random;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
+import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.builtin.part1.BuiltinDistTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+
+public class CompressByBinTest extends AutomatedTestBase {
+
+
+       private final static String TEST_NAME = "compressByBins";
+       private final static String TEST_DIR = 
"functions/compress/matrixByBin/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
BuiltinDistTest.class.getSimpleName() + "/";
+
+       private final static int rows = 1000;
+
+       private final static int cols = 10;
+
+       private final static int nbins = 10;
+
+       private final static int[] dVector = new int[cols];
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"X"}));
+       }
+
+       @Test
+       public void testCompressBinsMatrixWidthCP() { 
runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); }
+
+       @Test
+       public void testCompressBinsMatrixHeightCP() { 
runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); }
+
+       @Test
+       public void testCompressBinsFrameWidthCP() { 
runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); }
+
+       @Test
+       public void testCompressBinsFrameHeightCP() { 
runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); }
+
+       private void runCompress(Types.ExecType instType, 
ColumnEncoderBin.BinMethod binMethod)
+       {
+               Types.ExecMode platformOld = setExecMode(instType);
+
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-args", input("X"), 
Boolean.toString(binMethod == 
ColumnEncoderBin.BinMethod.EQUI_WIDTH),output("meta"), output("res")};
+
+                       double[][] X = generateMatrixData(binMethod);
+                       writeInputMatrixWithMTD("X", X, true);
+
+                       runTest(true, false, null, -1);
+
+                       checkMetaFile(DataConverter.convertToMatrixBlock(X), 
binMethod);
+
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+
+       private void runCompressFrame(Types.ExecType instType, 
ColumnEncoderBin.BinMethod binMethod)
+       {
+               Types.ExecMode platformOld = setExecMode(instType);
+
+               try
+               {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain", "-args", 
input("X"), Boolean.toString(binMethod == 
ColumnEncoderBin.BinMethod.EQUI_WIDTH) , output("meta"), output("res")};
+
+                       Types.ValueType[] schema = new Types.ValueType[cols];
+                       Arrays.fill(schema, Types.ValueType.FP32);
+                       FrameBlock Xf = generateFrameData(binMethod, schema);
+                       writeInputFrameWithMTD("X", Xf, false, schema, 
Types.FileFormat.CSV);
+
+                       runTest(true, false, null, -1);
+
+                       checkMetaFile(Xf, binMethod);
+
+               }
+               catch(IOException e) {
+                       throw new RuntimeException(e);
+               }
+               finally {
+                       rtplatform = platformOld;
+               }
+       }
+
+       private double[][] generateMatrixData(ColumnEncoderBin.BinMethod 
binMethod) {
+               double[][] X;
+               if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) {
+                       //generate actual dataset
+                       X = getRandomMatrix(rows, cols, -100, 100, 1, 7);
+                       // make sure that bins in [-100, 100]
+                       for(int i = 0; i < cols; i++) {
+                               X[0][i] = -100;
+                               X[1][i] = 100;
+                       }
+               } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+                       X = new double[rows][cols];
+                       for(int c = 0; c < cols; c++) {
+                               double[] vals = new 
Random().doubles(nbins).toArray();
+                               // Create one column
+                               for(int i = 0, j = 0; i < rows; i++) {
+                                       X[i][c] = vals[j];
+                                       if(i == (int) ((j + 1) * (rows / 
nbins)))
+                                               j++;
+                               }
+                       }
+               } else
+                       throw new RuntimeException("Invalid binning method.");
+
+               return X;
+       }
+
+       private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod 
binMethod, Types.ValueType[] schema) {
+               FrameBlock Xf;
+               if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) {
+                       Xf = TestUtils.generateRandomFrameBlock(1000, schema, 
7);
+
+                       for(int i = 0; i < cols; i++) {
+                               Xf.set(0, i, -100);
+                               Xf.set(rows-1, i, 100);
+                       }
+               } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+                       Xf = new FrameBlock();
+                       for(int c = 0; c < schema.length; c++) {
+                               double[] vals = new 
Random().doubles(nbins).toArray();
+                               // Create one column
+                               Array<Float> f = (Array<Float>) 
ArrayFactory.allocate(Types.ValueType.FP32, rows);
+                               for(int i = 0, j = 0; i < rows; i++) {
+                                       f.set(i, vals[j]);
+                                       if(i == (int) ((j + 1) * (rows / 
nbins)))
+                                               j++;
+                               }
+                               Xf.appendColumn(f);
+                       }
+
+               } else
+                       throw new RuntimeException("Invalid binning method.");
+
+               return Xf;
+       }
+
+       private void checkMetaFile(CacheBlock<?> X, ColumnEncoderBin.BinMethod 
binningType) throws IOException{
+               FrameBlock outputMeta = readDMLFrameFromHDFS("meta", 
Types.FileFormat.CSV);
+               Assert.assertEquals(nbins, outputMeta.getNumRows());
+
+               double[] binStarts = new double[nbins];
+               double[] binEnds = new double[nbins];
+
+               for(int c = 0; c < cols; c++) {
+                       if(binningType == 
ColumnEncoderBin.BinMethod.EQUI_WIDTH) {
+                               for(int i = -100, j = 0; i < 100; i += 20) {
+                                       // check bin starts
+                                       double binStart = 
Double.parseDouble(((String) outputMeta.getColumn(c).get(j)).split("·")[0]);
+                                       Assert.assertEquals(i, binStart, 0.0);
+                                       j++;
+                               }
+                       } else {
+                               binStarts[c] = Double.parseDouble(((String) 
outputMeta.getColumn(c).get(0)).split("·")[0]);
+                               binEnds[c] = Double.parseDouble(((String) 
outputMeta.getColumn(c).get(nbins-1)).split("·")[1]);
+                       }
+               }
+
+               if(binningType == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+                       MatrixBlock mX = null;
+                       if(X instanceof FrameBlock) {
+                               mX = 
DataConverter.convertToMatrixBlock((FrameBlock) X);
+                       }
+                       else {
+                               mX = (MatrixBlock) X;
+                       }
+                       double[] colMins = mX.colMin().getDenseBlockValues();
+                       double[] colMaxs = mX.colMax().getDenseBlockValues();
+
+                       Assert.assertArrayEquals(colMins, binStarts, 
0.0000000001);
+                       Assert.assertArrayEquals(colMaxs, binEnds, 
0.0000000001);
+               }
+       }
+
+}
diff --git a/src/test/scripts/functions/compress/matrixByBin/compressByBins.dml 
b/src/test/scripts/functions/compress/matrixByBin/compressByBins.dml
new file mode 100644
index 0000000000..9f79738202
--- /dev/null
+++ b/src/test/scripts/functions/compress/matrixByBin/compressByBins.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1)
+
+# d = ceil(rand(rows = 1, cols = ncol(X), min = 1, max = 100))
+d = matrix(10, rows = 1, cols = ncol(X))
+
+[res, meta] = compress(X, d)
+print(toString(res))
+print(toString(meta))
+
+write(meta, $3, format="csv");
+
+


Reply via email to