This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 96fc294 [SYSTEMDS-2878] test encoder serialization
96fc294 is described below
commit 96fc29434ff9877fbf170c09e58666d8040e6727
Author: Olga <[email protected]>
AuthorDate: Thu Feb 25 17:48:49 2021 +0200
[SYSTEMDS-2878] test encoder serialization
Closes #1192
---
.../transform/EncoderSerializationTest.java | 139 +++++++++++++++++++++
1 file changed, 139 insertions(+)
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/EncoderSerializationTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/EncoderSerializationTest.java
new file mode 100644
index 0000000..324c575
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/EncoderSerializationTest.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.util.List;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.transform.encode.Encoder;
+import org.apache.sysds.runtime.transform.encode.EncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+public class EncoderSerializationTest extends AutomatedTestBase
+{
+ private final static int rows = 2791;
+ private final static int cols = 8;
+
+ private final static Types.ValueType[] schemaStrings = new
Types.ValueType[]{
+ Types.ValueType.STRING, Types.ValueType.STRING,
Types.ValueType.STRING, Types.ValueType.STRING,
+ Types.ValueType.STRING, Types.ValueType.STRING,
Types.ValueType.STRING, Types.ValueType.STRING};
+ private final static Types.ValueType[] schemaMixed = new
Types.ValueType[]{
+ Types.ValueType.STRING, Types.ValueType.FP64,
Types.ValueType.INT64, Types.ValueType.BOOLEAN,
+ Types.ValueType.STRING, Types.ValueType.FP64,
Types.ValueType.INT64, Types.ValueType.BOOLEAN};
+
+
+ public enum TransformType {
+ RECODE,
+ DUMMY
+ }
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testComposite1() { runTransformSerTest(TransformType.DUMMY,
schemaStrings); }
+
+ @Test
+ public void testComposite2() {
runTransformSerTest(TransformType.RECODE, schemaMixed); }
+
+ @Test
+ public void testComposite3() {
runTransformSerTest(TransformType.RECODE, schemaStrings); }
+
+ @Test
+ public void testComposite4() { runTransformSerTest(TransformType.DUMMY,
schemaMixed); }
+
+
+ private void runTransformSerTest(TransformType type, Types.ValueType[]
schema) {
+ //data generation
+ double[][] A = getRandomMatrix(rows, cols, -10, 10, 0.9, 8234);
+
+ //init data frame
+ FrameBlock frame = new FrameBlock(schema);
+
+ //init data frame
+ Object[] row = new Object[schema.length];
+ for( int i=0; i < rows; i++) {
+ for( int j=0; j<schema.length; j++ )
+ A[i][j] =
UtilFunctions.objectToDouble(schema[j],
+ row[j] =
UtilFunctions.doubleToObject(schema[j], A[i][j]));
+ frame.appendRow(row);
+ }
+
+ String spec = "";
+ if(type == TransformType.DUMMY)
+ spec = "{\n \"ids\": true\n, \"dummycode\":[ 2, 7, 8, 1
]\n\n}";
+ else if(type == TransformType.RECODE)
+ spec = "{\n \"ids\": true\n, \"recode\":[ 2, 7, 1, 8
]\n\n}";
+
+
+ frame.setSchema(schema);
+ String[] cnames = frame.getColumnNames();
+
+ Encoder encoderIn = EncoderFactory.createEncoder(spec, cnames,
frame.getNumColumns(), null);
+ EncoderComposite encoderOut;
+
+ // serialization and deserialization
+ encoderOut = (EncoderComposite) serializeDeserialize(encoderIn);
+ // compare
+ Assert.assertArrayEquals(encoderIn.getColList(),
encoderOut.getColList());
+ Assert.assertEquals(encoderIn.getNumCols(),
encoderOut.getNumCols());
+
+ List<Encoder> eListIn = ((EncoderComposite)
encoderIn).getEncoders();
+ List<Encoder> eListOut = encoderOut.getEncoders();
+ for(int i = 0; i < eListIn.size(); i++) {
+ Assert.assertArrayEquals(eListIn.get(i).getColList(),
eListOut.get(i).getColList());
+ Assert.assertEquals(eListIn.get(i).getNumCols(),
eListOut.get(i).getNumCols());
+ }
+ }
+
+ private Encoder serializeDeserialize(Encoder encoderIn) {
+ try {
+ ByteArrayOutputStream bos = new ByteArrayOutputStream();
+ ObjectOutputStream oos = new ObjectOutputStream(bos);
+ oos.writeObject(encoderIn);
+ oos.flush();
+ byte[] encoderBytes = bos.toByteArray();
+
+ ByteArrayInputStream bis = new
ByteArrayInputStream(encoderBytes);
+ ObjectInput in = new ObjectInputStream(bis);
+ Encoder encoderOut = (Encoder) in.readObject();
+
+ return encoderOut;
+ }
+ catch(IOException | ClassNotFoundException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+}
\ No newline at end of file