This is an automated email from the ASF dual-hosted git repository. echauchot pushed a commit to branch spark-runner_structured-streaming in repository https://gitbox.apache.org/repos/asf/beam.git
commit 031754c73bcd72302f10730c8266e5ead0714bc1 Author: Etienne Chauchot <echauc...@apache.org> AuthorDate: Mon Aug 26 14:32:17 2019 +0200 Wrap Beam Coders into Spark Encoders using ExpressionEncoder: serialization part --- .../translation/helpers/EncoderHelpers.java | 245 +++++++++++++++++++++ 1 file changed, 245 insertions(+) diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java index d44fe27..b072803 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/structuredstreaming/translation/helpers/EncoderHelpers.java @@ -17,11 +17,40 @@ */ package org.apache.beam.runners.spark.structuredstreaming.translation.helpers; +import static org.apache.spark.sql.types.DataTypes.BinaryType; + +import java.lang.reflect.Array; +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import org.apache.beam.runners.spark.structuredstreaming.translation.SchemaHelpers; +import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.spark.sql.Encoder; import org.apache.spark.sql.Encoders; +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal; +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder; +import org.apache.spark.sql.catalyst.expressions.BoundReference; +import org.apache.spark.sql.catalyst.expressions.Cast; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NonSQLExpression; +import org.apache.spark.sql.catalyst.expressions.UnaryExpression; +import org.apache.spark.sql.catalyst.expressions.codegen.Block; +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator; +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext; +import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode; +import org.apache.spark.sql.catalyst.expressions.codegen.ExprValue; +import org.apache.spark.sql.catalyst.expressions.codegen.SimpleExprValue; +import org.apache.spark.sql.catalyst.expressions.codegen.VariableValue; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.ObjectType; +import scala.StringContext; import scala.Tuple2; +import scala.collection.JavaConversions; +import scala.collection.Seq; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; /** {@link Encoders} utility class. */ public class EncoderHelpers { @@ -64,4 +93,220 @@ public class EncoderHelpers { --------- Bridges from Beam Coders to Spark Encoders */ + /** A way to construct encoders using generic serializers. */ + private <T> Encoder<T> fromBeamCoder(Coder<T> coder, Class<T> claz){ + + List<Expression> serialiserList = new ArrayList<>(); + serialiserList.add(new EncodeUsingBeamCoder<>(claz, coder)); + ClassTag<T> classTag = ClassTag$.MODULE$.apply(claz); + return new ExpressionEncoder<>( + SchemaHelpers.binarySchema(), + false, + JavaConversions.collectionAsScalaIterable(serialiserList).toSeq(), + new DecodeUsingBeamCoder<>(classTag, coder), classTag); + +/* + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + serializer = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + deserializer = + DecodeUsingSerializer[T]( + Cast(GetColumnByOrdinal(0, BinaryType), BinaryType), + classTag[T], + kryo = useKryo), + clsTag = classTag[T] + ) +*/ + } + + private static class EncodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression { + + private Class<T> claz; + private Coder<T> beamCoder; + private Expression child; + + private EncodeUsingBeamCoder( Class<T> claz, Coder<T> beamCoder) { + this.claz = claz; + this.beamCoder = beamCoder; + this.child = new BoundReference(0, new ObjectType(claz), true); + } + + @Override public Expression child() { + return child; + } + + @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) { + // Code to serialize. + ExprCode input = child.genCode(ctx); + String javaType = CodeGenerator.javaType(dataType()); + String outputStream = "ByteArrayOutputStream baos = new ByteArrayOutputStream();"; + + String serialize = outputStream + "$beamCoder.encode(${input.value}, baos); baos.toByteArray();"; + + String outside = "final $javaType output = ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize;"; + + List<String> instructions = new ArrayList<>(); + instructions.add(outside); + + Seq<String> parts = JavaConversions.collectionAsScalaIterable(instructions).toSeq(); + StringContext stringContext = new StringContext(parts); + Block.BlockHelper blockHelper = new Block.BlockHelper(stringContext); + List<Object> args = new ArrayList<>(); + args.add(new VariableValue("beamCoder", Coder.class)); + args.add(new SimpleExprValue("input.value", ExprValue.class)); + args.add(new VariableValue("javaType", String.class)); + args.add(new SimpleExprValue("input.isNull", Boolean.class)); + args.add(new SimpleExprValue("CodeGenerator.defaultValue(dataType)", String.class)); + args.add(new VariableValue("$serialize", String.class)); + Block code = blockHelper.code(JavaConversions.collectionAsScalaIterable(args).toSeq()); + + return ev.copy(input.code().$plus(code), input.isNull(), new VariableValue("output", Array.class)); + } + + @Override public DataType dataType() { + return BinaryType; + } + + @Override public Object productElement(int n) { + if (n == 0) { + return this; + } else { + throw new IndexOutOfBoundsException(String.valueOf(n)); + } + } + + @Override public int productArity() { + //TODO test with spark Encoders if the arity of 1 is ok + return 1; + } + + @Override public boolean canEqual(Object that) { + return (that instanceof EncodeUsingBeamCoder); + } + + @Override public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + EncodeUsingBeamCoder<?> that = (EncodeUsingBeamCoder<?>) o; + return claz.equals(that.claz) && beamCoder.equals(that.beamCoder); + } + + @Override public int hashCode() { + return Objects.hash(super.hashCode(), claz, beamCoder); + } + } + + /*case class EncodeUsingSerializer(child: Expression, kryo: Boolean) + extends UnaryExpression with NonSQLExpression with SerializerSupport { + + override def nullSafeEval(input: Any): Any = { + serializerInstance.serialize(input).array() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val serializer = addImmutableScodererializerIfNeeded(ctx) + // Code to serialize. + val input = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val serialize = s"$serializer.serialize(${input.value}, null).array()" + + val code = input.code + code""" + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $serialize; + """ + ev.copy(code = code, isNull = input.isNull) + } + + override def dataType: DataType = BinaryType + }*/ + + private static class DecodeUsingBeamCoder<T> extends UnaryExpression implements NonSQLExpression{ + + private ClassTag<T> classTag; + private Coder<T> beamCoder; + + private DecodeUsingBeamCoder(ClassTag<T> classTag, Coder<T> beamCoder) { + this.classTag = classTag; + this.beamCoder = beamCoder; + } + + @Override public Expression child() { + return new Cast(new GetColumnByOrdinal(0, BinaryType), BinaryType); + } + + @Override public ExprCode doGenCode(CodegenContext ctx, ExprCode ev) { + return null; + } + + @Override public DataType dataType() { + return new ObjectType(classTag.runtimeClass()); + } + + @Override public Object productElement(int n) { + if (n == 0) { + return this; + } else { + throw new IndexOutOfBoundsException(String.valueOf(n)); + } + } + + @Override public int productArity() { + //TODO test with spark Encoders if the arity of 1 is ok + return 1; + } + + @Override public boolean canEqual(Object that) { + return (that instanceof DecodeUsingBeamCoder); + } + + @Override public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + DecodeUsingBeamCoder<?> that = (DecodeUsingBeamCoder<?>) o; + return classTag.equals(that.classTag) && beamCoder.equals(that.beamCoder); + } + + @Override public int hashCode() { + return Objects.hash(super.hashCode(), classTag, beamCoder); + } + } +/* +case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) + extends UnaryExpression with NonSQLExpression with SerializerSupport { + + override def nullSafeEval(input: Any): Any = { + val inputBytes = java.nio.ByteBuffer.wrap(input.asInstanceOf[Array[Byte]]) + serializerInstance.deserialize(inputBytes) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val serializer = addImmutableSerializerIfNeeded(ctx) + // Code to deserialize. + val input = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val deserialize = + s"($javaType) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null)" + + val code = input.code + code""" + final $javaType ${ev.value} = + ${input.isNull} ? ${CodeGenerator.defaultValue(dataType)} : $deserialize; + """ + ev.copy(code = code, isNull = input.isNull) + } + + override def dataType: DataType = ObjectType(tag.runtimeClass) + } +*/ + }