This is an automated email from the ASF dual-hosted git repository.

echauchot pushed a commit to branch spark-runner_structured-streaming
in repository https://gitbox.apache.org/repos/asf/beam.git

commit c4a44649013b04d04eeca3756ff181f59632fab7
Author: Etienne Chauchot <echauc...@apache.org>
AuthorDate: Thu Sep 5 14:20:30 2019 +0200

    Remove lazy init of beam coder because there is no generic way on 
instanciating a beam coder
---
 .../translation/helpers/EncoderHelpers.java        | 68 +++++++---------------
 .../structuredstreaming/utils/EncodersTest.java    |  2 +-
 2 files changed, 21 insertions(+), 49 deletions(-)

diff --git 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
index 0751c4c..3f7c102 100644
--- 
a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
+++ 
b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java
@@ -18,7 +18,6 @@
 package org.apache.beam.runners.spark.structuredstreaming.translation.helpers;
 
 import static org.apache.spark.sql.types.DataTypes.BinaryType;
-import static scala.compat.java8.JFunction.func;
 
 import java.io.ByteArrayInputStream;
 import java.util.ArrayList;
@@ -26,7 +25,6 @@ import java.util.List;
 import java.util.Objects;
 import 
org.apache.beam.runners.spark.structuredstreaming.translation.SchemaHelpers;
 import org.apache.beam.sdk.coders.Coder;
-import org.apache.beam.sdk.coders.VarIntCoder;
 import org.apache.beam.sdk.util.WindowedValue;
 import org.apache.beam.sdk.values.KV;
 import org.apache.spark.sql.Encoder;
@@ -92,17 +90,17 @@ public class EncoderHelpers {
   */
 
   /** A way to construct encoders using generic serializers. */
-  public static <T> Encoder<T> fromBeamCoder(Class<? extends Coder<T>> 
coderClass/*, Class<T> claz*/){
+  public static <T> Encoder<T> fromBeamCoder(Coder<T> beamCoder/*, Class<T> 
claz*/){
 
     List<Expression> serialiserList = new ArrayList<>();
     Class<T> claz = (Class<T>) Object.class;
-    serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new 
ObjectType(claz), true), (Class<Coder<T>>)coderClass));
+    serialiserList.add(new EncodeUsingBeamCoder<>(new BoundReference(0, new 
ObjectType(claz), true), beamCoder));
     ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz);
     return new ExpressionEncoder<>(
         SchemaHelpers.binarySchema(),
         false,
         JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(),
-        new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, 
BinaryType), BinaryType), classTag, (Class<Coder<T>>)coderClass),
+        new DecodeUsingBeamCoder<>(new Cast(new GetColumnByOrdinal(0, 
BinaryType), BinaryType), classTag, beamCoder),
         classTag);
 
 /*
@@ -125,11 +123,11 @@ public class EncoderHelpers {
   public static class EncodeUsingBeamCoder<T> extends UnaryExpression 
implements NonSQLExpression {
 
     private Expression child;
-    private Class<Coder<T>> coderClass;
+    private Coder<T> beamCoder;
 
-    public EncodeUsingBeamCoder(Expression child, Class<Coder<T>> coderClass) {
+    public EncodeUsingBeamCoder(Expression child, Coder<T> beamCoder) {
       this.child = child;
-      this.coderClass = coderClass;
+      this.beamCoder = beamCoder;
     }
 
     @Override public Expression child() {
@@ -138,7 +136,7 @@ public class EncoderHelpers {
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
       // Code to serialize.
-      String beamCoder = lazyInitBeamCoder(ctx, coderClass);
+      String accessCode = ctx.addReferenceObj("beamCoder", beamCoder, 
beamCoder.getClass().getName());
       ExprCode input = child.genCode(ctx);
 
       /*
@@ -172,7 +170,7 @@ public class EncoderHelpers {
       args.add(ev.value());
       args.add(input.isNull());
       args.add(ev.value());
-      args.add(beamCoder);
+      args.add(accessCode);
       args.add(input.value());
       args.add(ev.value());
       Block code = (new Block.BlockHelper(sc))
@@ -191,7 +189,7 @@ public class EncoderHelpers {
         case 0:
           return child;
         case 1:
-          return coderClass;
+          return beamCoder;
         default:
           throw new ArrayIndexOutOfBoundsException("productElement out of 
bounds");
       }
@@ -213,11 +211,11 @@ public class EncoderHelpers {
         return false;
       }
       EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o;
-      return coderClass.equals(that.coderClass);
+      return beamCoder.equals(that.beamCoder);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), coderClass);
+      return Objects.hash(super.hashCode(), beamCoder);
     }
   }
 
@@ -249,12 +247,12 @@ public class EncoderHelpers {
 
     private Expression child;
     private ClassTag<T> classTag;
-    private Class<Coder<T>> coderClass;
+    private Coder<T> beamCoder;
 
-    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, 
Class<Coder<T>> coderClass) {
+    public DecodeUsingBeamCoder(Expression child, ClassTag<T> classTag, 
Coder<T> beamCoder) {
       this.child = child;
       this.classTag = classTag;
-      this.coderClass = coderClass;
+      this.beamCoder = beamCoder;
     }
 
     @Override public Expression child() {
@@ -263,7 +261,7 @@ public class EncoderHelpers {
 
     @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) {
       // Code to deserialize.
-      String beamCoder = lazyInitBeamCoder(ctx, coderClass);
+      String accessCode = ctx.addReferenceObj("beamCoder", beamCoder, 
beamCoder.getClass().getName());
       ExprCode input = child.genCode(ctx);
       String javaType = CodeGenerator.javaType(dataType());
 
@@ -298,7 +296,7 @@ public class EncoderHelpers {
       args.add(input.isNull());
       args.add(CodeGenerator.defaultValue(dataType(), false));
       args.add(javaType);
-      args.add(beamCoder);
+      args.add(accessCode);
       args.add(input.value());
       Block code = (new Block.BlockHelper(sc))
           .code(JavaConversions.collectionAsScalaIterable(args).toSeq());
@@ -309,10 +307,9 @@ public class EncoderHelpers {
 
     @Override public Object nullSafeEval(Object input) {
       try {
-        Coder<T> beamCoder = coderClass.getDeclaredConstructor().newInstance();
         return beamCoder.decode(new ByteArrayInputStream((byte[]) input));
       } catch (Exception e) {
-        throw new IllegalStateException("Error decoding bytes for coder: " + 
coderClass, e);
+        throw new IllegalStateException("Error decoding bytes for coder: " + 
beamCoder, e);
       }
     }
 
@@ -327,7 +324,7 @@ public class EncoderHelpers {
         case 1:
           return classTag;
         case 2:
-          return coderClass;
+          return beamCoder;
         default:
           throw new ArrayIndexOutOfBoundsException("productElement out of 
bounds");
       }
@@ -349,11 +346,11 @@ public class EncoderHelpers {
         return false;
       }
       DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o;
-      return classTag.equals(that.classTag) && 
coderClass.equals(that.coderClass);
+      return classTag.equals(that.classTag) && 
beamCoder.equals(that.beamCoder);
     }
 
     @Override public int hashCode() {
-      return Objects.hash(super.hashCode(), classTag, coderClass);
+      return Objects.hash(super.hashCode(), classTag, beamCoder);
     }
   }
 /*
@@ -384,30 +381,5 @@ case class DecodeUsingSerializer[T](child: Expression, 
tag: ClassTag[T], kryo: B
   }
 */
 
-  private static <T> String lazyInitBeamCoder(CodegenContext ctx, 
Class<Coder<T>> coderClass) {
-    String beamCoderInstance = "beamCoder";
-    ctx.addImmutableStateIfNotExists(coderClass.getName(), beamCoderInstance, 
func(v1 -> {
-      /*
-    CODE GENERATED
-    try {
-    v1 = coderClass.class.getDeclaredConstructor().newInstance();
-    } catch (Exception e) {
-      throw new 
RuntimeException(org.apache.beam.sdk.util.UserCodeException.wrap(e));
-    }
-     */
-      List<String> parts = new ArrayList<>();
-        parts.add("try {");
-        parts.add(" = (");
-      parts.add(") ");
-      parts.add(".class.getDeclaredConstructor().newInstance();} catch 
(Exception e) {throw new 
RuntimeException(org.apache.beam.sdk.util.UserCodeException.wrap(e));}");
-        StringContext sc = new 
StringContext(JavaConversions.collectionAsScalaIterable(parts).toSeq());
-        List<Object> args = new ArrayList<>();
-        args.add(v1);
-      args.add(coderClass.getName());
-      args.add(coderClass.getName());
-        return sc.s(JavaConversions.collectionAsScalaIterable(args).toSeq());
-      }));
-    return beamCoderInstance;
-  }
 
 }
diff --git 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
index 0e38fe1..b3a6273 100644
--- 
a/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
+++ 
b/runners/spark/src/test/java/org/apache/beam/runners/spark/structuredstreaming/utils/EncodersTest.java
@@ -24,7 +24,7 @@ public class EncodersTest {
     data.add(1);
     data.add(2);
     data.add(3);
-    sparkSession.createDataset(data, 
EncoderHelpers.fromBeamCoder(VarIntCoder.class));
+    sparkSession.createDataset(data, 
EncoderHelpers.fromBeamCoder(VarIntCoder.of()));
 //    sparkSession.createDataset(data, EncoderHelpers.genericEncoder());
   }
 }

Reply via email to