This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 738dd42f80 [SYSTEMDS-3623] Lossy Compression with Binning
738dd42f80 is described below
commit 738dd42f80b53930b161a3dc0b6ff9b883797b59
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Mon Sep 25 01:02:59 2023 +0200
[SYSTEMDS-3623] Lossy Compression with Binning
This commit overloads the compression instruction to now be able to
fuse a transform encode binning and compression of a matrix.
`[res, meta] = compress(X, d)`
X is an n x m non-compressed matrix/frame and d is an 1 x m vector
specifying a number of bins in each column. It returns an
compressed matrix and transform meta.
Closes #1919
---
.../java/org/apache/sysds/common/Builtins.java | 2 +-
src/main/java/org/apache/sysds/common/Types.java | 2 +-
.../java/org/apache/sysds/hops/OptimizerUtils.java | 12 +-
.../sysds/parser/BuiltinFunctionExpression.java | 31 ++-
.../org/apache/sysds/parser/DMLTranslator.java | 78 +++----
.../runtime/compress/lib/CLALibBinCompress.java | 74 +++++++
.../instructions/cp/CompressionCPInstruction.java | 57 +++++-
.../runtime/transform/encode/ColumnEncoderBin.java | 16 +-
.../org/apache/sysds/test/AutomatedTestBase.java | 58 +++++-
src/test/java/org/apache/sysds/test/TestUtils.java | 20 +-
.../compress/matrixByBin/CompressByBinTest.java | 225 +++++++++++++++++++++
.../compress/matrixByBin/compressByBins.dml | 33 +++
12 files changed, 547 insertions(+), 61 deletions(-)
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java
b/src/main/java/org/apache/sysds/common/Builtins.java
index 2243eeb963..9f7b3a0d0a 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -86,7 +86,7 @@ public enum Builtins {
COLSUM("colSums", false),
COLVAR("colVars", false),
COMPONENTS("components", true),
- COMPRESS("compress", false),
+ COMPRESS("compress", false, ReturnType.MULTI_RETURN),
CONFUSIONMATRIX("confusionMatrix", true),
CONV2D("conv2d", false),
CONV2D_BACKWARD_FILTER("conv2d_backward_filter", false),
diff --git a/src/main/java/org/apache/sysds/common/Types.java
b/src/main/java/org/apache/sysds/common/Types.java
index 7bf48f5da5..0a3748a9f8 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -400,7 +400,7 @@ public class Types
// Operations that require 2 operands
public enum OpOp2 {
AND(true), APPLY_SCHEMA(false), BITWAND(true), BITWOR(true),
BITWSHIFTL(true), BITWSHIFTR(true),
- BITWXOR(true), CBIND(false), CONCAT(false), COV(false),
DIV(true),
+ BITWXOR(true), CBIND(false), COMPRESS(true), CONCAT(false),
COV(false), DIV(true),
DROP_INVALID_TYPE(false), DROP_INVALID_LENGTH(false),
EQUAL(true),
FRAME_ROW_REPLICATE(true), GREATER(true), GREATEREQUAL(true),
INTDIV(true),
INTERQUANTILE(false), IQM(false), LESS(true),
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index a4abc513c6..dc2bc487ed 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -19,10 +19,15 @@
package org.apache.sysds.hops;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.FileFormat;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
@@ -38,7 +43,6 @@ import
org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.Lop;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -61,10 +65,6 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-
public class OptimizerUtils
{
////////////////////////////////////////////////////////
@@ -269,7 +269,7 @@ public class OptimizerUtils
* This variable allows for insertion of Compress and decompress in the
dml script from the user.
* This is added because we want to have a way to test, and verify the
correct placement of compress and decompress commands.
*/
- public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = false;
+ public static boolean ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND = true;
/**
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index dd9bfe7892..8f1a496c1e 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -407,6 +407,31 @@ public class BuiltinFunctionExpression extends
DataIdentifier
svdOut3.setBlocksize(getFirstExpr().getOutput().getBlocksize());
break;
+
+ case COMPRESS:
+ if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
+ Expression expressionTwo = getSecondExpr();
+ checkNumParameters(getSecondExpr() != null ? 2
: 1);
+ checkMatrixFrameParam(getFirstExpr());
+ if(expressionTwo != null)
+ checkMatrixParam(getSecondExpr());
+
+ Identifier compressInput1 =
getFirstExpr().getOutput();
+ Identifier compressInput2 =
getSecondExpr().getOutput();
+
+ DataIdentifier compressOutput =
(DataIdentifier) getOutputs()[0];
+ compressOutput.setDataType(DataType.MATRIX);
+
compressOutput.setDimensions(compressInput1.getDim1(),
compressInput1.getDim2());
+ compressOutput.setBlocksize
(compressInput1.getBlocksize());
+
compressOutput.setValueType(compressInput1.getValueType());
+
+ DataIdentifier metaOutput = (DataIdentifier)
getOutputs()[1];
+ metaOutput.setDataType(DataType.FRAME);
+
metaOutput.setDimensions(compressInput1.getDim1(), -1);
+ }
+ else
+ raiseValidateError("Compress/DeCompress
instruction not allowed in dml script");
+ break;
default: //always unconditional
raiseValidateError("Unknown Builtin Function opcode: "
+ _opcode, false);
@@ -1605,8 +1630,10 @@ public class BuiltinFunctionExpression extends
DataIdentifier
case COMPRESS:
case DECOMPRESS:
if(OptimizerUtils.ALLOW_SCRIPT_LEVEL_COMPRESS_COMMAND){
- checkNumParameters(1);
+ Expression expressionTwo = getSecondExpr();;
checkMatrixParam(getFirstExpr());
+ if(expressionTwo != null)
+ checkMatrixParam(getSecondExpr());
output.setDataType(DataType.MATRIX);
output.setDimensions(id.getDim1(),
id.getDim2());
output.setBlocksize (id.getBlocksize());
@@ -1702,7 +1729,7 @@ public class BuiltinFunctionExpression extends
DataIdentifier
output.setDimensions(Math.max(dims1.getRows(),
dims2.getRows()), Math.max(dims1.getCols(), dims2.getCols()));
output.setBlocksize(Math.max(dims1.getBlocksize(),
dims2.getBlocksize()));
}
-
+
private void setNaryOutputProperties(DataIdentifier output) {
DataType dt = Arrays.stream(getAllExpr()).allMatch(
e -> e.getOutput().getDataType().isScalar()) ?
DataType.SCALAR : DataType.MATRIX;
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index ba5877d81b..a8ca42d54a 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -1045,7 +1045,7 @@ public class DMLTranslator
throw new
LanguageException(source.printErrorLocation()+": Unsupported indexing
expression in write statement. " +
"Please, assign the right indexing result to a variable and write this
variable.");
}
-
+
DataOp ae = (DataOp)processExpression(source,
target, ids);
Expression fmtExpr =
os.getExprParam(DataExpression.FORMAT_TYPE);
ae.setFileFormat((fmtExpr instanceof
StringIdentifier) ?
@@ -1077,7 +1077,7 @@ public class DMLTranslator
throw new
LanguageException("Unrecognized file format: " + ae.getFileFormat());
}
}
-
+
output.add(ae);
}
@@ -1138,15 +1138,15 @@ public class DMLTranslator
}
if (current instanceof AssignmentStatement) {
-
+
AssignmentStatement as = (AssignmentStatement)
current;
DataIdentifier target = as.getTarget();
Expression source = as.getSource();
-
- // CASE: regular assignment statement -- source
is DML expression that is NOT user-defined or external function
+
+ // CASE: regular assignment statement -- source
is DML expression that is NOT user-defined or external function
if (!(source instanceof
FunctionCallIdentifier)){
-
+
// CASE: target is regular data
identifier
if (!(target instanceof
IndexedIdentifier)) {
//process right hand side and
accumulation
@@ -1166,7 +1166,7 @@ public class DMLTranslator
}
ids.put(target.getName(), ae);
-
+
//add transient write if needed
Integer statementId =
liveOutToTemp.get(target.getName());
if ((statementId != null) &&
(statementId.intValue() == i)) {
@@ -1176,11 +1176,11 @@ public class DMLTranslator
updatedLiveOut.addVariable(target.getName(), target);
output.add(transientwrite);
}
- }
+ }
// CASE: target is indexed identifier
(left-hand side indexed expression)
else {
Hop ae =
processLeftIndexedExpression(source, (IndexedIdentifier)target, ids);
-
+
if( as.isAccumulator() ) {
DataIdentifier accum =
getAccumulatorData(liveIn, target.getName());
Hop rix =
processIndexingExpression((IndexedIdentifier)target, null, ids);
@@ -1189,16 +1189,16 @@ public class DMLTranslator
HopRewriteUtils.replaceChildReference(ae, ae.getInput(1), binary);
target.setProperties(accum.getOutput());
}
-
+
ids.put(target.getName(), ae);
-
+
// obtain origDim values BEFORE
they are potentially updated during setProperties call
// (this is incorrect for
LHS Indexing)
long origDim1 =
((IndexedIdentifier)target).getOrigDim1();
long origDim2 =
((IndexedIdentifier)target).getOrigDim2();
target.setProperties(source.getOutput());
((IndexedIdentifier)target).setOriginalDimensions(origDim1, origDim2);
-
+
// preserve data type matrix of
any index identifier
// (required for scalar input
to left indexing)
if( target.getDataType() !=
DataType.MATRIX ) {
@@ -1206,7 +1206,7 @@ public class DMLTranslator
target.setValueType(ValueType.FP64);
target.setBlocksize(ConfigurationManager.getBlocksize());
}
-
+
Integer statementId =
liveOutToTemp.get(target.getName());
if ((statementId != null) &&
(statementId.intValue() == i)) {
DataOp transientwrite =
new DataOp(target.getName(), target.getDataType(), target.getValueType(), ae,
OpOpData.TRANSIENTWRITE, null);
@@ -1222,36 +1222,36 @@ public class DMLTranslator
//assignment, function call
FunctionCallIdentifier fci =
(FunctionCallIdentifier) source;
FunctionStatementBlock fsb =
this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
-
+
//error handling missing function
- if (fsb == null) {
- throw new
LanguageException(source.printErrorLocation() + "function "
+ if (fsb == null) {
+ throw new
LanguageException(source.printErrorLocation() + "function "
+ fci.getName() + " is
undefined in namespace " + fci.getNamespace());
}
-
+
FunctionStatement fstmt =
(FunctionStatement)fsb.getStatement(0);
String fkey =
DMLProgram.constructFunctionKey(fci.getNamespace(),fci.getName());
-
+
//error handling unsupported function
call in indexing expression
if( target instanceof IndexedIdentifier
) {
throw new
LanguageException("Unsupported function call to '"+fkey+"' in left indexing "
+ "expression. Please,
assign the function output to a variable.");
}
-
+
//prepare function input names and
inputs
List<String> inputNames = new
ArrayList<>(fci.getParamExprs().stream()
.map(e ->
e.getName()).collect(Collectors.toList()));
List<Hop> finputs = new
ArrayList<>(fci.getParamExprs().stream()
.map(e ->
processExpression(e.getExpr(), null, ids)).collect(Collectors.toList()));
-
+
//append default expression for missing
arguments
appendDefaultArguments(fstmt,
inputNames, finputs, ids);
-
+
//use function signature to obtain
names for unnamed args
//(note: consistent parameters already
checked for functions in general)
if( inputNames.stream().allMatch(n ->
n==null) )
inputNames =
fstmt._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
-
+
//create function op
String[] inputNames2 =
inputNames.toArray(new String[0]);
FunctionType ftype =
fsb.getFunctionOpType();
@@ -1267,31 +1267,31 @@ public class DMLTranslator
//multi-assignment, by definition a function
call
MultiAssignmentStatement mas =
(MultiAssignmentStatement) current;
Expression source = mas.getSource();
-
+
if ( source instanceof FunctionCallIdentifier )
{
FunctionCallIdentifier fci =
(FunctionCallIdentifier) source;
FunctionStatementBlock fsb =
this._dmlProg.getFunctionStatementBlock(fci.getNamespace(),fci.getName());
if (fsb == null){
- throw new
LanguageException(source.printErrorLocation() + "function "
+ throw new
LanguageException(source.printErrorLocation() + "function "
+ fci.getName() + " is
undefined in namespace " + fci.getNamespace());
}
-
+
FunctionStatement fstmt =
(FunctionStatement)fsb.getStatement(0);
-
+
//prepare function input names and
inputs
List<String> inputNames = new
ArrayList<>(fci.getParamExprs().stream()
.map(e ->
e.getName()).collect(Collectors.toList()));
List<Hop> finputs = new
ArrayList<>(fci.getParamExprs().stream()
.map(e ->
processExpression(e.getExpr(), null, ids)).collect(Collectors.toList()));
-
+
//use function signature to obtain
names for unnamed args
//(note: consistent parameters already
checked for functions in general)
if( inputNames.stream().allMatch(n ->
n==null) )
inputNames =
fstmt._inputParams.stream().map(d -> d.getName()).collect(Collectors.toList());
-
+
//append default expression for missing
arguments
appendDefaultArguments(fstmt,
inputNames, finputs, ids);
-
+
//create function op
String[] foutputs =
mas.getTargetList().stream()
.map(d ->
d.getName()).toArray(String[]::new);
@@ -1314,7 +1314,7 @@ public class DMLTranslator
else
throw new LanguageException("Class \""
+ source.getClass() + "\" is not supported in Multiple Assignment statements");
}
-
+
}
sb.updateLiveVariablesOut(updatedLiveOut);
sb.setHops(output);
@@ -2265,11 +2265,21 @@ public class DMLTranslator
}
// Create the hop for current function call
- FunctionOp fcall = new FunctionOp(ftype,
nameSpace, source.getOpCode().toString(), null, inputs, outputNames, outputs);
- currBuiltinOp = fcall;
-
+ currBuiltinOp = new FunctionOp(ftype,
nameSpace, source.getOpCode().toString(), null, inputs, outputNames, outputs);
break;
-
+
+ case COMPRESS:
+ // Number of outputs = size of targetList = #of
identifiers in source.getOutputs
+ String[] outputNamesCompress = new
String[targetList.size()];
+ outputNamesCompress[0] =
targetList.get(0).getName();
+ outputNamesCompress[1] =
targetList.get(1).getName();
+ outputs.add(new DataOp(outputNamesCompress[0],
DataType.MATRIX, ValueType.FP64, inputs.get(0), OpOpData.FUNCTIONOUTPUT,
inputs.get(0).getFilename()));
+ outputs.add(new DataOp(outputNamesCompress[1],
DataType.FRAME, ValueType.STRING, inputs.get(0), OpOpData.FUNCTIONOUTPUT,
inputs.get(0).getFilename()));
+
+ // Create the hop for current function call
+ currBuiltinOp = new FunctionOp(ftype,
nameSpace, source.getOpCode().toString(), null, inputs, outputNamesCompress,
outputs);
+ break;
+
default:
throw new ParseException("Invaid Opcode in
DMLTranslator:processMultipleReturnBuiltinFunctionExpression(): " +
source.getOpCode());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinCompress.java
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinCompress.java
new file mode 100644
index 0000000000..8ef49332a0
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibBinCompress.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.runtime.compress.lib;
+
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.Pair;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
+import org.apache.sysds.runtime.compress.CompressionStatistics;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+
+public class CLALibBinCompress {
+ public static ColumnEncoderBin.BinMethod binMethod =
ColumnEncoderBin.BinMethod.EQUI_WIDTH;
+ public static Pair<MatrixBlock, FrameBlock> binCompress(CacheBlock<?>
X, MatrixBlock d, int k){
+ // Create transform spec acc to binMethod
+ String spec = createSpec(d);
+
+ // Apply compressed transform encode using spec
+ MultiColumnEncoder encoder = //
+ EncoderFactory.createEncoder(spec, null,
X.getNumColumns(), null);
+ MatrixBlock binned = encoder.encode(X, k, true);
+
+ // Get metadata from transformencode
+ FrameBlock meta = new FrameBlock(X.getNumColumns(),
Types.ValueType.STRING);
+ encoder.initMetaData(meta);
+ FrameBlock newMeta = encoder.getMetaData(meta, k);
+
+ // FIXME Optional: recompress, else can be removed (lines
54-55) once fixed compression
+ if(X instanceof MatrixBlock) {
+ Pair<MatrixBlock, CompressionStatistics> recompressed =
CompressedMatrixBlockFactory.compress(binned, k);
+ return new ImmutablePair<>(recompressed.getKey(),
newMeta);
+ }
+ else
+ return new ImmutablePair<>(binned, newMeta);
+ }
+
+ private static String createSpec(MatrixBlock d) {
+ d.sparseToDense();
+ double[] values = d.getDenseBlockValues();
+
+ String binning = binMethod.toString();
+
+ StringBuilder stringBuilder = new StringBuilder();
+ stringBuilder.append("{\"ids\":true,\"bin\":[");
+ for(int i = 0; i < values.length; i++) {
+
stringBuilder.append(String.format("{\"id\":%d,\"method\":\"%s\",\"numbins\":%d}",
i + 1, binning, (int)values[i]));
+ if(i + 1 < values.length)
+ stringBuilder.append(',');
+ }
+ stringBuilder.append("]}");
+ return stringBuilder.toString();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
index 22766079e6..b59e4d9db8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/CompressionCPInstruction.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.instructions.cp;
+import java.util.ArrayList;
+import java.util.List;
+
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -26,6 +29,7 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory;
import org.apache.sysds.runtime.compress.CompressionStatistics;
import org.apache.sysds.runtime.compress.SingletonLookupHashMap;
+import org.apache.sysds.runtime.compress.lib.CLALibBinCompress;
import org.apache.sysds.runtime.compress.workload.WTreeRoot;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.frame.data.FrameBlock;
@@ -39,19 +43,37 @@ public class CompressionCPInstruction extends
ComputationCPInstruction {
private final int _singletonLookupID;
+ /** This is only for binned compression with 2 outputs*/
+ protected final List<CPOperand> _outputs;
+
private CompressionCPInstruction(Operator op, CPOperand in, CPOperand
out, String opcode, String istr,
int singletonLookupID) {
super(CPType.Compression, op, in, null, null, out, opcode,
istr);
+ _outputs = null;
+ this._singletonLookupID = singletonLookupID;
+ }
+
+ private CompressionCPInstruction(Operator op, CPOperand in1, CPOperand
in2, List<CPOperand> out, String opcode, String istr,
+ int singletonLookupID) {
+ super(CPType.Compression, op, in1, in2, null, out.get(0),
opcode, istr);
+ _outputs = out;
this._singletonLookupID = singletonLookupID;
}
public static CompressionCPInstruction parseInstruction(String str) {
- InstructionUtils.checkNumFields(str, 2, 3);
+ InstructionUtils.checkNumFields(str, 2, 3, 4);
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
CPOperand in1 = new CPOperand(parts[1]);
CPOperand out = new CPOperand(parts[2]);
- if(parts.length == 4) {
+ if(parts.length == 5) {
+ /** Compression with bins that returns two outputs*/
+ List<CPOperand> outputs = new ArrayList<>();
+ outputs.add(new CPOperand(parts[3]));
+ outputs.add(new CPOperand(parts[4]));
+ return new CompressionCPInstruction(null, in1, out,
outputs, opcode, str, 0);
+ }
+ else if(parts.length == 4) {
int treeNodeID = Integer.parseInt(parts[3]);
return new CompressionCPInstruction(null, in1, out,
opcode, str, treeNodeID);
}
@@ -61,6 +83,36 @@ public class CompressionCPInstruction extends
ComputationCPInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
+ if(input2 == null)
+ processSimpleCompressInstruction(ec);
+ else
+ processCompressByBinInstruction(ec);
+ }
+
+ private void processCompressByBinInstruction(ExecutionContext ec) {
+ final MatrixBlock d = ec.getMatrixInput(input2.getName());
+
+ final int k = OptimizerUtils.getConstrainedNumThreads(-1);
+
+ Pair<MatrixBlock, FrameBlock> out;
+
+ if(ec.isMatrixObject(input1.getName())) {
+ final MatrixBlock X =
ec.getMatrixInput(input1.getName());
+ out = CLALibBinCompress.binCompress(X, d, k);
+ ec.releaseMatrixInput(input1.getName());
+ } else {
+ final FrameBlock X = ec.getFrameInput(input1.getName());
+ out = CLALibBinCompress.binCompress(X, d, k);
+ ec.releaseFrameInput(input1.getName());
+ }
+
+ // Set output and release input
+ ec.releaseMatrixInput(input2.getName());
+ ec.setMatrixOutput(_outputs.get(0).getName(), out.getKey());
+ ec.setFrameOutput(_outputs.get(1).getName(), out.getValue());
+ }
+
+ private void processSimpleCompressInstruction(ExecutionContext ec) {
// final MatrixBlock in = ec.getMatrixInput(input1.getName());
final SingletonLookupHashMap m =
SingletonLookupHashMap.getMap();
@@ -74,7 +126,6 @@ public class CompressionCPInstruction extends
ComputationCPInstruction {
processMatrixBlockCompression(ec,
ec.getMatrixInput(input1.getName()), k, root);
else
processFrameBlockCompression(ec,
ec.getFrameInput(input1.getName()), k, root);
-
}
private void processMatrixBlockCompression(ExecutionContext ec,
MatrixBlock in, int k, WTreeRoot root) {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 2df54bef20..848491c326 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.transform.encode;
-import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
-
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
@@ -30,9 +28,11 @@ import java.util.PriorityQueue;
import java.util.Random;
import java.util.concurrent.Callable;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
import org.apache.commons.lang3.tuple.MutableTriple;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.frame.data.columns.Array;
@@ -495,7 +495,17 @@ public class ColumnEncoderBin extends ColumnEncoder {
}
public enum BinMethod {
- INVALID, EQUI_WIDTH, EQUI_HEIGHT, EQUI_HEIGHT_APPROX
+ INVALID, EQUI_WIDTH, EQUI_HEIGHT, EQUI_HEIGHT_APPROX;
+
+ @Override
+ public String toString(){
+ switch(this) {
+ case EQUI_WIDTH: return "EQUI-WIDTH";
+ case EQUI_HEIGHT: return "EQUI-HEIGHT";
+ case EQUI_HEIGHT_APPROX: return
"EQUI_HEIGHT_APPROX";
+ default: throw new DMLRuntimeException("Invalid
encoder type.");
+ }
+ }
}
private static class BinSparseApplyTask extends
ColumnApplyTask<ColumnEncoderBin> {
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 1f48ef760e..354fa12feb 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -19,11 +19,6 @@
package org.apache.sysds.test;
-import static java.lang.Math.ceil;
-import static java.lang.Thread.sleep;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
@@ -43,6 +38,10 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
+import static java.lang.Math.ceil;
+import static java.lang.Thread.sleep;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils;
@@ -2238,6 +2237,55 @@ public abstract class AutomatedTestBase {
return writeInputFrameWithMTD(name, data, bIncludeR, mc,
schema, fmt);
}
+ protected FrameBlock writeInputFrame(String name, FrameBlock data,
boolean bIncludeR, ValueType[] schema,
+ FileFormat fmt) throws IOException {
+ String completePath = baseDirectory + INPUT_DIR + name;
+ String completeRPath = baseDirectory + INPUT_DIR + name +
".csv";
+
+ try {
+ cleanupExistingData(baseDirectory + INPUT_DIR + name,
bIncludeR);
+ }
+ catch(IOException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+
+ TestUtils.writeTestFrame(completePath, data, schema, fmt);
+ if(bIncludeR) {
+ TestUtils.writeTestFrame(completeRPath, data, schema,
FileFormat.CSV, true);
+ inputRFiles.add(completeRPath);
+ }
+ if(DEBUG)
+ TestUtils.writeTestFrame(DEBUG_TEMP_DIR + completePath,
data, schema, fmt);
+ inputDirectories.add(baseDirectory + INPUT_DIR + name);
+
+ return data;
+ }
+
+ protected FrameBlock writeInputFrameWithMTD(String name, FrameBlock
data, boolean bIncludeR, ValueType[] schema,
+ FileFormat fmt) throws IOException {
+ MatrixCharacteristics mc = new
MatrixCharacteristics(data.getNumRows(), data.getNumColumns(),
+ OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
+ return writeInputFrameWithMTD(name, data, bIncludeR, mc,
schema, fmt);
+ }
+
+ protected FrameBlock writeInputFrameWithMTD(String name, FrameBlock
data, boolean bIncludeR,
+ MatrixCharacteristics mc, ValueType[] schema, FileFormat fmt)
throws IOException {
+ writeInputFrame(name, data, bIncludeR, schema, fmt);
+
+ // write metadata file
+ try {
+ String completeMTDPath = baseDirectory + INPUT_DIR +
name + ".mtd";
+ HDFSTool.writeMetaDataFile(completeMTDPath, null,
schema, DataType.FRAME, mc, fmt);
+ }
+ catch(IOException e) {
+ e.printStackTrace();
+ throw new RuntimeException(e);
+ }
+
+ return data;
+ }
+
protected double[][] writeInputFrameWithMTD(String name, double[][]
data, boolean bIncludeR,
MatrixCharacteristics mc, ValueType[] schema, FileFormat fmt)
throws IOException {
writeInputFrame(name, data, bIncludeR, schema, fmt);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java
b/src/test/java/org/apache/sysds/test/TestUtils.java
index c90318b448..907c9adab8 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -19,12 +19,6 @@
package org.apache.sysds.test;
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assert.fail;
-
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.DataOutputStream;
@@ -50,6 +44,11 @@ import java.util.Random;
import java.util.Set;
import java.util.StringTokenizer;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.NotImplementedException;
@@ -2740,6 +2739,11 @@ public class TestUtils {
writer.writeFrameToHDFS(frame, file, data.length,
schema.length);
}
+ public static void writeTestFrame(String file, FrameBlock data,
ValueType[] schema, FileFormat fmt, boolean isR) throws IOException {
+ FrameWriter writer = FrameWriterFactory.createFrameWriter(fmt);
+ writer.writeFrameToHDFS(data, file, data.getNumRows(),
schema.length);
+ }
+
/**
* <p>
* Writes a frame to a file using the text format.
@@ -2755,6 +2759,10 @@ public class TestUtils {
writeTestFrame(file, data, schema, fmt, false);
}
+ public static void writeTestFrame(String file, FrameBlock data,
ValueType[] schema, FileFormat fmt) throws IOException {
+ writeTestFrame(file, data, schema, fmt, false);
+ }
+
public static void initFrameData(FrameBlock frame, double[][] data,
ValueType[] lschema, int rows) {
Object[] row1 = new Object[lschema.length];
for( int i=0; i<rows; i++ ) {
diff --git
a/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java
b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java
new file mode 100644
index 0000000000..8265b261d2
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/compress/matrixByBin/CompressByBinTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.compress.matrixByBin;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Random;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.frame.data.columns.Array;
+import org.apache.sysds.runtime.frame.data.columns.ArrayFactory;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.builtin.part1.BuiltinDistTest;
+import org.junit.Assert;
+import org.junit.Test;
+
+
+public class CompressByBinTest extends AutomatedTestBase {
+
+
+ private final static String TEST_NAME = "compressByBins";
+ private final static String TEST_DIR =
"functions/compress/matrixByBin/";
+ private static final String TEST_CLASS_DIR = TEST_DIR +
BuiltinDistTest.class.getSimpleName() + "/";
+
+ private final static int rows = 1000;
+
+ private final static int cols = 10;
+
+ private final static int nbins = 10;
+
+ private final static int[] dVector = new int[cols];
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME,new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"X"}));
+ }
+
+ @Test
+ public void testCompressBinsMatrixWidthCP() {
runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); }
+
+ @Test
+ public void testCompressBinsMatrixHeightCP() {
runCompress(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); }
+
+ @Test
+ public void testCompressBinsFrameWidthCP() {
runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_WIDTH); }
+
+ @Test
+ public void testCompressBinsFrameHeightCP() {
runCompressFrame(Types.ExecType.CP, ColumnEncoderBin.BinMethod.EQUI_HEIGHT); }
+
+ private void runCompress(Types.ExecType instType,
ColumnEncoderBin.BinMethod binMethod)
+ {
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ try
+ {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-args", input("X"),
Boolean.toString(binMethod ==
ColumnEncoderBin.BinMethod.EQUI_WIDTH),output("meta"), output("res")};
+
+ double[][] X = generateMatrixData(binMethod);
+ writeInputMatrixWithMTD("X", X, true);
+
+ runTest(true, false, null, -1);
+
+ checkMetaFile(DataConverter.convertToMatrixBlock(X),
binMethod);
+
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ private void runCompressFrame(Types.ExecType instType,
ColumnEncoderBin.BinMethod binMethod)
+ {
+ Types.ExecMode platformOld = setExecMode(instType);
+
+ try
+ {
+ loadTestConfiguration(getTestConfiguration(TEST_NAME));
+
+ String HOME = SCRIPT_DIR + TEST_DIR;
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
+ programArgs = new String[]{"-explain", "-args",
input("X"), Boolean.toString(binMethod ==
ColumnEncoderBin.BinMethod.EQUI_WIDTH) , output("meta"), output("res")};
+
+ Types.ValueType[] schema = new Types.ValueType[cols];
+ Arrays.fill(schema, Types.ValueType.FP32);
+ FrameBlock Xf = generateFrameData(binMethod, schema);
+ writeInputFrameWithMTD("X", Xf, false, schema,
Types.FileFormat.CSV);
+
+ runTest(true, false, null, -1);
+
+ checkMetaFile(Xf, binMethod);
+
+ }
+ catch(IOException e) {
+ throw new RuntimeException(e);
+ }
+ finally {
+ rtplatform = platformOld;
+ }
+ }
+
+ private double[][] generateMatrixData(ColumnEncoderBin.BinMethod
binMethod) {
+ double[][] X;
+ if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) {
+ //generate actual dataset
+ X = getRandomMatrix(rows, cols, -100, 100, 1, 7);
+ // make sure that bins in [-100, 100]
+ for(int i = 0; i < cols; i++) {
+ X[0][i] = -100;
+ X[1][i] = 100;
+ }
+ } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ X = new double[rows][cols];
+ for(int c = 0; c < cols; c++) {
+ double[] vals = new
Random().doubles(nbins).toArray();
+ // Create one column
+ for(int i = 0, j = 0; i < rows; i++) {
+ X[i][c] = vals[j];
+ if(i == (int) ((j + 1) * (rows /
nbins)))
+ j++;
+ }
+ }
+ } else
+ throw new RuntimeException("Invalid binning method.");
+
+ return X;
+ }
+
+ private FrameBlock generateFrameData(ColumnEncoderBin.BinMethod
binMethod, Types.ValueType[] schema) {
+ FrameBlock Xf;
+ if(binMethod == ColumnEncoderBin.BinMethod.EQUI_WIDTH) {
+ Xf = TestUtils.generateRandomFrameBlock(1000, schema,
7);
+
+ for(int i = 0; i < cols; i++) {
+ Xf.set(0, i, -100);
+ Xf.set(rows-1, i, 100);
+ }
+ } else if(binMethod == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ Xf = new FrameBlock();
+ for(int c = 0; c < schema.length; c++) {
+ double[] vals = new
Random().doubles(nbins).toArray();
+ // Create one column
+ Array<Float> f = (Array<Float>)
ArrayFactory.allocate(Types.ValueType.FP32, rows);
+ for(int i = 0, j = 0; i < rows; i++) {
+ f.set(i, vals[j]);
+ if(i == (int) ((j + 1) * (rows /
nbins)))
+ j++;
+ }
+ Xf.appendColumn(f);
+ }
+
+ } else
+ throw new RuntimeException("Invalid binning method.");
+
+ return Xf;
+ }
+
+ private void checkMetaFile(CacheBlock<?> X, ColumnEncoderBin.BinMethod
binningType) throws IOException{
+ FrameBlock outputMeta = readDMLFrameFromHDFS("meta",
Types.FileFormat.CSV);
+ Assert.assertEquals(nbins, outputMeta.getNumRows());
+
+ double[] binStarts = new double[nbins];
+ double[] binEnds = new double[nbins];
+
+ for(int c = 0; c < cols; c++) {
+ if(binningType ==
ColumnEncoderBin.BinMethod.EQUI_WIDTH) {
+ for(int i = -100, j = 0; i < 100; i += 20) {
+ // check bin starts
+ double binStart =
Double.parseDouble(((String) outputMeta.getColumn(c).get(j)).split("·")[0]);
+ Assert.assertEquals(i, binStart, 0.0);
+ j++;
+ }
+ } else {
+ binStarts[c] = Double.parseDouble(((String)
outputMeta.getColumn(c).get(0)).split("·")[0]);
+ binEnds[c] = Double.parseDouble(((String)
outputMeta.getColumn(c).get(nbins-1)).split("·")[1]);
+ }
+ }
+
+ if(binningType == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ MatrixBlock mX = null;
+ if(X instanceof FrameBlock) {
+ mX =
DataConverter.convertToMatrixBlock((FrameBlock) X);
+ }
+ else {
+ mX = (MatrixBlock) X;
+ }
+ double[] colMins = mX.colMin().getDenseBlockValues();
+ double[] colMaxs = mX.colMax().getDenseBlockValues();
+
+ Assert.assertArrayEquals(colMins, binStarts,
0.0000000001);
+ Assert.assertArrayEquals(colMaxs, binEnds,
0.0000000001);
+ }
+ }
+
+}
diff --git a/src/test/scripts/functions/compress/matrixByBin/compressByBins.dml
b/src/test/scripts/functions/compress/matrixByBin/compressByBins.dml
new file mode 100644
index 0000000000..9f79738202
--- /dev/null
+++ b/src/test/scripts/functions/compress/matrixByBin/compressByBins.dml
@@ -0,0 +1,33 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read($1)
+
+# d = ceil(rand(rows = 1, cols = ncol(X), min = 1, max = 100))
+d = matrix(10, rows = 1, cols = ncol(X))
+
+[res, meta] = compress(X, d)
+print(toString(res))
+print(toString(meta))
+
+write(meta, $3, format="csv");
+
+