This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 5fa6331e0356953870e6ed614b0ce5e5c801fab1
Author: Etienne Chauchot <echauc...@apache.org>
AuthorDate: Mon Aug 26 15:22:12 2019 +0200

    Wrap Beam Coders into Spark Encoders using ExpressionEncoder: 
deserialization part
    
    + Fix EncoderHelpers.fromBeamCoder() visibility
---
 .../translation/helpers/EncoderHelpers.java        | 64 ++++++++++++++++++----
 1 file changed, 52 insertions(+), 12 deletions(-)

diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index b072803..ab24e37 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -19,6 +19,8 @@ package 
org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
 import static org.apache.spark.sql.types.DataTypes.BinaryType;
 
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
 import java.lang.reflect.Array;
 import java.util.ArrayList;
 import java.util.List;
@@ -94,7 +96,7 @@ public class EncoderHelpers {
   */
 
   /** A way to construct encoders using generic serializers. */
-  private <T> Encoder<T> fromBeamCoder(Coder<T> coder, Class<T> claz){
+  public static <T> Encoder<T> fromBeamCoder(Coder<T> coder, Class<T> claz){
 
     List<Expression> serialiserList = new ArrayList<>();
     serialiserList.add(new EncodeUsingBeamCoder<>(claz, coder));
@@ -103,7 +105,8 @@ public class EncoderHelpers {
         SchemaHelpers.binarySchema(),
         false,
         JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(),
-        new DecodeUsingBeamCoder<>(classTag, coder), classTag);
+        new DecodeUsingBeamCoder<>(claz, coder),
+        classTag);
 
 /*
     ExpressionEncoder[T](
@@ -150,8 +153,8 @@ public class EncoderHelpers {
 
       List<String> instructions = new ArrayList<>();
       instructions.add(outside);
-
       Seq<String> parts = 
JavaConversions.collectionAsScalaIterable(instructions).toSeq();
+
       StringContext stringContext = new StringContext(parts);
       Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext);
       List<Object> args = new ArrayList<>();
@@ -160,7 +163,7 @@ public class EncoderHelpers {
       args.add(new VariableValue("javaType", String.class));
       args.add(new SimpleExprValue("input.isNull", Boolean.class));
       args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", 
String.class));
-      args.add(new VariableValue("$serialize", String.class));
+      args.add(new VariableValue("serialize", String.class));
       Block code = 
blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq());
 
       return ev.copy(input.code().$plus(code), input.isNull(), new 
VariableValue("output", Array.class));
@@ -229,24 +232,61 @@ public class EncoderHelpers {
 
   private static class DecodeUsingBeamCoder<T> extends UnaryExpression 
implements  NonSQLExpression{
 
-    private ClassTag<T> classTag;
+    private Class<T> claz;
     private Coder<T> beamCoder;
+    private Expression child;
 
-    private DecodeUsingBeamCoder(ClassTag<T> classTag, Coder<T> beamCoder) {
-      this.classTag = classTag;
+    private DecodeUsingBeamCoder(Class<T> claz, Coder<T> beamCoder) {
+      this.claz = claz;
       this.beamCoder = beamCoder;
+      this.child = new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType);
     }
 
     @Override public Expression child() {
-      return new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType);
+      return child;
     }
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
-      return null;
+      // Code to deserialize.
+      ExprCode input = child.genCode(ctx);
+      String javaType = CodeGenerator.javaType(dataType());
+
+      String inputStream = "ByteArrayInputStream bais = new 
ByteArrayInputStream(${input.value});";
+      String deserialize = inputStream + "($javaType) 
$beamCoder.decode(bais);";
+
+      String outside = "final $javaType output = ${input.isNull} ? 
${CodeGenerator.defaultValue(dataType)} : $deserialize;";
+
+      List<String> instructions = new ArrayList<>();
+      instructions.add(outside);
+      Seq<String> parts = 
JavaConversions.collectionAsScalaIterable(instructions).toSeq();
+
+      StringContext stringContext = new StringContext(parts);
+      Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext);
+      List<Object> args = new ArrayList<>();
+      args.add(new SimpleExprValue("input.value", ExprValue.class));
+      args.add(new VariableValue("javaType", String.class));
+      args.add(new VariableValue("beamCoder", Coder.class));
+      args.add(new SimpleExprValue("input.isNull", Boolean.class));
+      args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", 
String.class));
+      args.add(new VariableValue("deserialize", String.class));
+      Block code = 
blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq());
+
+      return ev.copy(input.code().$plus(code), input.isNull(), new 
VariableValue("output", claz));
+
+    }
+
+    @Override public Object nullSafeEval(Object input) {
+      try {
+        return beamCoder.decode(new ByteArrayInputStream((byte[]) input));
+      } catch (IOException e) {
+        throw new IllegalStateException("Error decoding bytes for coder: " + 
beamCoder, e);
+      }
     }
 
     @Override public DataType dataType() {
-      return new ObjectType(classTag.runtimeClass());
+//      return new ObjectType(classTag.runtimeClass());
+      //TODO does type erasure impose to use classTag.runtimeClass() ?
+      return new ObjectType(claz);
     }
 
     @Override public Object productElement(int n) {
@@ -274,11 +314,11 @@ public class EncoderHelpers {
         return false;
       }
       DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
-      return classTag.equals(that.classTag) && 
beamCoder.equals(that.beamCoder);
+      return claz.equals(that.claz) && beamCoder.equals(that.beamCoder);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), classTag, beamCoder);
+      return Objects.hash(super.hashCode(), claz, beamCoder);
     }
   }
 /*

Reply via email to