Repository: flink
Updated Branches:
  refs/heads/master 0e92b6632 -> 0ba528c71


[FLINK-7902] Use TypeSerializer in TwoPhaseCommitSinkFunctions

We use custom serializers to ensure that we have control over the
serialization format, which allows us easier evolution of the format in
the future.

This also implements custom serializers for KafkaProducer11, the only
TwoPhaseCommitSinkFunction we currently have.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0ba528c7
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0ba528c7
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0ba528c7

Branch: refs/heads/master
Commit: 0ba528c71e35858a043bd513ead37800262f7e0c
Parents: 944a63c
Author: Aljoscha Krettek <[email protected]>
Authored: Thu Oct 26 14:39:25 2017 +0200
Committer: Aljoscha Krettek <[email protected]>
Committed: Wed Nov 1 09:03:53 2017 +0100

----------------------------------------------------------------------
 .../flink-connector-kafka-0.11/pom.xml          |   9 +
 .../connectors/kafka/FlinkKafkaProducer011.java | 248 +++++++++++++++-
 ...linkKafkaProducer011StateSerializerTest.java | 106 +++++++
 .../sink/TwoPhaseCommitSinkFunction.java        | 281 ++++++++++++++++++-
 .../sink/TwoPhaseCommitSinkFunctionTest.java    |   9 +-
 5 files changed, 628 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0ba528c7/flink-connectors/flink-connector-kafka-0.11/pom.xml
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kafka-0.11/pom.xml 
b/flink-connectors/flink-connector-kafka-0.11/pom.xml
index c41f697..e58d36d 100644
--- a/flink-connectors/flink-connector-kafka-0.11/pom.xml
+++ b/flink-connectors/flink-connector-kafka-0.11/pom.xml
@@ -87,6 +87,15 @@ under the License.
 
                <dependency>
                        <groupId>org.apache.flink</groupId>
+                       <artifactId>flink-core</artifactId>
+                       <version>${project.version}</version>
+                       <scope>test</scope>
+                       <type>test-jar</type>
+               </dependency>
+
+
+               <dependency>
+                       <groupId>org.apache.flink</groupId>
                        
<artifactId>flink-streaming-java_${scala.binary.version}</artifactId>
                        <version>${project.version}</version>
                        <scope>test</scope>

http://git-wip-us.apache.org/repos/asf/flink/blob/0ba528c7/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java
 
b/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java
index 593e002..5f557d2 100644
--- 
a/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java
+++ 
b/flink-connectors/flink-connector-kafka-0.11/src/main/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011.java
@@ -17,14 +17,18 @@
 
 package org.apache.flink.streaming.connectors.kafka;
 
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.time.Time;
-import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.TypeSerializerSingleton;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.metrics.MetricGroup;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
@@ -411,7 +415,7 @@ public class FlinkKafkaProducer011<IN>
                        Optional<FlinkKafkaPartitioner<IN>> customPartitioner,
                        Semantic semantic,
                        int kafkaProducersPoolSize) {
-               super(TypeInformation.of(new 
TypeHint<State<KafkaTransactionState, KafkaTransactionContext>>() {}));
+               super(new TransactionStateSerializer(), new 
ContextStateSerializer());
 
                this.defaultTopicId = checkNotNull(defaultTopicId, 
"defaultTopicId is null");
                this.schema = checkNotNull(serializationSchema, 
"serializationSchema is null");
@@ -958,6 +962,8 @@ public class FlinkKafkaProducer011<IN>
        /**
         * State for handling transactions.
         */
+       @VisibleForTesting
+       @Internal
        static class KafkaTransactionState {
 
                private final transient FlinkKafkaProducer<byte[], byte[]> 
producer;
@@ -970,38 +976,260 @@ public class FlinkKafkaProducer011<IN>
                final short epoch;
 
                KafkaTransactionState(String transactionalId, 
FlinkKafkaProducer<byte[], byte[]> producer) {
-                       this.producer = producer;
-                       this.transactionalId = transactionalId;
-                       this.producerId = producer.getProducerId();
-                       this.epoch = producer.getEpoch();
+                       this(transactionalId, producer.getProducerId(), 
producer.getEpoch(), producer);
                }
 
                KafkaTransactionState(FlinkKafkaProducer<byte[], byte[]> 
producer) {
+                       this(null, -1, (short) -1, producer);
+               }
+
+               KafkaTransactionState(
+                               String transactionalId,
+                               long producerId,
+                               short epoch,
+                               FlinkKafkaProducer<byte[], byte[]> producer) {
+                       this.transactionalId = transactionalId;
+                       this.producerId = producerId;
+                       this.epoch = epoch;
                        this.producer = producer;
-                       this.transactionalId = null;
-                       this.producerId = -1;
-                       this.epoch = -1;
                }
 
                @Override
                public String toString() {
                        return String.format("%s [transactionalId=%s]", 
this.getClass().getSimpleName(), transactionalId);
                }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+
+                       KafkaTransactionState that = (KafkaTransactionState) o;
+
+                       if (producerId != that.producerId) {
+                               return false;
+                       }
+                       if (epoch != that.epoch) {
+                               return false;
+                       }
+                       return transactionalId != null ? 
transactionalId.equals(that.transactionalId) : that.transactionalId == null;
+               }
+
+               @Override
+               public int hashCode() {
+                       int result = transactionalId != null ? 
transactionalId.hashCode() : 0;
+                       result = 31 * result + (int) (producerId ^ (producerId 
>>> 32));
+                       result = 31 * result + (int) epoch;
+                       return result;
+               }
        }
 
        /**
         * Context associated to this instance of the {@link 
FlinkKafkaProducer011}. User for keeping track of the
         * transactionalIds.
         */
-       static class KafkaTransactionContext {
+       @VisibleForTesting
+       @Internal
+       public static class KafkaTransactionContext {
                final Set<String> transactionalIds;
 
                KafkaTransactionContext(Set<String> transactionalIds) {
+                       checkNotNull(transactionalIds);
                        this.transactionalIds = transactionalIds;
                }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+
+                       KafkaTransactionContext that = 
(KafkaTransactionContext) o;
+
+                       return transactionalIds.equals(that.transactionalIds);
+               }
+
+               @Override
+               public int hashCode() {
+                       return transactionalIds.hashCode();
+               }
+       }
+
+       /**
+        * {@link org.apache.flink.api.common.typeutils.TypeSerializer} for
+        * {@link KafkaTransactionState}.
+        */
+       @VisibleForTesting
+       @Internal
+       public static class TransactionStateSerializer extends 
TypeSerializerSingleton<KafkaTransactionState> {
+
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public boolean isImmutableType() {
+                       return true;
+               }
+
+               @Override
+               public KafkaTransactionState createInstance() {
+                       return null;
+               }
+
+               @Override
+               public KafkaTransactionState copy(KafkaTransactionState from) {
+                       return from;
+               }
+
+               @Override
+               public KafkaTransactionState copy(
+                       KafkaTransactionState from,
+                       KafkaTransactionState reuse) {
+                       return from;
+               }
+
+               @Override
+               public int getLength() {
+                       return -1;
+               }
+
+               @Override
+               public void serialize(
+                               KafkaTransactionState record,
+                               DataOutputView target) throws IOException {
+                       if (record.transactionalId == null) {
+                               target.writeBoolean(false);
+                       } else {
+                               target.writeBoolean(true);
+                               target.writeUTF(record.transactionalId);
+                       }
+                       target.writeLong(record.producerId);
+                       target.writeShort(record.epoch);
+               }
+
+               @Override
+               public KafkaTransactionState deserialize(DataInputView source) 
throws IOException {
+                       String transactionalId = null;
+                       if (source.readBoolean()) {
+                               transactionalId = source.readUTF();
+                       }
+                       long producerId = source.readLong();
+                       short epoch = source.readShort();
+                       return new KafkaTransactionState(transactionalId, 
producerId, epoch, null);
+               }
+
+               @Override
+               public KafkaTransactionState deserialize(
+                               KafkaTransactionState reuse,
+                               DataInputView source) throws IOException {
+                       return deserialize(source);
+               }
+
+               @Override
+               public void copy(
+                               DataInputView source, DataOutputView target) 
throws IOException {
+                       boolean hasTransactionalId = source.readBoolean();
+                       target.writeBoolean(hasTransactionalId);
+                       if (hasTransactionalId) {
+                               target.writeUTF(source.readUTF());
+                       }
+                       target.writeLong(source.readLong());
+                       target.writeShort(source.readShort());
+               }
+
+               @Override
+               public boolean canEqual(Object obj) {
+                       return obj instanceof TransactionStateSerializer;
+               }
        }
 
+       /**
+        * {@link org.apache.flink.api.common.typeutils.TypeSerializer} for
+        * {@link KafkaTransactionContext}.
+        */
+       @VisibleForTesting
+       @Internal
+       public static class ContextStateSerializer extends 
TypeSerializerSingleton<KafkaTransactionContext> {
 
+               private static final long serialVersionUID = 1L;
+
+               @Override
+               public boolean isImmutableType() {
+                       return true;
+               }
+
+               @Override
+               public KafkaTransactionContext createInstance() {
+                       return null;
+               }
+
+               @Override
+               public KafkaTransactionContext copy(KafkaTransactionContext 
from) {
+                       return from;
+               }
+
+               @Override
+               public KafkaTransactionContext copy(
+                               KafkaTransactionContext from,
+                               KafkaTransactionContext reuse) {
+                       return from;
+               }
+
+               @Override
+               public int getLength() {
+                       return -1;
+               }
+
+               @Override
+               public void serialize(
+                               KafkaTransactionContext record,
+                               DataOutputView target) throws IOException {
+                       int numIds = record.transactionalIds.size();
+                       target.writeInt(numIds);
+                       for (String id : record.transactionalIds) {
+                               target.writeUTF(id);
+                       }
+               }
+
+               @Override
+               public KafkaTransactionContext deserialize(DataInputView 
source) throws IOException {
+                       int numIds = source.readInt();
+                       Set<String> ids = new HashSet<>(numIds);
+                       for (int i = 0; i < numIds; i++) {
+                               ids.add(source.readUTF());
+                       }
+                       return new KafkaTransactionContext(ids);
+               }
+
+               @Override
+               public KafkaTransactionContext deserialize(
+                               KafkaTransactionContext reuse,
+                               DataInputView source) throws IOException {
+                       return deserialize(source);
+               }
+
+               @Override
+               public void copy(
+                               DataInputView source,
+                               DataOutputView target) throws IOException {
+                       int numIds = source.readInt();
+                       target.writeInt(numIds);
+                       for (int i = 0; i < numIds; i++) {
+                               target.writeUTF(source.readUTF());
+                       }
+               }
+
+               @Override
+               public boolean canEqual(Object obj) {
+                       return obj instanceof ContextStateSerializer;
+               }
+       }
 
        static class ProducersPool implements Closeable {
                private final LinkedBlockingDeque<FlinkKafkaProducer<byte[], 
byte[]>> pool = new LinkedBlockingDeque<>();

http://git-wip-us.apache.org/repos/asf/flink/blob/0ba528c7/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011StateSerializerTest.java
----------------------------------------------------------------------
diff --git 
a/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011StateSerializerTest.java
 
b/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011StateSerializerTest.java
new file mode 100644
index 0000000..c6a873b
--- /dev/null
+++ 
b/flink-connectors/flink-connector-kafka-0.11/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaProducer011StateSerializerTest.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kafka;
+
+import org.apache.flink.api.common.typeutils.SerializerTestBase;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import 
org.apache.flink.streaming.api.functions.sink.TwoPhaseCommitSinkFunction;
+
+import java.util.Collections;
+import java.util.Optional;
+
+/**
+ * A test for the {@link TypeSerializer TypeSerializers} used for the Kafka 
producer state.
+ */
+public class FlinkKafkaProducer011StateSerializerTest
+       extends SerializerTestBase<
+               TwoPhaseCommitSinkFunction.State<
+                       FlinkKafkaProducer011.KafkaTransactionState,
+                       FlinkKafkaProducer011.KafkaTransactionContext>> {
+
+       @Override
+       protected TypeSerializer<
+               TwoPhaseCommitSinkFunction.State<
+                       FlinkKafkaProducer011.KafkaTransactionState,
+                       FlinkKafkaProducer011.KafkaTransactionContext>> 
createSerializer() {
+               return new TwoPhaseCommitSinkFunction.StateSerializer<>(
+                       new FlinkKafkaProducer011.TransactionStateSerializer(),
+                       new FlinkKafkaProducer011.ContextStateSerializer());
+       }
+
+       @Override
+       protected Class<TwoPhaseCommitSinkFunction.State<
+                       FlinkKafkaProducer011.KafkaTransactionState,
+                       FlinkKafkaProducer011.KafkaTransactionContext>> 
getTypeClass() {
+               return (Class) TwoPhaseCommitSinkFunction.State.class;
+       }
+
+       @Override
+       protected int getLength() {
+               return -1;
+       }
+
+       @Override
+       protected TwoPhaseCommitSinkFunction.State<
+               FlinkKafkaProducer011.KafkaTransactionState,
+               FlinkKafkaProducer011.KafkaTransactionContext>[] getTestData() {
+               return new TwoPhaseCommitSinkFunction.State[] {
+                       new TwoPhaseCommitSinkFunction.State<
+                               FlinkKafkaProducer011.KafkaTransactionState,
+                               FlinkKafkaProducer011.KafkaTransactionContext>(
+                                       new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null),
+                                       Collections.emptyList(),
+                                       Optional.empty()),
+                       new TwoPhaseCommitSinkFunction.State<
+                               FlinkKafkaProducer011.KafkaTransactionState,
+                               FlinkKafkaProducer011.KafkaTransactionContext>(
+                               new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null),
+                               Collections.singletonList(new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null)),
+                               Optional.empty()),
+                       new TwoPhaseCommitSinkFunction.State<
+                               FlinkKafkaProducer011.KafkaTransactionState,
+                               FlinkKafkaProducer011.KafkaTransactionContext>(
+                               new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null),
+                               Collections.emptyList(),
+                               Optional.of(new 
FlinkKafkaProducer011.KafkaTransactionContext(Collections.emptySet()))),
+                       new TwoPhaseCommitSinkFunction.State<
+                               FlinkKafkaProducer011.KafkaTransactionState,
+                               FlinkKafkaProducer011.KafkaTransactionContext>(
+                               new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null),
+                               Collections.emptyList(),
+                               Optional.of(new 
FlinkKafkaProducer011.KafkaTransactionContext(Collections.singleton("hello")))),
+                       new TwoPhaseCommitSinkFunction.State<
+                               FlinkKafkaProducer011.KafkaTransactionState,
+                               FlinkKafkaProducer011.KafkaTransactionContext>(
+                               new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null),
+                               Collections.singletonList(new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null)),
+                               Optional.of(new 
FlinkKafkaProducer011.KafkaTransactionContext(Collections.emptySet()))),
+                       new TwoPhaseCommitSinkFunction.State<
+                               FlinkKafkaProducer011.KafkaTransactionState,
+                               FlinkKafkaProducer011.KafkaTransactionContext>(
+                               new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null),
+                               Collections.singletonList(new 
FlinkKafkaProducer011.KafkaTransactionState("fake", 1L, (short) 42, null)),
+                               Optional.of(new 
FlinkKafkaProducer011.KafkaTransactionContext(Collections.singleton("hello"))))
+               };
+       }
+
+       @Override
+       public void testInstantiate() {
+               // this serializer does not support instantiation
+       }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/0ba528c7/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
index 8c11753..952f298 100644
--- 
a/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
+++ 
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunction.java
@@ -17,11 +17,23 @@
 
 package org.apache.flink.streaming.api.functions.sink;
 
+import org.apache.flink.annotation.Internal;
 import org.apache.flink.annotation.PublicEvolving;
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeHint;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.CompatibilityResult;
+import org.apache.flink.api.common.typeutils.CompatibilityUtil;
+import 
org.apache.flink.api.common.typeutils.CompositeTypeSerializerConfigSnapshot;
+import org.apache.flink.api.common.typeutils.TypeDeserializerAdapter;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot;
+import org.apache.flink.api.common.typeutils.UnloadableDummyTypeSerializer;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.runtime.state.CheckpointListener;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
@@ -32,6 +44,7 @@ import org.slf4j.LoggerFactory;
 
 import javax.annotation.Nullable;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Iterator;
 import java.util.LinkedHashMap;
@@ -40,6 +53,7 @@ import java.util.Map;
 import java.util.Optional;
 
 import static java.util.Objects.requireNonNull;
+import static org.apache.flink.util.Preconditions.checkNotNull;
 import static org.apache.flink.util.Preconditions.checkState;
 
 /**
@@ -79,19 +93,18 @@ public abstract class TwoPhaseCommitSinkFunction<IN, TXN, 
CONTEXT>
         * TwoPhaseCommitSinkFunction(TypeInformation.of(new 
TypeHint<State<TXN, CONTEXT>>() {}));
         * }
         * </pre>
-        * @param stateTypeInformation {@link TypeInformation} for POJO holding 
state of opened transactions.
-        */
-       public TwoPhaseCommitSinkFunction(TypeInformation<State<TXN, CONTEXT>> 
stateTypeInformation) {
-               this(new ListStateDescriptor<State<TXN, CONTEXT>>("state", 
stateTypeInformation));
-       }
-
-       /**
-        * Instantiate {@link TwoPhaseCommitSinkFunction} with custom state 
descriptors.
         *
-        * @param stateDescriptor descriptor for transactions POJO.
+        * @param transactionSerializer {@link TypeSerializer} for the 
transaction type of this sink
+        * @param contextSerializer {@link TypeSerializer} for the context type 
of this sink
         */
-       public TwoPhaseCommitSinkFunction(ListStateDescriptor<State<TXN, 
CONTEXT>> stateDescriptor) {
-               this.stateDescriptor = requireNonNull(stateDescriptor, 
"stateDescriptor is null");
+       public TwoPhaseCommitSinkFunction(
+                       TypeSerializer<TXN> transactionSerializer,
+                       TypeSerializer<CONTEXT> contextSerializer) {
+
+               this.stateDescriptor =
+                       new ListStateDescriptor<>(
+                                       "state",
+                                       new 
StateSerializer<>(transactionSerializer, contextSerializer));
        }
 
        protected Optional<CONTEXT> initializeUserContext() {
@@ -324,7 +337,9 @@ public abstract class TwoPhaseCommitSinkFunction<IN, TXN, 
CONTEXT>
        /**
         * State POJO class coupling pendingTransaction, context and 
pendingCommitTransactions.
         */
-       public static class State<TXN, CONTEXT> {
+       @VisibleForTesting
+       @Internal
+       public static final class State<TXN, CONTEXT> {
                protected TXN pendingTransaction;
                protected List<TXN> pendingCommitTransactions = new 
ArrayList<>();
                protected Optional<CONTEXT> context;
@@ -361,5 +376,247 @@ public abstract class TwoPhaseCommitSinkFunction<IN, TXN, 
CONTEXT>
                public void setContext(Optional<CONTEXT> context) {
                        this.context = context;
                }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+
+                       State<?, ?> state = (State<?, ?>) o;
+
+                       if (pendingTransaction != null ? 
!pendingTransaction.equals(state.pendingTransaction) : state.pendingTransaction 
!= null) {
+                               return false;
+                       }
+                       if (pendingCommitTransactions != null ? 
!pendingCommitTransactions.equals(state.pendingCommitTransactions) : 
state.pendingCommitTransactions != null) {
+                               return false;
+                       }
+                       return context != null ? context.equals(state.context) 
: state.context == null;
+               }
+
+               @Override
+               public int hashCode() {
+                       int result = pendingTransaction != null ? 
pendingTransaction.hashCode() : 0;
+                       result = 31 * result + (pendingCommitTransactions != 
null ? pendingCommitTransactions.hashCode() : 0);
+                       result = 31 * result + (context != null ? 
context.hashCode() : 0);
+                       return result;
+               }
+       }
+
+       /**
+        * Custom {@link TypeSerializer} for the sink state.
+        */
+       @VisibleForTesting
+       @Internal
+       public static final class StateSerializer<TXN, CONTEXT> extends 
TypeSerializer<State<TXN, CONTEXT>> {
+
+               private static final long serialVersionUID = 1L;
+
+               private final TypeSerializer<TXN> transactionSerializer;
+               private final TypeSerializer<CONTEXT> contextSerializer;
+
+               public StateSerializer(
+                               TypeSerializer<TXN> transactionSerializer,
+                               TypeSerializer<CONTEXT> contextSerializer) {
+                       this.transactionSerializer = 
checkNotNull(transactionSerializer);
+                       this.contextSerializer = 
checkNotNull(contextSerializer);
+               }
+
+               @Override
+               public boolean isImmutableType() {
+                       return transactionSerializer.isImmutableType() && 
contextSerializer.isImmutableType();
+               }
+
+               @Override
+               public TypeSerializer<State<TXN, CONTEXT>> duplicate() {
+                       return new StateSerializer<>(
+                                       transactionSerializer.duplicate(), 
contextSerializer.duplicate());
+               }
+
+               @Override
+               public State<TXN, CONTEXT> createInstance() {
+                       return null;
+               }
+
+               @Override
+               public State<TXN, CONTEXT> copy(State<TXN, CONTEXT> from) {
+                       TXN copiedPendingTransaction = 
transactionSerializer.copy(from.getPendingTransaction());
+                       List<TXN> copiedPendingCommitTransactions = new 
ArrayList<>();
+                       for (TXN txn : from.getPendingCommitTransactions()) {
+                               
copiedPendingCommitTransactions.add(transactionSerializer.copy(txn));
+                       }
+                       Optional<CONTEXT> copiedContext = 
from.getContext().map(contextSerializer::copy);
+                       return new State<>(copiedPendingTransaction, 
copiedPendingCommitTransactions, copiedContext);
+               }
+
+               @Override
+               public State<TXN, CONTEXT> copy(
+                               State<TXN, CONTEXT> from,
+                               State<TXN, CONTEXT> reuse) {
+                       return copy(from);
+               }
+
+               @Override
+               public int getLength() {
+                       return -1;
+               }
+
+               @Override
+               public void serialize(
+                               State<TXN, CONTEXT> record,
+                               DataOutputView target) throws IOException {
+                       
transactionSerializer.serialize(record.getPendingTransaction(), target);
+                       List<TXN> pendingCommitTransactions = 
record.getPendingCommitTransactions();
+                       target.writeInt(pendingCommitTransactions.size());
+                       for (TXN pendingTxn : pendingCommitTransactions) {
+                               transactionSerializer.serialize(pendingTxn, 
target);
+                       }
+                       Optional<CONTEXT> context = record.getContext();
+                       if (context.isPresent()) {
+                               target.writeBoolean(true);
+                               contextSerializer.serialize(context.get(), 
target);
+                       } else {
+                               target.writeBoolean(false);
+                       }
+               }
+
+               @Override
+               public State<TXN, CONTEXT> deserialize(DataInputView source) 
throws IOException {
+                       TXN pendingTxn = 
transactionSerializer.deserialize(source);
+                       int numPendingCommitTxns = source.readInt();
+                       List<TXN> pendingCommitTxns = new 
ArrayList<>(numPendingCommitTxns);
+                       for (int i = 0; i < numPendingCommitTxns; i++) {
+                               
pendingCommitTxns.add(transactionSerializer.deserialize(source));
+                       }
+                       Optional<CONTEXT> context = Optional.empty();
+                       boolean hasContext = source.readBoolean();
+                       if (hasContext) {
+                               context = 
Optional.of(contextSerializer.deserialize(source));
+                       }
+                       return new State<>(pendingTxn, pendingCommitTxns, 
context);
+               }
+
+               @Override
+               public State<TXN, CONTEXT> deserialize(
+                               State<TXN, CONTEXT> reuse,
+                               DataInputView source) throws IOException {
+                       return deserialize(source);
+               }
+
+               @Override
+               public void copy(
+                               DataInputView source, DataOutputView target) 
throws IOException {
+                       TXN pendingTxn = 
transactionSerializer.deserialize(source);
+                       transactionSerializer.serialize(pendingTxn, target);
+                       int numPendingCommitTxns = source.readInt();
+                       target.writeInt(numPendingCommitTxns);
+                       for (int i = 0; i < numPendingCommitTxns; i++) {
+                               TXN pendingCommitTxn = 
transactionSerializer.deserialize(source);
+                               
transactionSerializer.serialize(pendingCommitTxn, target);
+                       }
+                       boolean hasContext = source.readBoolean();
+                       target.writeBoolean(hasContext);
+                       if (hasContext) {
+                               CONTEXT context = 
contextSerializer.deserialize(source);
+                               contextSerializer.serialize(context, target);
+                       }
+               }
+
+               @Override
+               public boolean canEqual(Object obj) {
+                       return obj instanceof StateSerializer;
+               }
+
+               @Override
+               public boolean equals(Object o) {
+                       if (this == o) {
+                               return true;
+                       }
+                       if (o == null || getClass() != o.getClass()) {
+                               return false;
+                       }
+
+                       StateSerializer<?, ?> that = (StateSerializer<?, ?>) o;
+
+                       if 
(!transactionSerializer.equals(that.transactionSerializer)) {
+                               return false;
+                       }
+                       return contextSerializer.equals(that.contextSerializer);
+               }
+
+               @Override
+               public int hashCode() {
+                       int result = transactionSerializer.hashCode();
+                       result = 31 * result + contextSerializer.hashCode();
+                       return result;
+               }
+
+               @Override
+               public TypeSerializerConfigSnapshot snapshotConfiguration() {
+                       return new 
StateSerializerConfigSnapshot<>(transactionSerializer, contextSerializer);
+               }
+
+               @Override
+               public CompatibilityResult<State<TXN, CONTEXT>> 
ensureCompatibility(
+                               TypeSerializerConfigSnapshot configSnapshot) {
+                       if (configSnapshot instanceof 
StateSerializerConfigSnapshot) {
+                               List<Tuple2<TypeSerializer<?>, 
TypeSerializerConfigSnapshot>> previousSerializersAndConfigs =
+                                               
((StateSerializerConfigSnapshot) 
configSnapshot).getNestedSerializersAndConfigs();
+
+                               CompatibilityResult<TXN> txnCompatResult = 
CompatibilityUtil.resolveCompatibilityResult(
+                                               
previousSerializersAndConfigs.get(0).f0,
+                                               
UnloadableDummyTypeSerializer.class,
+                                               
previousSerializersAndConfigs.get(0).f1,
+                                               transactionSerializer);
+
+                               CompatibilityResult<CONTEXT> 
contextCompatResult = CompatibilityUtil.resolveCompatibilityResult(
+                                               
previousSerializersAndConfigs.get(1).f0,
+                                               
UnloadableDummyTypeSerializer.class,
+                                               
previousSerializersAndConfigs.get(1).f1,
+                                               contextSerializer);
+
+                               if (!txnCompatResult.isRequiresMigration() && 
!contextCompatResult.isRequiresMigration()) {
+                                       return CompatibilityResult.compatible();
+                               } else {
+                                       if 
(txnCompatResult.getConvertDeserializer() != null && 
contextCompatResult.getConvertDeserializer() != null) {
+                                               return 
CompatibilityResult.requiresMigration(
+                                                               new 
StateSerializer<>(
+                                                                               
new TypeDeserializerAdapter<>(txnCompatResult.getConvertDeserializer()),
+                                                                               
new TypeDeserializerAdapter<>(contextCompatResult.getConvertDeserializer())));
+                                       }
+                               }
+                       }
+
+                       return CompatibilityResult.requiresMigration();
+               }
+       }
+
+       /**
+        * {@link TypeSerializerConfigSnapshot} for sink state. This has to be 
public so that
+        * it can be deserialized/instantiated, should not be used anywhere 
outside
+        * {@code TwoPhaseCommitSinkFunction}.
+        */
+       @Internal
+       public static final class StateSerializerConfigSnapshot<TXN, CONTEXT>
+                       extends CompositeTypeSerializerConfigSnapshot {
+
+               private static final int VERSION = 1;
+
+               /** This empty nullary constructor is required for 
deserializing the configuration. */
+               public StateSerializerConfigSnapshot() {}
+
+               public StateSerializerConfigSnapshot(
+                               TypeSerializer<TXN> transactionSerializer,
+                               TypeSerializer<CONTEXT> contextSerializer) {
+                       super(transactionSerializer, contextSerializer);
+               }
+
+               @Override
+               public int getVersion() {
+                       return VERSION;
+               }
        }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0ba528c7/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java
----------------------------------------------------------------------
diff --git 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java
 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java
index 3043512..20abf58 100644
--- 
a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java
+++ 
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/functions/sink/TwoPhaseCommitSinkFunctionTest.java
@@ -17,9 +17,10 @@
 
 package org.apache.flink.streaming.api.functions.sink;
 
-import org.apache.flink.api.common.typeinfo.TypeHint;
-import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.api.common.typeutils.base.VoidSerializer;
+import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
 import org.apache.flink.streaming.api.operators.StreamSink;
 import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
@@ -125,7 +126,9 @@ public class TwoPhaseCommitSinkFunctionTest {
                private final File targetDirectory;
 
                public FileBasedSinkFunction(File tmpDirectory, File 
targetDirectory) {
-                       super(TypeInformation.of(new 
TypeHint<State<FileTransaction, Void>>() {}));
+                       super(
+                               new KryoSerializer<>(FileTransaction.class, new 
ExecutionConfig()),
+                               VoidSerializer.INSTANCE);
 
                        if (!tmpDirectory.isDirectory() || 
!targetDirectory.isDirectory()) {
                                throw new IllegalArgumentException();

Reply via email to