This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4b46b1f [SYSTEMDS-2810] Improved serialization of transform encoders
4b46b1f is described below
commit 4b46b1f11c235bf998acfb3c42e1b255d853ea13
Author: Olga <[email protected]>
AuthorDate: Fri Jan 29 23:35:48 2021 +0100
[SYSTEMDS-2810] Improved serialization of transform encoders
The transform encode and apply encoders/decoders carry potentially
large meta data such as the recode maps (dictionaries). For this reason,
this patch adds dedicated serialization and deserialization code paths,
to bypass the default java serialization and thus, avoid serializing
temporary internal data structures and unnecessarily bloated
representations. This serialization framework is now implicitly used in
respective spark and federated instructions.
Closes #1171.
---
.../runtime/controlprogram/ParForProgramBlock.java | 2 +-
.../sysds/runtime/transform/decode/Decoder.java | 68 ++++++++-
.../runtime/transform/decode/DecoderComposite.java | 29 +++-
.../runtime/transform/decode/DecoderDummycode.java | 25 ++++
.../runtime/transform/decode/DecoderFactory.java | 30 ++++
.../transform/decode/DecoderPassThrough.java | 35 ++++-
.../runtime/transform/decode/DecoderRecode.java | 39 ++++-
.../sysds/runtime/transform/encode/Encoder.java | 97 ++++++++-----
.../sysds/runtime/transform/encode/EncoderBin.java | 75 +++++++---
.../runtime/transform/encode/EncoderComposite.java | 54 ++++++-
.../runtime/transform/encode/EncoderDummycode.java | 43 ++++++
.../runtime/transform/encode/EncoderFactory.java | 45 ++++++
.../transform/encode/EncoderFeatureHash.java | 22 ++-
.../runtime/transform/encode/EncoderMVImpute.java | 88 +++++++++++-
.../runtime/transform/encode/EncoderOmit.java | 42 ++++++
.../transform/encode/EncoderPassThrough.java | 4 +-
.../runtime/transform/encode/EncoderRecode.java | 159 ++++++++++++++-------
.../transform/TransformFrameEncodeDecodeTest.java | 2 +-
18 files changed, 734 insertions(+), 125 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
index f11d795..b0b87eb 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/ParForProgramBlock.java
@@ -586,7 +586,7 @@ public class ParForProgramBlock extends ForProgramBlock
// OptimizationWrapper.setLogLevel(_optLogLevel); //set
optimizer log level
OptimizationWrapper.optimize(_optMode, sb, this, ec,
_monitor); //core optimize
}
-
+
///////
//DATA PARTITIONING of read-only parent variables of type
(matrix,unpartitioned)
///////
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
index 4417387..5bc3e97 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/Decoder.java
@@ -19,7 +19,10 @@
package org.apache.sysds.runtime.transform.decode;
-import java.io.Serializable;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -31,14 +34,13 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
* interface for decoding matrices to frames.
*
*/
-public abstract class Decoder implements Serializable
+public abstract class Decoder implements Externalizable
{
private static final long serialVersionUID = -1732411001366177787L;
- protected final ValueType[] _schema;
- protected final int[] _colList;
+ protected ValueType[] _schema;
+ protected int[] _colList;
protected String[] _colnames = null;
-
protected Decoder(ValueType[] schema, int[] colList) {
_schema = schema;
_colList = colList;
@@ -91,4 +93,60 @@ public abstract class Decoder implements Serializable
}
public abstract void initMetaData(FrameBlock meta);
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
serialization.
+ *
+ * @param os object output
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void writeExternal(ObjectOutput os)
+ throws IOException
+ {
+ int size1 = (_colList == null) ? 0 : _colList.length;
+ os.writeInt(size1);
+ for(int i = 0; i < size1; i++)
+ os.writeInt(_colList[i]);
+
+ int size2 = (_colnames == null) ? 0 : _colnames.length;
+ os.writeInt(size2);
+ for(int j = 0; j < size2; j++)
+ os.writeUTF(_colnames[j]);
+
+ int size3 = (_schema == null) ? 0 : _schema.length;
+ os.writeInt(size3);
+ for(int j = 0; j < size3; j++)
+ os.writeByte(_schema[j].ordinal());
+ }
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
deserialization.
+ *
+ * @param in object input
+ * @throws IOException if IOException occur
+ */
+ @Override
+ public void readExternal(ObjectInput in)
+ throws IOException
+ {
+ int size1 = in.readInt();
+ _colList = (size1 == 0) ? null : new int[size1];
+ for(int i = 0; i < size1; i++)
+ _colList[i] = in.readInt();
+
+ int size2 = in.readInt();
+ _colnames = (size2 == 0) ? null : new String[size2];
+ for(int j = 0; j < size2; j++) {
+ _colnames[j] = in.readUTF();
+ }
+
+ int size3 = in.readInt();
+ _schema = (size3 == 0) ? null : new ValueType[size3];
+ for(int j = 0; j < size3; j++) {
+ _schema[j] = ValueType.values()[in.readByte()];
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
index 263e064..7081405 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderComposite.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.transform.decode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -36,7 +39,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public class DecoderComposite extends Decoder
{
private static final long serialVersionUID = 5790600547144743716L;
-
+
private List<Decoder> _decoders = null;
protected DecoderComposite(ValueType[] schema, List<Decoder> decoders) {
@@ -44,6 +47,8 @@ public class DecoderComposite extends Decoder
_decoders = decoders;
}
+ public DecoderComposite() { super(null, null); }
+
@Override
public FrameBlock decode(MatrixBlock in, FrameBlock out) {
for( Decoder decoder : _decoders )
@@ -73,4 +78,26 @@ public class DecoderComposite extends Decoder
for( Decoder decoder : _decoders )
decoder.initMetaData(meta);
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeInt(_decoders.size());
+ for(Decoder decoder : _decoders) {
+ out.writeByte(DecoderFactory.getDecoderType(decoder));
+ decoder.writeExternal(out);
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ int decodersSize = in.readInt();
+ _decoders = new ArrayList<>();
+ for(int i = 0; i < decodersSize; i++) {
+ Decoder decoder =
DecoderFactory.createInstance(in.readByte());
+ decoder.readExternal(in);
+ _decoders.add(decoder);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
index ab1fbc8..061e1e2 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderDummycode.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.transform.decode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -121,4 +124,26 @@ public class DecoderDummycode extends Decoder
off += ndist - 1;
}
}
+
+ @Override
+ public void writeExternal(ObjectOutput os) throws IOException {
+ super.writeExternal(os);
+ os.writeInt(_clPos.length);
+ for(int i = 0; i < _clPos.length; i++) {
+ os.writeInt(_clPos[i]);
+ os.writeInt(_cuPos[i]);
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ int size = in.readInt();
+ _clPos = new int[size];
+ _cuPos = new int[size];
+ for(int i = 0; i < size; i++) {
+ _clPos[i] = in.readInt();
+ _cuPos[i] = in.readInt();
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
index b51547d..6fad975 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderFactory.java
@@ -36,6 +36,12 @@ import static
org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
public class DecoderFactory
{
+ public enum DecoderType {
+ Dummycode,
+ PassThrough,
+ Recode
+ };
+
public static Decoder createDecoder(String spec, String[] colnames,
ValueType[] schema, FrameBlock meta) {
return createDecoder(spec, colnames, schema, meta,
meta.getNumColumns(), -1, -1);
}
@@ -101,4 +107,28 @@ public class DecoderFactory
return decoder;
}
+
+ public static int getDecoderType(Decoder decoder) {
+ if( decoder instanceof DecoderDummycode )
+ return DecoderType.Dummycode.ordinal();
+ else if( decoder instanceof DecoderRecode )
+ return DecoderType.Recode.ordinal();
+ else if( decoder instanceof DecoderPassThrough )
+ return DecoderType.PassThrough.ordinal();
+ throw new DMLRuntimeException("Unsupported decoder type: "
+ + decoder.getClass().getCanonicalName());
+ }
+
+ public static Decoder createInstance(int type) {
+ DecoderType dtype = DecoderType.values()[type];
+
+ // create instance
+ switch(dtype) {
+ case Dummycode: return new DecoderDummycode(null,
null);
+ case PassThrough: return new DecoderPassThrough(null,
null, null);
+ case Recode: return new DecoderRecode(null, false,
null);
+ default:
+ throw new DMLRuntimeException("Unsupported
Encoder Type used: " + dtype);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
index 753c666..54ef02c 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderPassThrough.java
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.transform.decode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -39,12 +42,14 @@ public class DecoderPassThrough extends Decoder
private int[] _dcCols = null;
private int[] _srcCols = null;
-
+
protected DecoderPassThrough(ValueType[] schema, int[] ptCols, int[]
dcCols) {
super(schema, ptCols);
_dcCols = dcCols;
}
+ public DecoderPassThrough() { super(null, null); }
+
@Override
public FrameBlock decode(MatrixBlock in, FrameBlock out) {
out.ensureAllocatedColumns(in.getNumRows());
@@ -112,4 +117,32 @@ public class DecoderPassThrough extends Decoder
_srcCols = _colList;
}
}
+
+ @Override
+ public void writeExternal(ObjectOutput os)
+ throws IOException
+ {
+ super.writeExternal(os);
+ os.writeInt(_srcCols.length);
+ for(int i = 0; i < _srcCols.length; i++)
+ os.writeInt(_srcCols[i]);
+
+ os.writeInt(_dcCols.length);
+ for(int i = 0; i < _dcCols.length; i++)
+ os.writeInt(_dcCols[i]);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in)
+ throws IOException
+ {
+ super.readExternal(in);
+ _srcCols = new int[in.readInt()];
+ for(int i = 0; i < _srcCols.length; i++)
+ _srcCols[i] = in.readInt();
+
+ _dcCols = new int[in.readInt()];
+ for(int i = 0; i < _dcCols.length; i++)
+ _dcCols[i] = in.readInt();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
index 9ae315f..78b2e62 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/decode/DecoderRecode.java
@@ -19,11 +19,15 @@
package org.apache.sysds.runtime.transform.decode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
-
import java.util.List;
+import java.util.Map.Entry;
+
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -43,7 +47,9 @@ public class DecoderRecode extends Decoder
private HashMap<Long, Object>[] _rcMaps = null;
private boolean _onOut = false;
-
+
+ public DecoderRecode() { super(null, null); }
+
protected DecoderRecode(ValueType[] schema, boolean onOut, int[]
rcCols) {
super(schema, rcCols);
_onOut = onOut;
@@ -135,4 +141,33 @@ public class DecoderRecode extends Decoder
String id = entry.substring(ixq+2,idx);
pair.set(token, id);
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeBoolean(_onOut);
+ out.writeInt(_rcMaps.length);
+ for(int i = 0; i < _rcMaps.length; i++) {
+ out.writeInt(_rcMaps[i].size());
+ for(Entry<Long,Object> e1 : _rcMaps[i].entrySet()) {
+ out.writeLong(e1.getKey());
+ out.writeUTF(e1.getValue().toString());
+ }
+ }
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ _onOut = in.readBoolean();
+ _rcMaps = (HashMap<Long,Object>[])new HashMap[in.readInt()];
+ for(int i = 0; i < _rcMaps.length; i++) {
+ HashMap<Long, Object> maps = new HashMap<>();
+ int size = in.readInt();
+ for(int j = 0; j < size; j++)
+ maps.put(in.readLong(), in.readUTF());
+ _rcMaps[i] = maps;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
index 784e4d6..bd8e10c 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/Encoder.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,7 +19,10 @@
package org.apache.sysds.runtime.transform.encode;
-import java.io.Serializable;
+import java.io.Externalizable;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
@@ -38,37 +41,37 @@ import org.apache.wink.json4j.JSONArray;
/**
* Base class for all transform encoders providing both a row and block
* interface for decoding frames to matrices.
- *
+ *
*/
-public abstract class Encoder implements Serializable
+public abstract class Encoder implements Externalizable
{
private static final long serialVersionUID = 2299156350718979064L;
protected static final Log LOG =
LogFactory.getLog(Encoder.class.getName());
-
- protected int _clen = -1;
+
+ protected int _clen = -1;
protected int[] _colList = null;
-
+
protected Encoder( int[] colList, int clen ) {
_colList = colList;
_clen = clen;
}
-
+
public int[] getColList() {
return _colList;
}
-
+
public void setColList(int[] colList) {
_colList = colList;
}
-
+
public int getNumCols() {
return _clen;
}
public int initColList(JSONArray attrs) {
_colList = new int[attrs.size()];
- for(int i=0; i < _colList.length; i++)
- _colList[i] = UtilFunctions.toInt(attrs.get(i));
+ for(int i=0; i < _colList.length; i++)
+ _colList[i] = UtilFunctions.toInt(attrs.get(i));
return _colList.length;
}
@@ -76,21 +79,21 @@ public abstract class Encoder implements Serializable
_colList = colList;
return _colList.length;
}
-
+
/**
- * Indicates if this encoder is applicable, i.e, if there is at
- * least one column to encode.
- *
+ * Indicates if this encoder is applicable, i.e, if there is at
+ * least one column to encode.
+ *
* @return true if at least one column to encode
*/
public boolean isApplicable() {
return (_colList != null && _colList.length > 0);
}
-
+
/**
* Indicates if this encoder is applicable for the given column ID,
* i.e., if it is subject to this transformation.
- *
+ *
* @param colID column ID
* @return true if encoder is applicable for given column
*/
@@ -100,10 +103,10 @@ public abstract class Encoder implements Serializable
int idx = Arrays.binarySearch(_colList, colID);
return ( idx >= 0 ? idx : -1);
}
-
+
/**
* Block encode: build and apply (transform encode).
- *
+ *
* @param in input frame block
* @param out output matrix block
* @return output matrix block
@@ -113,11 +116,11 @@ public abstract class Encoder implements Serializable
/**
* Build the transform meta data for the given block input. This call
modifies
* and keeps meta data as encoder state.
- *
+ *
* @param in input frame block
*/
public abstract void build(FrameBlock in);
-
+
/**
* Allocates internal data structures for partial build.
*/
@@ -137,7 +140,7 @@ public abstract class Encoder implements Serializable
/**
* Encode input data blockwise according to existing transform meta
* data (transform apply).
- *
+ *
* @param in input frame block
* @param out output matrix block
* @return output matrix block
@@ -159,7 +162,7 @@ public abstract class Encoder implements Serializable
/**
* Returns a new Encoder that only handles a sub range of columns.
- *
+ *
* @param ixRange the range (1-based, begin inclusive, end exclusive)
* @return an encoder of the same type, just for the sub-range
*/
@@ -170,7 +173,7 @@ public abstract class Encoder implements Serializable
/**
* Merges the column information, like how many columns the frame needs
and which columns this encoder operates on.
- *
+ *
* @param other the other encoder of the same type
* @param col column at which the second encoder will be merged in
(1-based)
*/
@@ -191,7 +194,7 @@ public abstract class Encoder implements Serializable
* Merges another encoder, of a compatible type, in after a certain
position. Resizes as necessary.
* <code>Encoders</code> are compatible with themselves and
<code>EncoderComposite</code> is compatible with every
* other <code>Encoder</code>.
- *
+ *
* @param other the encoder that should be merged in
* @param row the row where it should be placed (1-based)
* @param col the col where it should be placed (1-based)
@@ -200,7 +203,7 @@ public abstract class Encoder implements Serializable
throw new DMLRuntimeException(
this.getClass().getSimpleName() + " does not support
merging with " + other.getClass().getSimpleName());
}
-
+
/**
* Update index-ranges to after encoding. Note that only Dummycoding
changes the ranges.
*
@@ -213,7 +216,7 @@ public abstract class Encoder implements Serializable
/**
* Construct a frame block out of the transform meta data.
- *
+ *
* @param out output frame block
* @return output frame block?
*/
@@ -221,15 +224,15 @@ public abstract class Encoder implements Serializable
/**
* Sets up the required meta data for a subsequent call to apply.
- *
+ *
* @param meta frame block
*/
public abstract void initMetaData(FrameBlock meta);
-
+
/**
* Obtain the column mapping of encoded frames based on the passed
* meta data frame.
- *
+ *
* @param meta meta data frame block
* @param out output matrix
* @return matrix with column mapping (one row per attribute)
@@ -238,4 +241,34 @@ public abstract class Encoder implements Serializable
//default: do nothing
return out;
}
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
serialization.
+ *
+ * @param os object output
+ * @throws IOException if IOException occurs
+ */
+ @Override
+ public void writeExternal(ObjectOutput os) throws IOException {
+ os.writeInt(_clen);
+ os.writeInt(_colList.length);
+ for(int col : _colList)
+ os.writeInt(col);
+ }
+
+ /**
+ * Redirects the default java serialization via externalizable to our
default
+ * hadoop writable serialization for efficient broadcast/rdd
deserialization.
+ *
+ * @param in object input
+ * @throws IOException if IOException occur
+ */
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ _clen = in.readInt();
+ _colList = new int[in.readInt()];
+ for(int i = 0; i < _colList.length; i++)
+ _colList[i] = in.readInt();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
index cbbeb67..f01e873 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderBin.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -20,6 +20,8 @@
package org.apache.sysds.runtime.transform.encode;
import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -39,7 +41,7 @@ import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;
-public class EncoderBin extends Encoder
+public class EncoderBin extends Encoder
{
private static final long serialVersionUID = 1917445005206076078L;
@@ -48,10 +50,10 @@ public class EncoderBin extends Encoder
public static final String NBINS_PREFIX = "nbins";
protected int[] _numBins = null;
-
+
//frame transform-apply attributes
// a) column bin boundaries
- //TODO binMins is redundant and could be removed
+ //TODO binMins is redundant and could be removed - necessary for
correct fed results
private double[][] _binMins = null;
private double[][] _binMaxs = null;
// b) column min/max (for partial build)
@@ -59,16 +61,16 @@ public class EncoderBin extends Encoder
private double[] _colMaxs = null;
public EncoderBin(JSONObject parsedSpec, String[] colnames, int clen,
int minCol, int maxCol)
- throws JSONException, IOException
+ throws JSONException, IOException
{
super( null, clen );
if ( !parsedSpec.containsKey(TfMethod.BIN.toString()) )
return;
-
+
//parse column names or column ids
List<Integer> collist =
TfMetaUtils.parseBinningColIDs(parsedSpec, colnames, minCol, maxCol);
initColList(ArrayUtils.toPrimitive(collist.toArray(new
Integer[0])));
-
+
//parse number of bins per column
boolean ids = parsedSpec.containsKey("ids") &&
parsedSpec.getBoolean("ids");
JSONArray group = (JSONArray)
parsedSpec.get(TfMethod.BIN.toString());
@@ -82,12 +84,12 @@ public class EncoderBin extends Encoder
_numBins[pos] = colspec.containsKey("numbins")
? colspec.getInt("numbins") : 1;
}
}
-
+
public EncoderBin() {
super(new int[0], 0);
_numBins = new int[0];
}
-
+
private EncoderBin(int[] colList, int clen, int[] numBins, double[][]
binMins, double[][] binMaxs) {
super(colList, clen);
_numBins = numBins;
@@ -121,7 +123,7 @@ public class EncoderBin extends Encoder
public void build(FrameBlock in) {
if ( !isApplicable() )
return;
-
+
// derive bin boundaries from min/max per column
for(int j=0; j <_colList.length; j++) {
double min = Double.POSITIVE_INFINITY;
@@ -178,14 +180,14 @@ public class EncoderBin extends Encoder
_colMaxs[j] = max;
}
}
-
+
@Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
for(int j=0; j<_colList.length; j++) {
int colID = _colList[j];
for( int i=0; i<in.getNumRows(); i++ ) {
double inVal = UtilFunctions.objectToDouble(
- in.getSchema()[colID-1],
in.get(i, colID-1));
+ in.getSchema()[colID-1], in.get(i,
colID-1));
int ix = Arrays.binarySearch(_binMaxs[j],
inVal);
int binID = ((ix < 0) ? Math.abs(ix+1) : ix) +
1;
out.quickSetValue(i, colID-1, binID);
@@ -193,7 +195,7 @@ public class EncoderBin extends Encoder
}
return out;
}
-
+
@Override
public Encoder subRangeEncoder(IndexRange ixRange) {
List<Integer> colsList = new ArrayList<>();
@@ -221,7 +223,7 @@ public class EncoderBin extends Encoder
numBinsList.stream().mapToInt((i) -> i).toArray(),
binMinsList.toArray(new double[0][0]),
binMaxsList.toArray(new double[0][0]));
}
-
+
@Override
public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderBin) {
@@ -273,7 +275,7 @@ public class EncoderBin extends Encoder
}
super.mergeAt(other, row, col);
}
-
+
@Override
public FrameBlock getMetaData(FrameBlock meta) {
//allocate frame if necessary
@@ -281,7 +283,7 @@ public class EncoderBin extends Encoder
for( int j=0; j<_colList.length; j++ )
maxLength = Math.max(maxLength, _binMaxs[j].length);
meta.ensureAllocatedColumns(maxLength);
-
+
//serialize the internal state into frame meta data
for( int j=0; j<_colList.length; j++ ) {
int colID = _colList[j]; //1-based
@@ -296,7 +298,7 @@ public class EncoderBin extends Encoder
}
return meta;
}
-
+
@Override
public void initMetaData(FrameBlock meta) {
if( meta == null || _binMaxs != null )
@@ -316,4 +318,41 @@ public class EncoderBin extends Encoder
}
}
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+
+ out.writeInt(_numBins.length);
+ out.writeBoolean(_binMaxs!=null);
+ for(int i = 0; i < _numBins.length; i++) {
+ out.writeInt(_numBins[i]);
+ if( _binMaxs != null )
+ for(int j = 0; j < _binMaxs[i].length; j++) {
+ out.writeDouble(_binMaxs[i][j]);
+ out.writeDouble(_binMins[i][j]);
+ }
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ int d1 = in.readInt();
+ boolean minmax = in.readBoolean();
+ _numBins = new int[d1];
+ _binMaxs = minmax ? new double[d1][] : null;
+ _binMins = minmax ? new double[d1][] : null;
+
+ for(int i = 0; i < d1; i++) {
+ _numBins[i] = in.readInt();
+ if( !minmax ) continue;
+ _binMaxs[i] = new double[_numBins[i]];
+ _binMins[i] = new double[_numBins[i]];
+ for(int j = 0; j < _binMaxs[i].length; j++) {
+ _binMaxs[i][j] = in.readDouble();
+ _binMins[i][j] = in.readDouble();
+ }
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
index f4923d6..1115ba2 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderComposite.java
@@ -19,9 +19,13 @@
package org.apache.sysds.runtime.transform.encode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
+import java.util.Objects;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types.ValueType;
@@ -42,7 +46,11 @@ public class EncoderComposite extends Encoder
private List<Encoder> _encoders = null;
private FrameBlock _meta = null;
-
+
+ public EncoderComposite() {
+ super(null, -1);
+ }
+
public EncoderComposite(List<Encoder> encoders) {
super(null, -1);
_encoders = encoders;
@@ -134,6 +142,22 @@ public class EncoderComposite extends Encoder
}
@Override
+ public boolean equals(Object o) {
+ if(this == o)
+ return true;
+ if(o == null || getClass() != o.getClass())
+ return false;
+ EncoderComposite that = (EncoderComposite) o;
+ return _encoders.equals(that._encoders)
+ && Objects.equals(_meta, that._meta);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(_encoders, _meta);
+ }
+
+ @Override
public Encoder subRangeEncoder(IndexRange ixRange) {
List<Encoder> subRangeEncoders = new ArrayList<>();
for (Encoder encoder : _encoders) {
@@ -248,4 +272,32 @@ public class EncoderComposite extends Encoder
}
return sb.toString();
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ out.writeInt(_encoders.size());
+ for(Encoder encoder : _encoders) {
+ out.writeByte(EncoderFactory.getEncoderType(encoder));
+ encoder.writeExternal(out);
+ }
+ out.writeBoolean(_meta != null);
+ if(_meta != null)
+ _meta.write(out);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ int encodersSize = in.readInt();
+ _encoders = new ArrayList<>();
+ for(int i = 0; i < encodersSize; i++) {
+ Encoder encoder =
EncoderFactory.createInstance(in.readByte());
+ encoder.readExternal(in);
+ _encoders.add(encoder);
+ }
+ if (in.readBoolean()) {
+ FrameBlock meta = new FrameBlock();
+ meta.readFields(in);
+ _meta = meta;
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
index f590a04..7fddecb 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderDummycode.java
@@ -19,11 +19,15 @@
package org.apache.sysds.runtime.transform.encode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -231,4 +235,43 @@ public class EncoderDummycode extends Encoder
return out;
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeLong(_dummycodedLength);
+ int size1 = _domainSizes == null ? 0 : _domainSizes.length;
+ out.writeInt(size1);
+ for(int i = 0; i < size1; i++)
+ out.writeInt(_domainSizes[i]);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ _dummycodedLength = in.readLong();
+ if(_domainSizes == null || _domainSizes.length == 0) {
+ _domainSizes = new int[in.readInt()];
+ for(int i = 0; i < _domainSizes.length; i++)
+ _domainSizes[i] = in.readInt();
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if(this == o)
+ return true;
+ if(o == null || getClass() != o.getClass())
+ return false;
+ EncoderDummycode that = (EncoderDummycode) o;
+ return _dummycodedLength == that._dummycodedLength
+ && Arrays.equals(_domainSizes, that._domainSizes);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(_dummycodedLength);
+ result = 31 * result + Arrays.hashCode(_domainSizes);
+ return result;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index af929be..95c6009 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -38,6 +38,17 @@ import static
org.apache.sysds.runtime.util.CollectionUtils.unionDistinct;
public class EncoderFactory
{
+ public enum EncoderType {
+ Bin,
+ Dummycode,
+ FeatureHash,
+ MVImpute,
+ Omit,
+ PassThrough,
+ Recode
+ };
+
+
public static Encoder createEncoder(String spec, String[] colnames, int
clen, FrameBlock meta) {
return createEncoder(spec, colnames,
UtilFunctions.nCopies(clen, ValueType.STRING), meta);
}
@@ -142,6 +153,40 @@ public class EncoderFactory
return encoder;
}
+ public static int getEncoderType(Encoder encoder) {
+ if( encoder instanceof EncoderBin )
+ return EncoderType.Bin.ordinal();
+ else if( encoder instanceof EncoderDummycode )
+ return EncoderType.Dummycode.ordinal();
+ else if( encoder instanceof EncoderFeatureHash )
+ return EncoderType.FeatureHash.ordinal();
+ else if( encoder instanceof EncoderMVImpute )
+ return EncoderType.MVImpute.ordinal();
+ else if( encoder instanceof EncoderOmit )
+ return EncoderType.Omit.ordinal();
+ else if( encoder instanceof EncoderPassThrough )
+ return EncoderType.PassThrough.ordinal();
+ else if( encoder instanceof EncoderRecode )
+ return EncoderType.Recode.ordinal();
+ throw new DMLRuntimeException("Unsupported encoder type: "
+ + encoder.getClass().getCanonicalName());
+ }
+
+ public static Encoder createInstance(int type) {
+ EncoderType etype = EncoderType.values()[type];
+ switch(etype) {
+ case Bin: return new EncoderBin();
+ case Dummycode: return new EncoderDummycode();
+ case FeatureHash: return new EncoderFeatureHash();
+ case MVImpute: return new EncoderMVImpute();
+ case Omit: return new EncoderOmit();
+ case PassThrough: return new EncoderPassThrough();
+ case Recode: return new EncoderRecode();
+ default:
+ throw new DMLRuntimeException("Unsupported
encoder type: " + etype);
+ }
+ }
+
private static HashMap<String, Integer> getColumnPositions(String[]
colnames) {
HashMap<String, Integer> ret = new HashMap<>();
for(int i=0; i<colnames.length; i++)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
index 3b6503b..cf6d67a 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFeatureHash.java
@@ -19,14 +19,18 @@
package org.apache.sysds.runtime.transform.encode;
-import org.apache.sysds.runtime.util.IndexRange;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
+
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
/**
* Class used for feature hashing transformation of frames.
@@ -144,4 +148,16 @@ public class EncoderFeatureHash extends Encoder
_K = UtilFunctions.parseToLong(meta.get(0,
colID-1).toString());
}
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeLong(_K);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ _K = in.readLong();
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
index 534d16c..2391733 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderMVImpute.java
@@ -19,8 +19,12 @@
package org.apache.sysds.runtime.transform.encode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -29,9 +33,6 @@ import java.util.Map.Entry;
import java.util.Set;
import java.util.stream.Collectors;
-import org.apache.wink.json4j.JSONArray;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
import org.apache.commons.lang.ArrayUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
@@ -43,6 +44,9 @@ import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.wink.json4j.JSONArray;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
public class EncoderMVImpute extends Encoder
{
@@ -60,7 +64,7 @@ public class EncoderMVImpute extends Encoder
private String[] _replacementList = null; // replacements: for
global_mean, mean; and for global_mode, recode id of mode category
private List<Integer> _rcList = null;
private HashMap<Integer,HashMap<String,Long>> _hist = null;
-
+
public String[] getReplacements() { return _replacementList; }
public KahanObject[] getMeans() { return _meanList; }
@@ -333,6 +337,82 @@ public class EncoderMVImpute extends Encoder
public HashMap<String,Long> getHistogram( int colID ) {
return _hist.get(colID);
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ for(int i = 0; i < _colList.length; i++) {
+ out.writeByte(_mvMethodList[i].ordinal());
+ out.writeLong(_countList[i]);
+ }
+
+ List<String> notNullReplacements = new
ArrayList<>(Arrays.asList(_replacementList));
+ notNullReplacements.removeAll(Collections.singleton(null));
+ out.writeInt(notNullReplacements.size());
+ for(int i = 0; i < _replacementList.length; i++)
+ if(_replacementList[i] != null) {
+ out.writeInt(i);
+ out.writeUTF(_replacementList[i]);
+ }
+
+ out.writeInt(_rcList.size());
+ for(int rc: _rcList)
+ out.writeInt(rc);
+
+ int histSize = _hist == null ? 0 : _hist.size();
+ out.writeInt(histSize);
+ if (histSize > 0)
+ for(Entry<Integer,HashMap<String,Long>> e1 :
_hist.entrySet()) {
+ out.writeInt(e1.getKey());
+ out.writeInt(e1.getValue().size());
+ for(Entry<String, Long> e2 :
e1.getValue().entrySet()) {
+ out.writeUTF(e2.getKey());
+ out.writeLong(e2.getValue());
+ }
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+
+ _mvMethodList = new MVMethod[_colList.length];
+ _countList = new long[_colList.length];
+ _meanList = new KahanObject[_colList.length];
+ _replacementList = new String[_colList.length];
+
+ for(int i = 0; i < _colList.length; i++) {
+ _mvMethodList[i] = MVMethod.values()[in.readByte()];
+ _countList[i] = in.readLong();
+ _meanList[i] = new KahanObject(0, 0);
+ }
+
+ int size4 = in.readInt();
+ for(int i = 0; i < size4; i++) {
+ int index = in.readInt();
+ _replacementList[index] = in.readUTF();
+ }
+
+ int size3 = in.readInt();
+ _rcList = new ArrayList<>();
+ for(int j = 0; j < size3; j++)
+ _rcList.add(in.readInt());
+
+ _hist = new HashMap<>();
+ int size1 = in.readInt();
+ for(int i = 0; i < size1; i++) {
+ Integer key1 = in.readInt();
+ int size2 = in.readInt();
+
+ HashMap<String, Long> maps = new HashMap<>();
+ for(int j = 0; j < size2; j++){
+ String key2 = in.readUTF();
+ Long value = in.readLong();
+ maps.put(key2, value);
+ }
+ _hist.put(key1, maps);
+ }
+ }
private static class ColInfo {
MVMethod _method;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
index bbc83e4..5c4f1ff 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderOmit.java
@@ -19,7 +19,12 @@
package org.apache.sysds.runtime.transform.encode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.Arrays;
+import java.util.Objects;
+
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -195,4 +200,41 @@ public class EncoderOmit extends Encoder
public void initMetaData(FrameBlock meta) {
//do nothing
}
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeBoolean(_federated);
+ out.writeInt(_rmRows.length);
+ for(boolean r : _rmRows)
+ out.writeBoolean(r);
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ if(_rmRows.length == 0) {
+ _federated = in.readBoolean();
+ _rmRows = new boolean[in.readInt()];
+ for(int i = 0; i < _rmRows.length; i++)
+ _rmRows[i] = in.readBoolean();
+ }
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if(this == o)
+ return true;
+ if(o == null || getClass() != o.getClass())
+ return false;
+ EncoderOmit that = (EncoderOmit) o;
+ return _federated == that._federated && Arrays.equals(_rmRows,
that._rmRows);
+ }
+
+ @Override
+ public int hashCode() {
+ int result = Objects.hash(_federated);
+ result = 31 * result + Arrays.hashCode(_rmRows);
+ return result;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
index ac414e9..0a603af 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderPassThrough.java
@@ -64,8 +64,8 @@ public class EncoderPassThrough extends Encoder
for( int i=0; i<in.getNumRows(); i++ ) {
Object val = in.get(i, col);
out.quickSetValue(i, col,
(val==null||(vt==ValueType.STRING
- && val.toString().isEmpty())) ?
Double.NaN :
-
UtilFunctions.objectToDouble(vt, val));
+ && val.toString().isEmpty())) ?
Double.NaN :
+ UtilFunctions.objectToDouble(vt, val));
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
index ada0dff..23b166f 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderRecode.java
@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
- *
+ *
* http://www.apache.org/licenses/LICENSE-2.0
- *
+ *
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
@@ -19,6 +19,9 @@
package org.apache.sysds.runtime.transform.encode;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
@@ -27,55 +30,56 @@ import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.Objects;
-import org.apache.sysds.runtime.util.IndexRange;
-import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
+import org.apache.sysds.runtime.util.IndexRange;
+import org.apache.wink.json4j.JSONException;
+import org.apache.wink.json4j.JSONObject;
public class EncoderRecode extends Encoder
{
private static final long serialVersionUID = 8213163881283341874L;
-
+
//test property to ensure consistent encoding for local and federated
public static boolean SORT_RECODE_MAP = false;
-
- //recode maps and custom map for partial recode maps
+
+ //recode maps and custom map for partial recode maps
private HashMap<Integer, HashMap<String, Long>> _rcdMaps = new
HashMap<>();
private HashMap<Integer, HashSet<Object>> _rcdMapsPart = null;
-
+
public EncoderRecode(JSONObject parsedSpec, String[] colnames, int
clen, int minCol, int maxCol)
- throws JSONException
+ throws JSONException
{
super(null, clen);
_colList = TfMetaUtils.parseJsonIDList(parsedSpec, colnames,
TfMethod.RECODE.toString(), minCol, maxCol);
}
-
+
private EncoderRecode(int[] colList, int clen) {
super(colList, clen);
}
-
+
public EncoderRecode() {
this(new int[0], 0);
}
-
+
private EncoderRecode(int[] colList, int clen, HashMap<Integer,
HashMap<String, Long>> rcdMaps) {
super(colList, clen);
_rcdMaps = rcdMaps;
}
-
- public HashMap<Integer, HashMap<String,Long>> getCPRecodeMaps() {
- return _rcdMaps;
+
+ public HashMap<Integer, HashMap<String,Long>> getCPRecodeMaps() {
+ return _rcdMaps;
}
-
- public HashMap<Integer, HashSet<Object>> getCPRecodeMapsPartial() {
- return _rcdMapsPart;
+
+ public HashMap<Integer, HashSet<Object>> getCPRecodeMapsPartial() {
+ return _rcdMapsPart;
}
-
+
public void sortCPRecodeMaps() {
for( HashMap<String,Long> map : _rcdMaps.values() ) {
String[] keys= map.keySet().toArray(new String[0]);
@@ -85,23 +89,23 @@ public class EncoderRecode extends Encoder
putCode(map, key);
}
}
-
+
private long lookupRCDMap(int colID, String key) {
if( !_rcdMaps.containsKey(colID) )
return -1; //empty recode map
Long tmp = _rcdMaps.get(colID).get(key);
return (tmp!=null) ? tmp : -1;
}
-
+
@Override
public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
if( !isApplicable() )
return out;
-
- //build and apply recode maps
+
+ //build and apply recode maps
build(in);
apply(in, out);
-
+
return out;
}
@@ -112,11 +116,11 @@ public class EncoderRecode extends Encoder
Iterator<String[]> iter = in.getStringRowIterator(_colList);
while( iter.hasNext() ) {
- String[] row = iter.next();
+ String[] row = iter.next();
for( int j=0; j<_colList.length; j++ ) {
int colID = _colList[j]; //1-based
//allocate column map if necessary
- if( !_rcdMaps.containsKey(colID) )
+ if( !_rcdMaps.containsKey(colID) )
_rcdMaps.put(colID, new
HashMap<String,Long>());
//probe and build column map
HashMap<String,Long> map = _rcdMaps.get(colID);
@@ -125,14 +129,14 @@ public class EncoderRecode extends Encoder
putCode(map, key);
}
}
-
+
if( SORT_RECODE_MAP ) {
sortCPRecodeMaps();
}
}
/**
- * Put the code into the map with the provided key. The code depends on
the type of encoder.
+ * Put the code into the map with the provided key. The code depends on
the type of encoder.
* @param map column map
* @param key key for the new entry
*/
@@ -151,13 +155,13 @@ public class EncoderRecode extends Encoder
public void buildPartial(FrameBlock in) {
if( !isApplicable() )
return;
-
+
//construct partial recode map (tokens w/o codes)
//iterate over columns for sequential access
for( int j=0; j<_colList.length; j++ ) {
int colID = _colList[j]; //1-based
//allocate column map if necessary
- if( !_rcdMapsPart.containsKey(colID) )
+ if( !_rcdMapsPart.containsKey(colID) )
_rcdMapsPart.put(colID, new HashSet<>());
HashSet<Object> map = _rcdMapsPart.get(colID);
//probe and build column map
@@ -167,8 +171,9 @@ public class EncoderRecode extends Encoder
map.remove(null);
map.remove("");
}
+// _rcdMapsPart = null;
}
-
+
@Override
public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
//apply recode maps column wise
@@ -182,7 +187,7 @@ public class EncoderRecode extends Encoder
(code >= 0) ? code : Double.NaN);
}
}
-
+
return out;
}
@@ -212,7 +217,7 @@ public class EncoderRecode extends Encoder
public void mergeAt(Encoder other, int row, int col) {
if(other instanceof EncoderRecode) {
mergeColumnInfo(other, col);
-
+
// merge together overlapping columns or add new columns
EncoderRecode otherRec = (EncoderRecode) other;
for (int otherColID : other._colList) {
@@ -220,7 +225,7 @@ public class EncoderRecode extends Encoder
//allocate column map if necessary
if( !_rcdMaps.containsKey(colID) )
_rcdMaps.put(colID, new HashMap<>());
-
+
HashMap<String, Long> otherMap =
otherRec._rcdMaps.get(otherColID);
if(otherMap != null) {
// for each column, add all non present
recode values
@@ -236,10 +241,10 @@ public class EncoderRecode extends Encoder
}
super.mergeAt(other, row, col);
}
-
+
public int[] numDistinctValues() {
int[] numDistinct = new int[_colList.length];
-
+
for( int j=0; j<_colList.length; j++ ) {
int colID = _colList[j]; //1-based
numDistinct[j] = _rcdMaps.get(colID).size();
@@ -251,16 +256,16 @@ public class EncoderRecode extends Encoder
public FrameBlock getMetaData(FrameBlock meta) {
if( !isApplicable() )
return meta;
-
+
//inverse operation to initRecodeMaps
-
+
//allocate output rows
int maxDistinct = 0;
for( int j=0; j<_colList.length; j++ )
if( _rcdMaps.containsKey(_colList[j]) )
maxDistinct = Math.max(maxDistinct,
_rcdMaps.get(_colList[j]).size());
meta.ensureAllocatedColumns(maxDistinct);
-
+
//create compact meta data representation
StringBuilder sb = new StringBuilder(); //for reuse
for( int j=0; j<_colList.length; j++ ) {
@@ -268,37 +273,68 @@ public class EncoderRecode extends Encoder
int rowID = 0;
if( _rcdMaps.containsKey(_colList[j]) )
for( Entry<String, Long> e :
_rcdMaps.get(colID).entrySet() ) {
- meta.set(rowID++, colID-1,
-
constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
+ meta.set(rowID++, colID-1,
+
constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
}
meta.getColumnMetadata(colID-1).setNumDistinct(
_rcdMaps.get(colID).size());
}
-
+
return meta;
}
-
+
/**
- * Construct the recodemaps from the given input frame for all
+ * Construct the recodemaps from the given input frame for all
* columns registered for recode.
- *
+ *
* @param meta frame block
*/
@Override
public void initMetaData( FrameBlock meta ) {
if( meta == null || meta.getNumRows()<=0 )
return;
-
+
for( int j=0; j<_colList.length; j++ ) {
int colID = _colList[j]; //1-based
_rcdMaps.put(colID, meta.getRecodeMap(colID-1));
}
}
-
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+ out.writeInt(_rcdMaps.size());
+ for(Entry<Integer, HashMap<String,Long>> e1 :
_rcdMaps.entrySet()) {
+ out.writeInt(e1.getKey());
+ out.writeInt(e1.getValue().size());
+ for(Entry<String, Long> e2 : e1.getValue().entrySet()) {
+ out.writeUTF(e2.getKey());
+ out.writeLong(e2.getValue());
+ }
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ int size1 = in.readInt();
+ for(int i = 0; i < size1; i++) {
+ Integer key1 = in.readInt();
+ int size2 = in.readInt();
+ HashMap<String, Long> maps = new HashMap<>();
+ for(int j = 0; j < size2; j++){
+ String key2 = in.readUTF();
+ Long value = in.readLong();
+ maps.put(key2, value);
+ }
+ _rcdMaps.put(key1, maps);
+ }
+ }
+
/**
- * Returns the Recode map entry which consists of concatenation of
code, delimiter and token.
- *
+ * Returns the Recode map entry which consists of concatenation of
code, delimiter and token.
+ *
* @param token is part of Recode map
* @param code is code for token
* @return the concatenation of token and code with delimiter in between
@@ -307,23 +343,38 @@ public class EncoderRecode extends Encoder
StringBuilder sb = new StringBuilder(token.length()+16);
return constructRecodeMapEntry(token, code, sb);
}
-
+
private static String constructRecodeMapEntry(String token, Long code,
StringBuilder sb) {
sb.setLength(0); //reset reused string builder
return sb.append(token).append(Lop.DATATYPE_PREFIX)
.append(code.longValue()).toString();
}
-
+
/**
* Splits a Recode map entry into its token and code.
- *
+ *
* @param value concatenation of token and code with delimiter in
between
* @return string array of token and code
*/
public static String[] splitRecodeMapEntry(String value) {
// Instead of using splitCSV which is forcing string with
RFC-4180 format,
- // using Lop.DATATYPE_PREFIX separator to split token and code
+ // using Lop.DATATYPE_PREFIX separator to split token and code
int pos = value.toString().lastIndexOf(Lop.DATATYPE_PREFIX);
return new String[] {value.substring(0, pos),
value.substring(pos+1)};
}
+
+ @Override
+ public boolean equals(Object o) {
+ if(this == o)
+ return true;
+ if(o == null || getClass() != o.getClass())
+ return false;
+ EncoderRecode that = (EncoderRecode) o;
+ return Objects.equals(_rcdMaps, that._rcdMaps);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(_rcdMaps);
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeDecodeTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeDecodeTest.java
index b92a32f..a6f98f9 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeDecodeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeDecodeTest.java
@@ -187,4 +187,4 @@ public class TransformFrameEncodeDecodeTest extends
AutomatedTestBase
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
}
-}
\ No newline at end of file
+}