This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit 5fa6331e0356953870e6ed614b0ce5e5c801fab1 Author: Etienne Chauchot <echauc...@apache.org> AuthorDate: Mon Aug 26 15:22:12 2019 +0200 Wrap Beam Coders into Spark Encoders using ExpressionEncoder: deserialization part + Fix EncoderHelpers.fromBeamCoder() visibility --- .../translation/helpers/EncoderHelpers.java | 64 ++++++++++++++++++---- 1 file changed, 52 insertions(+), 12 deletions(-) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index b072803..ab24e37 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -19,6 +19,8 @@ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; import static org.apache.spark.sql.types.DataTypes.BinaryType; +import java.io.ByteArrayInputStream; +import java.io.IOException; import java.lang.reflect.Array; import java.util.ArrayList; import java.util.List; @@ -94,7 +96,7 @@ public class EncoderHelpers { */ /** A way to construct encoders using generic serializers. */ - private <T> Encoder<T> fromBeamCoder(Coder<T> coder, Class<T> claz){ + public static <T> Encoder<T> fromBeamCoder(Coder<T> coder, Class<T> claz){ List<Expression> serialiserList = new ArrayList<>(); serialiserList.add(new EncodeUsingBeamCoder<>(claz, coder)); @@ -103,7 +105,8 @@ public class EncoderHelpers { SchemaHelpers.binarySchema(), false, JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(), - new DecodeUsingBeamCoder<>(classTag, coder), classTag); + new DecodeUsingBeamCoder<>(claz, coder), + classTag); /* ExpressionEncoder[T]( @@ -150,8 +153,8 @@ public class EncoderHelpers { List<String> instructions = new ArrayList<>(); instructions.add(outside); - Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq(); + StringContext stringContext = new StringContext(parts); Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext); List<Object> args = new ArrayList<>(); @@ -160,7 +163,7 @@ public class EncoderHelpers { args.add(new VariableValue("javaType", String.class)); args.add(new SimpleExprValue("input.isNull", Boolean.class)); args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class)); - args.add(new VariableValue("$serialize", String.class)); + args.add(new VariableValue("serialize", String.class)); Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq()); return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", Array.class)); @@ -229,24 +232,61 @@ public class EncoderHelpers { private static class DecodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression{ - private ClassTag<T> classTag; + private Class<T> claz; private Coder<T> beamCoder; + private Expression child; - private DecodeUsingBeamCoder(ClassTag<T> classTag, Coder<T> beamCoder) { - this.classTag = classTag; + private DecodeUsingBeamCoder(Class<T> claz, Coder<T> beamCoder) { + this.claz = claz; this.beamCoder = beamCoder; + this.child = new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType); } @Override public Expression child() { - return new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType); + return child; } @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) { - return null; + // Code to deserialize. + ExprCode input = child.genCode(ctx); + String javaType = CodeGenerator.javaType(dataType()); + + String inputStream = "ByteArrayInputStream bais = new ByteArrayInputStream(${input.value});"; + String deserialize = inputStream + "($javaType) $beamCoder.decode(bais);"; + + String outside = "final $javaType output = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize;"; + + List<String> instructions = new ArrayList<>(); + instructions.add(outside); + Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq(); + + StringContext stringContext = new StringContext(parts); + Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext); + List<Object> args = new ArrayList<>(); + args.add(new SimpleExprValue("input.value", ExprValue.class)); + args.add(new VariableValue("javaType", String.class)); + args.add(new VariableValue("beamCoder", Coder.class)); + args.add(new SimpleExprValue("input.isNull", Boolean.class)); + args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class)); + args.add(new VariableValue("deserialize", String.class)); + Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq()); + + return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", claz)); + + } + + @Override public Object nullSafeEval(Object input) { + try { + return beamCoder.decode(new ByteArrayInputStream((byte[]) input)); + } catch (IOException e) { + throw new IllegalStateException("Error decoding bytes for coder: " + beamCoder, e); + } } @Override public DataType dataType() { - return new ObjectType(classTag.runtimeClass()); +// return new ObjectType(classTag.runtimeClass()); + //TODO does type erasure impose to use classTag.runtimeClass() ? + return new ObjectType(claz); } @Override public Object productElement(int n) { @@ -274,11 +314,11 @@ public class EncoderHelpers { return false; } DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o; - return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder); + return claz.equals(that.claz) && beamCoder.equals(that.beamCoder); } @Override public int hashCode() { - return Objects.hash(super.hashCode(), classTag, beamCoder); + return Objects.hash(super.hashCode(), claz, beamCoder); } } /*