This is an automated email from the ASF dual-hosted git repository.

thw pushed a commit to branch release-1.14
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/release-1.14 by this push:
     new cb71c37  [FLINK-24064][connector/common] HybridSource restore from 
savepoint
cb71c37 is described below

commit cb71c3721efa129972c8d9e6ace347b691ee80e1
Author: Thomas Weise <t...@apache.org>
AuthorDate: Tue Aug 31 05:48:40 2021 -0700

    [FLINK-24064][connector/common] HybridSource restore from savepoint
---
 .../connector/base/source/hybrid/HybridSource.java | 19 ++----
 .../source/hybrid/HybridSourceEnumeratorState.java | 17 ++++--
 .../HybridSourceEnumeratorStateSerializer.java     | 37 ++----------
 .../base/source/hybrid/HybridSourceReader.java     | 15 ++---
 .../base/source/hybrid/HybridSourceSplit.java      | 69 ++++++++++++++++------
 .../source/hybrid/HybridSourceSplitEnumerator.java | 48 +++++++++------
 .../source/hybrid/HybridSourceSplitSerializer.java | 39 +++---------
 .../base/source/hybrid/SwitchedSources.java        | 48 +++++++++++++++
 .../base/source/hybrid/HybridSourceReaderTest.java | 36 ++++-------
 .../hybrid/HybridSourceSplitEnumeratorTest.java    | 36 ++++++-----
 .../hybrid/HybridSourceSplitSerializerTest.java    |  6 +-
 .../base/source/reader/mocks/MockBaseSource.java   |  3 +-
 12 files changed, 202 insertions(+), 171 deletions(-)

diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSource.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSource.java
index e3d66de..24acb6a 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSource.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSource.java
@@ -32,9 +32,7 @@ import org.apache.flink.util.Preconditions;
 
 import java.io.Serializable;
 import java.util.ArrayList;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 
 /**
  * Hybrid source that switches underlying sources based on configured source 
chain.
@@ -91,14 +89,11 @@ import java.util.Map;
 public class HybridSource<T> implements Source<T, HybridSourceSplit, 
HybridSourceEnumeratorState> {
 
     private final List<SourceListEntry> sources;
-    // sources are populated per subtask at switch time
-    private final Map<Integer, Source> switchedSources;
 
     /** Protected for subclass, use {@link #builder(Source)} to construct 
source. */
     protected HybridSource(List<SourceListEntry> sources) {
         Preconditions.checkArgument(!sources.isEmpty());
         this.sources = sources;
-        this.switchedSources = new HashMap<>(sources.size());
     }
 
     /** Builder for {@link HybridSource}. */
@@ -116,13 +111,13 @@ public class HybridSource<T> implements Source<T, 
HybridSourceSplit, HybridSourc
     @Override
     public SourceReader<T, HybridSourceSplit> createReader(SourceReaderContext 
readerContext)
             throws Exception {
-        return new HybridSourceReader(readerContext, switchedSources);
+        return new HybridSourceReader(readerContext);
     }
 
     @Override
     public SplitEnumerator<HybridSourceSplit, HybridSourceEnumeratorState> 
createEnumerator(
             SplitEnumeratorContext<HybridSourceSplit> enumContext) {
-        return new HybridSourceSplitEnumerator(enumContext, sources, 0, 
switchedSources, null);
+        return new HybridSourceSplitEnumerator(enumContext, sources, 0, null);
     }
 
     @Override
@@ -131,22 +126,18 @@ public class HybridSource<T> implements Source<T, 
HybridSourceSplit, HybridSourc
             HybridSourceEnumeratorState checkpoint)
             throws Exception {
         return new HybridSourceSplitEnumerator(
-                enumContext,
-                sources,
-                checkpoint.getCurrentSourceIndex(),
-                switchedSources,
-                checkpoint.getWrappedState());
+                enumContext, sources, checkpoint.getCurrentSourceIndex(), 
checkpoint);
     }
 
     @Override
     public SimpleVersionedSerializer<HybridSourceSplit> getSplitSerializer() {
-        return new HybridSourceSplitSerializer(switchedSources);
+        return new HybridSourceSplitSerializer();
     }
 
     @Override
     public SimpleVersionedSerializer<HybridSourceEnumeratorState>
             getEnumeratorCheckpointSerializer() {
-        return new HybridSourceEnumeratorStateSerializer(switchedSources);
+        return new HybridSourceEnumeratorStateSerializer();
     }
 
     /**
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorState.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorState.java
index 2da99ee..95aadde 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorState.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorState.java
@@ -21,18 +21,25 @@ package org.apache.flink.connector.base.source.hybrid;
 /** The state of hybrid source enumerator. */
 public class HybridSourceEnumeratorState {
     private final int currentSourceIndex;
-    private final Object wrappedState;
+    private byte[] wrappedStateBytes;
+    private final int wrappedStateSerializerVersion;
 
-    HybridSourceEnumeratorState(int currentSourceIndex, Object wrappedState) {
+    HybridSourceEnumeratorState(
+            int currentSourceIndex, byte[] wrappedStateBytes, int 
serializerVersion) {
         this.currentSourceIndex = currentSourceIndex;
-        this.wrappedState = wrappedState;
+        this.wrappedStateBytes = wrappedStateBytes;
+        this.wrappedStateSerializerVersion = serializerVersion;
     }
 
     public int getCurrentSourceIndex() {
         return this.currentSourceIndex;
     }
 
-    public Object getWrappedState() {
-        return wrappedState;
+    public byte[] getWrappedState() {
+        return wrappedStateBytes;
+    }
+
+    public int getWrappedStateSerializerVersion() {
+        return wrappedStateSerializerVersion;
     }
 }
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorStateSerializer.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorStateSerializer.java
index 721e010..92c021e 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorStateSerializer.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceEnumeratorStateSerializer.java
@@ -18,17 +18,13 @@
 
 package org.apache.flink.connector.base.source.hybrid;
 
-import org.apache.flink.api.connector.source.Source;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
-import org.apache.flink.util.Preconditions;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
 
 /** The {@link SimpleVersionedSerializer Serializer} for the enumerator state. 
*/
 public class HybridSourceEnumeratorStateSerializer
@@ -36,13 +32,7 @@ public class HybridSourceEnumeratorStateSerializer
 
     private static final int CURRENT_VERSION = 0;
 
-    private final Map<Integer, SimpleVersionedSerializer<Object>> 
cachedSerializers;
-    private final Map<Integer, Source> switchedSources;
-
-    public HybridSourceEnumeratorStateSerializer(Map<Integer, Source> 
switchedSources) {
-        this.switchedSources = switchedSources;
-        this.cachedSerializers = new HashMap<>();
-    }
+    public HybridSourceEnumeratorStateSerializer() {}
 
     @Override
     public int getVersion() {
@@ -54,12 +44,9 @@ public class HybridSourceEnumeratorStateSerializer
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                 DataOutputStream out = new DataOutputStream(baos)) {
             out.writeInt(enumState.getCurrentSourceIndex());
-            SimpleVersionedSerializer<Object> serializer =
-                    serializerOf(enumState.getCurrentSourceIndex());
-            out.writeInt(serializer.getVersion());
-            byte[] enumStateBytes = 
serializer.serialize(enumState.getWrappedState());
-            out.writeInt(enumStateBytes.length);
-            out.write(enumStateBytes);
+            out.writeInt(enumState.getWrappedStateSerializerVersion());
+            out.writeInt(enumState.getWrappedState().length);
+            out.write(enumState.getWrappedState());
             out.flush();
             return baos.toByteArray();
         }
@@ -86,21 +73,7 @@ public class HybridSourceEnumeratorStateSerializer
             int length = in.readInt();
             byte[] nestedBytes = new byte[length];
             in.readFully(nestedBytes);
-            Object nested = 
serializerOf(sourceIndex).deserialize(nestedVersion, nestedBytes);
-            return new HybridSourceEnumeratorState(sourceIndex, nested);
+            return new HybridSourceEnumeratorState(sourceIndex, nestedBytes, 
nestedVersion);
         }
     }
-
-    private SimpleVersionedSerializer<Object> serializerOf(int sourceIndex) {
-        return cachedSerializers.computeIfAbsent(
-                sourceIndex,
-                (k -> {
-                    Source source =
-                            Preconditions.checkNotNull(
-                                    switchedSources.get(k),
-                                    "Source for index=%s not available",
-                                    sourceIndex);
-                    return source.getEnumeratorCheckpointSerializer();
-                }));
-    }
 }
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReader.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReader.java
index 28d4011..855f85d 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReader.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReader.java
@@ -34,7 +34,6 @@ import java.util.ArrayList;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 
 /**
@@ -55,17 +54,15 @@ import java.util.concurrent.CompletableFuture;
 public class HybridSourceReader<T> implements SourceReader<T, 
HybridSourceSplit> {
     private static final Logger LOG = 
LoggerFactory.getLogger(HybridSourceReader.class);
     private final SourceReaderContext readerContext;
-    private final Map<Integer, Source> switchedSources;
+    private final SwitchedSources switchedSources = new SwitchedSources();
     private int currentSourceIndex = -1;
     private boolean isFinalSource;
     private SourceReader<T, ? extends SourceSplit> currentReader;
     private CompletableFuture<Void> availabilityFuture = new 
CompletableFuture<>();
     private List<HybridSourceSplit> restoredSplits = new ArrayList<>();
 
-    public HybridSourceReader(
-            SourceReaderContext readerContext, Map<Integer, Source> 
switchedSources) {
+    public HybridSourceReader(SourceReaderContext readerContext) {
         this.readerContext = readerContext;
-        this.switchedSources = switchedSources;
     }
 
     @Override
@@ -117,7 +114,7 @@ public class HybridSourceReader<T> implements 
SourceReader<T, HybridSourceSplit>
                 currentReader != null
                         ? currentReader.snapshotState(checkpointId)
                         : Collections.emptyList();
-        return HybridSourceSplit.wrapSplits(currentSourceIndex, state);
+        return HybridSourceSplit.wrapSplits(state, currentSourceIndex, 
switchedSources);
     }
 
     @Override
@@ -158,7 +155,7 @@ public class HybridSourceReader<T> implements 
SourceReader<T, HybridSourceSplit>
                         "Split %s while current source is %s",
                         split,
                         currentSourceIndex);
-                realSplits.add(split.getWrappedSplit());
+                realSplits.add(HybridSourceSplit.unwrapSplit(split, 
switchedSources));
             }
             currentReader.addSplits((List) realSplits);
         }
@@ -224,9 +221,7 @@ public class HybridSourceReader<T> implements 
SourceReader<T, HybridSourceSplit>
                     currentReader);
         }
         // TODO: track previous readers splits till checkpoint
-        Source source =
-                Preconditions.checkNotNull(
-                        switchedSources.get(index), "Source for index=%s not 
available", index);
+        Source source = switchedSources.sourceOf(index);
         SourceReader<T, ?> reader;
         try {
             reader = source.createReader(readerContext);
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplit.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplit.java
index 9057aa6..f26bd10 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplit.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplit.java
@@ -19,33 +19,45 @@
 package org.apache.flink.connector.base.source.hybrid;
 
 import org.apache.flink.api.connector.source.SourceSplit;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
 
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 import java.util.Objects;
 
 /** Source split that wraps the actual split type. */
 public class HybridSourceSplit implements SourceSplit {
 
-    private final SourceSplit wrappedSplit;
+    private final byte[] wrappedSplitBytes;
+    private final int wrappedSplitSerializerVersion;
     private final int sourceIndex;
+    private final String splitId;
 
-    public HybridSourceSplit(int sourceIndex, SourceSplit wrappedSplit) {
+    public HybridSourceSplit(
+            int sourceIndex, byte[] wrappedSplit, int serializerVersion, 
String splitId) {
         this.sourceIndex = sourceIndex;
-        this.wrappedSplit = wrappedSplit;
+        this.wrappedSplitBytes = wrappedSplit;
+        this.wrappedSplitSerializerVersion = serializerVersion;
+        this.splitId = splitId;
     }
 
     public int sourceIndex() {
         return this.sourceIndex;
     }
 
-    public SourceSplit getWrappedSplit() {
-        return wrappedSplit;
+    public byte[] wrappedSplitBytes() {
+        return wrappedSplitBytes;
+    }
+
+    public int wrappedSplitSerializerVersion() {
+        return wrappedSplitSerializerVersion;
     }
 
     @Override
     public String splitId() {
-        return wrappedSplit.splitId();
+        return splitId;
     }
 
     @Override
@@ -57,38 +69,59 @@ public class HybridSourceSplit implements SourceSplit {
             return false;
         }
         HybridSourceSplit that = (HybridSourceSplit) o;
-        return sourceIndex == that.sourceIndex && 
wrappedSplit.equals(that.wrappedSplit);
+        return sourceIndex == that.sourceIndex
+                && Arrays.equals(wrappedSplitBytes, that.wrappedSplitBytes);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(wrappedSplit, sourceIndex);
+        return Objects.hash(wrappedSplitBytes, sourceIndex);
     }
 
     @Override
     public String toString() {
-        return "HybridSourceSplit{"
-                + "realSplit="
-                + wrappedSplit
-                + ", sourceIndex="
-                + sourceIndex
-                + '}';
+        return "HybridSourceSplit{" + "sourceIndex=" + sourceIndex + ", 
splitId=" + splitId + '}';
     }
 
     public static List<HybridSourceSplit> wrapSplits(
-            int readerIndex, List<? extends SourceSplit> state) {
+            List<? extends SourceSplit> state, int readerIndex, 
SwitchedSources switchedSources) {
         List<HybridSourceSplit> wrappedSplits = new ArrayList<>(state.size());
         for (SourceSplit split : state) {
-            wrappedSplits.add(new HybridSourceSplit(readerIndex, split));
+            wrappedSplits.add(wrapSplit(split, readerIndex, switchedSources));
         }
         return wrappedSplits;
     }
 
-    public static List<SourceSplit> unwrapSplits(List<HybridSourceSplit> 
splits) {
+    public static HybridSourceSplit wrapSplit(
+            SourceSplit split, int sourceIndex, SwitchedSources 
switchedSources) {
+        try {
+            SimpleVersionedSerializer<SourceSplit> serializer =
+                    switchedSources.serializerOf(sourceIndex);
+            byte[] serialized = serializer.serialize(split);
+            return new HybridSourceSplit(
+                    sourceIndex, serialized, serializer.getVersion(), 
split.splitId());
+        } catch (IOException ex) {
+            throw new RuntimeException(ex);
+        }
+    }
+
+    public static List<SourceSplit> unwrapSplits(
+            List<HybridSourceSplit> splits, SwitchedSources switchedSources) {
         List<SourceSplit> unwrappedSplits = new ArrayList<>(splits.size());
         for (HybridSourceSplit split : splits) {
-            unwrappedSplits.add(split.getWrappedSplit());
+            unwrappedSplits.add(unwrapSplit(split, switchedSources));
         }
         return unwrappedSplits;
     }
+
+    public static SourceSplit unwrapSplit(
+            HybridSourceSplit split, SwitchedSources switchedSources) {
+        try {
+            return switchedSources
+                    .serializerOf(split.sourceIndex())
+                    .deserialize(split.wrappedSplitSerializerVersion(), 
split.wrappedSplitBytes());
+        } catch (IOException ex) {
+            throw new RuntimeException(ex);
+        }
+    }
 }
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
index 0f2b036..d27de22 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumerator.java
@@ -25,6 +25,7 @@ import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.api.connector.source.SplitEnumerator;
 import org.apache.flink.api.connector.source.SplitEnumeratorContext;
 import org.apache.flink.api.connector.source.SplitsAssignment;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
 import org.apache.flink.metrics.groups.SplitEnumeratorMetricGroup;
 import org.apache.flink.util.Preconditions;
 
@@ -68,21 +69,21 @@ public class HybridSourceSplitEnumerator
 
     private final SplitEnumeratorContext<HybridSourceSplit> context;
     private final List<HybridSource.SourceListEntry> sources;
-    private final Map<Integer, Source> switchedSources;
+    private final SwitchedSources switchedSources = new SwitchedSources();
     // Splits that have been returned due to subtask reset
     private final Map<Integer, TreeMap<Integer, List<HybridSourceSplit>>> 
pendingSplits;
     private final Set<Integer> finishedReaders;
     private final Map<Integer, Integer> readerSourceIndex;
     private int currentSourceIndex;
-    private Object restoredEnumeratorState;
+    private HybridSourceEnumeratorState restoredEnumeratorState;
     private SplitEnumerator<SourceSplit, Object> currentEnumerator;
+    private SimpleVersionedSerializer<Object> 
currentEnumeratorCheckpointSerializer;
 
     public HybridSourceSplitEnumerator(
             SplitEnumeratorContext<HybridSourceSplit> context,
             List<HybridSource.SourceListEntry> sources,
             int initialSourceIndex,
-            Map<Integer, Source> switchedSources,
-            Object restoredEnumeratorState) {
+            HybridSourceEnumeratorState restoredEnumeratorState) {
         Preconditions.checkArgument(initialSourceIndex < sources.size());
         this.context = context;
         this.sources = sources;
@@ -90,7 +91,6 @@ public class HybridSourceSplitEnumerator
         this.pendingSplits = new HashMap<>();
         this.finishedReaders = new HashSet<>();
         this.readerSourceIndex = new HashMap<>();
-        this.switchedSources = switchedSources;
         this.restoredEnumeratorState = restoredEnumeratorState;
     }
 
@@ -127,7 +127,8 @@ public class HybridSourceSplitEnumerator
                 (k, splitsPerSource) -> {
                     if (k == currentSourceIndex) {
                         currentEnumerator.addSplitsBack(
-                                
HybridSourceSplit.unwrapSplits(splitsPerSource), subtaskId);
+                                
HybridSourceSplit.unwrapSplits(splitsPerSource, switchedSources),
+                                subtaskId);
                     } else {
                         pendingSplits
                                 .computeIfAbsent(subtaskId, sourceIndex -> new 
TreeMap<>())
@@ -144,7 +145,7 @@ public class HybridSourceSplitEnumerator
 
     private void sendSwitchSourceEvent(int subtaskId, int sourceIndex) {
         readerSourceIndex.put(subtaskId, sourceIndex);
-        Source source = 
Preconditions.checkNotNull(switchedSources.get(sourceIndex));
+        Source source = switchedSources.sourceOf(sourceIndex);
         context.sendEventToSourceReader(
                 subtaskId,
                 new SwitchSourceEvent(sourceIndex, source, sourceIndex >= 
(sources.size() - 1)));
@@ -172,7 +173,11 @@ public class HybridSourceSplitEnumerator
     @Override
     public HybridSourceEnumeratorState snapshotState(long checkpointId) throws 
Exception {
         Object enumState = currentEnumerator.snapshotState(checkpointId);
-        return new HybridSourceEnumeratorState(currentSourceIndex, enumState);
+        byte[] enumStateBytes = 
currentEnumeratorCheckpointSerializer.serialize(enumState);
+        return new HybridSourceEnumeratorState(
+                currentSourceIndex,
+                enumStateBytes,
+                currentEnumeratorCheckpointSerializer.getVersion());
     }
 
     @Override
@@ -262,21 +267,22 @@ public class HybridSourceSplitEnumerator
                 };
 
         Source<?, ? extends SourceSplit, Object> source =
-                switchedSources.computeIfAbsent(
-                        currentSourceIndex,
-                        k -> {
-                            return 
sources.get(currentSourceIndex).factory.create(switchContext);
-                        });
+                sources.get(currentSourceIndex).factory.create(switchContext);
         switchedSources.put(currentSourceIndex, source);
+        currentEnumeratorCheckpointSerializer = 
source.getEnumeratorCheckpointSerializer();
         SplitEnumeratorContextProxy delegatingContext =
-                new SplitEnumeratorContextProxy(currentSourceIndex, context, 
readerSourceIndex);
+                new SplitEnumeratorContextProxy(
+                        currentSourceIndex, context, readerSourceIndex, 
switchedSources);
         try {
             if (restoredEnumeratorState == null) {
                 currentEnumerator = source.createEnumerator(delegatingContext);
             } else {
                 LOG.info("Restoring enumerator for sourceIndex={}", 
currentSourceIndex);
-                currentEnumerator =
-                        source.restoreEnumerator(delegatingContext, 
restoredEnumeratorState);
+                Object nestedEnumState =
+                        currentEnumeratorCheckpointSerializer.deserialize(
+                                
restoredEnumeratorState.getWrappedStateSerializerVersion(),
+                                restoredEnumeratorState.getWrappedState());
+                currentEnumerator = 
source.restoreEnumerator(delegatingContext, nestedEnumState);
                 restoredEnumeratorState = null;
             }
         } catch (Exception e) {
@@ -301,14 +307,17 @@ public class HybridSourceSplitEnumerator
         private final SplitEnumeratorContext<HybridSourceSplit> realContext;
         private final int sourceIndex;
         private final Map<Integer, Integer> readerSourceIndex;
+        private final SwitchedSources switchedSources;
 
         private SplitEnumeratorContextProxy(
                 int sourceIndex,
                 SplitEnumeratorContext<HybridSourceSplit> realContext,
-                Map<Integer, Integer> readerSourceIndex) {
+                Map<Integer, Integer> readerSourceIndex,
+                SwitchedSources switchedSources) {
             this.realContext = realContext;
             this.sourceIndex = sourceIndex;
             this.readerSourceIndex = readerSourceIndex;
+            this.switchedSources = switchedSources;
         }
 
         @Override
@@ -358,7 +367,7 @@ public class HybridSourceSplitEnumerator
             Map<Integer, List<HybridSourceSplit>> wrappedAssignmentMap = new 
HashMap<>();
             for (Map.Entry<Integer, List<SplitT>> e : 
newSplitAssignments.assignment().entrySet()) {
                 List<HybridSourceSplit> splits =
-                        HybridSourceSplit.wrapSplits(sourceIndex, 
e.getValue());
+                        HybridSourceSplit.wrapSplits(e.getValue(), 
sourceIndex, switchedSources);
                 wrappedAssignmentMap.put(e.getKey(), splits);
             }
             SplitsAssignment<HybridSourceSplit> wrappedAssignments =
@@ -369,7 +378,8 @@ public class HybridSourceSplitEnumerator
 
         @Override
         public void assignSplit(SplitT split, int subtask) {
-            HybridSourceSplit wrappedSplit = new 
HybridSourceSplit(sourceIndex, split);
+            HybridSourceSplit wrappedSplit =
+                    HybridSourceSplit.wrapSplit(split, sourceIndex, 
switchedSources);
             realContext.assignSplit(wrappedSplit, subtask);
         }
 
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializer.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializer.java
index 025733c..8fe7b7c 100644
--- 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializer.java
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializer.java
@@ -18,29 +18,18 @@
 
 package org.apache.flink.connector.base.source.hybrid;
 
-import org.apache.flink.api.connector.source.Source;
-import org.apache.flink.api.connector.source.SourceSplit;
 import org.apache.flink.core.io.SimpleVersionedSerializer;
-import org.apache.flink.util.Preconditions;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
 import java.io.DataInputStream;
 import java.io.DataOutputStream;
 import java.io.IOException;
-import java.util.HashMap;
-import java.util.Map;
 
 /** Serializes splits by delegating to the source-indexed underlying split 
serializer. */
 public class HybridSourceSplitSerializer implements 
SimpleVersionedSerializer<HybridSourceSplit> {
 
-    final Map<Integer, SimpleVersionedSerializer<SourceSplit>> 
cachedSerializers;
-    final Map<Integer, Source> switchedSources;
-
-    public HybridSourceSplitSerializer(Map<Integer, Source> switchedSources) {
-        this.cachedSerializers = new HashMap<>();
-        this.switchedSources = switchedSources;
-    }
+    public HybridSourceSplitSerializer() {}
 
     @Override
     public int getVersion() {
@@ -52,11 +41,10 @@ public class HybridSourceSplitSerializer implements 
SimpleVersionedSerializer<Hy
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
                 DataOutputStream out = new DataOutputStream(baos)) {
             out.writeInt(split.sourceIndex());
-            out.writeInt(serializerOf(split.sourceIndex()).getVersion());
-            byte[] serializedSplit =
-                    
serializerOf(split.sourceIndex()).serialize(split.getWrappedSplit());
-            out.writeInt(serializedSplit.length);
-            out.write(serializedSplit);
+            out.writeUTF(split.splitId());
+            out.writeInt(split.wrappedSplitSerializerVersion());
+            out.writeInt(split.wrappedSplitBytes().length);
+            out.write(split.wrappedSplitBytes());
             out.flush();
             return baos.toByteArray();
         }
@@ -74,25 +62,12 @@ public class HybridSourceSplitSerializer implements 
SimpleVersionedSerializer<Hy
         try (ByteArrayInputStream bais = new ByteArrayInputStream(serialized);
                 DataInputStream in = new DataInputStream(bais)) {
             int sourceIndex = in.readInt();
+            String splitId = in.readUTF();
             int nestedVersion = in.readInt();
             int length = in.readInt();
             byte[] splitBytes = new byte[length];
             in.readFully(splitBytes);
-            SourceSplit split = 
serializerOf(sourceIndex).deserialize(nestedVersion, splitBytes);
-            return new HybridSourceSplit(sourceIndex, split);
+            return new HybridSourceSplit(sourceIndex, splitBytes, 
nestedVersion, splitId);
         }
     }
-
-    private SimpleVersionedSerializer<SourceSplit> serializerOf(int 
sourceIndex) {
-        return cachedSerializers.computeIfAbsent(
-                sourceIndex,
-                (k -> {
-                    Source source =
-                            Preconditions.checkNotNull(
-                                    switchedSources.get(k),
-                                    "Source for index=%s not available",
-                                    sourceIndex);
-                    return source.getSplitSerializer();
-                }));
-    }
 }
diff --git 
a/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/SwitchedSources.java
 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/SwitchedSources.java
new file mode 100644
index 0000000..7911612
--- /dev/null
+++ 
b/flink-connectors/flink-connector-base/src/main/java/org/apache/flink/connector/base/source/hybrid/SwitchedSources.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.connector.base.source.hybrid;
+
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.connector.source.SourceSplit;
+import org.apache.flink.core.io.SimpleVersionedSerializer;
+import org.apache.flink.util.Preconditions;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/** Sources that participated in switching with cached serializers. */
+class SwitchedSources {
+    private final Map<Integer, Source> sources = new HashMap<>();
+    private final Map<Integer, SimpleVersionedSerializer<SourceSplit>> 
cachedSerializers =
+            new HashMap<>();
+
+    public Source sourceOf(int sourceIndex) {
+        return Preconditions.checkNotNull(
+                sources.get(sourceIndex), "Source for index=%s not available", 
sourceIndex);
+    }
+
+    public SimpleVersionedSerializer<SourceSplit> serializerOf(int 
sourceIndex) {
+        return cachedSerializers.computeIfAbsent(
+                sourceIndex, (k -> sourceOf(k).getSplitSerializer()));
+    }
+
+    public void put(int sourceIndex, Source source) {
+        sources.put(sourceIndex, Preconditions.checkNotNull(source));
+    }
+}
diff --git 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReaderTest.java
 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReaderTest.java
index 7882333..031a735 100644
--- 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReaderTest.java
+++ 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceReaderTest.java
@@ -36,9 +36,7 @@ import org.junit.Test;
 import org.mockito.Mockito;
 
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
-import java.util.Map;
 
 /** Tests for {@link HybridSourceReader}. */
 public class HybridSourceReaderTest {
@@ -55,10 +53,7 @@ public class HybridSourceReaderTest {
         SourceReader<Integer, MockSourceSplit> mockSplitReader2 =
                 source.createReader(readerContext);
 
-        Map<Integer, Source> switchedSources = new HashMap<>();
-
-        HybridSourceReader<Integer> reader =
-                new HybridSourceReader<>(readerContext, switchedSources);
+        HybridSourceReader<Integer> reader = new 
HybridSourceReader<>(readerContext);
 
         Assert.assertThat(readerContext.getSentEvents(), 
Matchers.emptyIterable());
         reader.start();
@@ -75,10 +70,13 @@ public class HybridSourceReaderTest {
                     }
                 };
         reader.handleSourceEvents(new SwitchSourceEvent(0, source1, false));
-        Assert.assertEquals(source1, switchedSources.get(0));
+
         MockSourceSplit mockSplit = new MockSourceSplit(0, 0, 1);
         mockSplit.addRecord(0);
-        HybridSourceSplit hybridSplit = new HybridSourceSplit(0, mockSplit);
+
+        SwitchedSources switchedSources = new SwitchedSources();
+        switchedSources.put(0, source);
+        HybridSourceSplit hybridSplit = HybridSourceSplit.wrapSplit(mockSplit, 
0, switchedSources);
         reader.addSplits(Collections.singletonList(hybridSplit));
 
         // drain splits
@@ -128,19 +126,17 @@ public class HybridSourceReaderTest {
         TestingReaderOutput<Integer> readerOutput = new 
TestingReaderOutput<>();
         MockBaseSource source = new MockBaseSource(1, 1, Boundedness.BOUNDED);
 
-        Map<Integer, Source> switchedSources = new HashMap<>();
-
-        HybridSourceReader<Integer> reader =
-                new HybridSourceReader<>(readerContext, switchedSources);
+        HybridSourceReader<Integer> reader = new 
HybridSourceReader<>(readerContext);
 
         reader.start();
         assertAndClearSourceReaderFinishedEvent(readerContext, -1);
         reader.handleSourceEvents(new SwitchSourceEvent(0, source, false));
-        Assert.assertEquals(source, switchedSources.get(0));
 
         MockSourceSplit mockSplit = new MockSourceSplit(0, 0, 2147483647);
-        // mockSplit.addRecord(0);
-        HybridSourceSplit hybridSplit = new HybridSourceSplit(0, mockSplit);
+
+        SwitchedSources switchedSources = new SwitchedSources();
+        switchedSources.put(0, source);
+        HybridSourceSplit hybridSplit = HybridSourceSplit.wrapSplit(mockSplit, 
0, switchedSources);
         reader.addSplits(Collections.singletonList(hybridSplit));
 
         List<HybridSourceSplit> snapshot = reader.snapshotState(0);
@@ -148,8 +144,7 @@ public class HybridSourceReaderTest {
 
         // reader recovery
         readerContext.clearSentEvents();
-        switchedSources = new HashMap<>();
-        reader = new HybridSourceReader<>(readerContext, switchedSources);
+        reader = new HybridSourceReader<>(readerContext);
 
         reader.addSplits(snapshot);
         Assert.assertNull(currentReader(reader));
@@ -160,7 +155,6 @@ public class HybridSourceReaderTest {
         assertAndClearSourceReaderFinishedEvent(readerContext, -1);
         reader.handleSourceEvents(new SwitchSourceEvent(0, source, false));
         Assert.assertNotNull(currentReader(reader));
-        Assert.assertEquals(source, switchedSources.get(0));
         Assert.assertThat(reader.snapshotState(1), 
Matchers.contains(hybridSplit));
 
         reader.close();
@@ -179,15 +173,11 @@ public class HybridSourceReaderTest {
                     }
                 };
 
-        Map<Integer, Source> switchedSources = new HashMap<>();
-
-        HybridSourceReader<Integer> reader =
-                new HybridSourceReader<>(readerContext, switchedSources);
+        HybridSourceReader<Integer> reader = new 
HybridSourceReader<>(readerContext);
 
         reader.start();
         assertAndClearSourceReaderFinishedEvent(readerContext, -1);
         reader.handleSourceEvents(new SwitchSourceEvent(0, source, false));
-        Assert.assertEquals(source, switchedSources.get(0));
         SourceReader<Integer, MockSourceSplit> underlyingReader = 
currentReader(reader);
 
         reader.notifyCheckpointComplete(1);
diff --git 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java
 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java
index 59a1836..7bcf69c 100644
--- 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java
+++ 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitEnumeratorTest.java
@@ -47,6 +47,7 @@ public class HybridSourceSplitEnumeratorTest {
 
     private static final int SUBTASK0 = 0;
     private static final int SUBTASK1 = 1;
+    private static final MockBaseSource MOCK_SOURCE = new MockBaseSource(1, 1, 
Boundedness.BOUNDED);
 
     private HybridSource<Integer> source;
     private MockSplitEnumeratorContext<HybridSourceSplit> context;
@@ -56,10 +57,7 @@ public class HybridSourceSplitEnumeratorTest {
 
     private void setupEnumeratorAndTriggerSourceSwitch() {
         context = new MockSplitEnumeratorContext<>(2);
-        source =
-                HybridSource.builder(new MockBaseSource(1, 1, 
Boundedness.BOUNDED))
-                        .addSource(new MockBaseSource(1, 1, 
Boundedness.BOUNDED))
-                        .build();
+        source = 
HybridSource.builder(MOCK_SOURCE).addSource(MOCK_SOURCE).build();
 
         enumerator = (HybridSourceSplitEnumerator) 
source.createEnumerator(context);
         enumerator.start();
@@ -130,9 +128,7 @@ public class HybridSourceSplitEnumeratorTest {
         setupEnumeratorAndTriggerSourceSwitch();
 
         UnderlyingEnumeratorWrapper underlyingEnumeratorWrapper =
-                new UnderlyingEnumeratorWrapper(
-                        (MockSplitEnumerator)
-                                Whitebox.getInternalState(enumerator, 
"currentEnumerator"));
+                new 
UnderlyingEnumeratorWrapper(getCurrentEnumerator(enumerator));
         Whitebox.setInternalState(enumerator, "currentEnumerator", 
underlyingEnumeratorWrapper);
 
         List<MockSourceSplit> mockSourceSplits =
@@ -147,11 +143,15 @@ public class HybridSourceSplitEnumeratorTest {
         assertThat(underlyingEnumeratorWrapper.handleSplitRequests, 
Matchers.emptyIterable());
         enumerator.handleSplitRequest(SUBTASK0, "fakehostname");
 
+        SwitchedSources switchedSources = new SwitchedSources();
+        switchedSources.put(1, MOCK_SOURCE);
+
         assertSplitAssignment(
                 "handleSplitRequest triggers assignment of split by underlying 
enumerator",
                 context,
                 1,
-                new HybridSourceSplit(1, UnderlyingEnumeratorWrapper.SPLIT_1),
+                HybridSourceSplit.wrapSplit(
+                        UnderlyingEnumeratorWrapper.SPLIT_1, 1, 
switchedSources),
                 SUBTASK0);
 
         // handleSplitRequest invalid during reset
@@ -169,21 +169,24 @@ public class HybridSourceSplitEnumeratorTest {
         enumerator = (HybridSourceSplitEnumerator) 
source.createEnumerator(context);
         enumerator.start();
         HybridSourceEnumeratorState enumeratorState = 
enumerator.snapshotState(0);
-        Assert.assertEquals(1, ((List) 
enumeratorState.getWrappedState()).size());
+        MockSplitEnumerator underlyingEnumerator = 
getCurrentEnumerator(enumerator);
+        Assert.assertThat(
+                (List<MockSourceSplit>) 
Whitebox.getInternalState(underlyingEnumerator, "splits"),
+                Matchers.iterableWithSize(1));
         enumerator =
                 (HybridSourceSplitEnumerator) 
source.restoreEnumerator(context, enumeratorState);
         enumerator.start();
-        enumeratorState = enumerator.snapshotState(0);
-        Assert.assertEquals(1, ((List) 
enumeratorState.getWrappedState()).size());
+        underlyingEnumerator = getCurrentEnumerator(enumerator);
+        Assert.assertThat(
+                (List<MockSourceSplit>) 
Whitebox.getInternalState(underlyingEnumerator, "splits"),
+                Matchers.iterableWithSize(1));
     }
 
     @Test
     public void testDefaultMethodDelegation() throws Exception {
         setupEnumeratorAndTriggerSourceSwitch();
         SplitEnumerator<MockSourceSplit, Object> underlyingEnumeratorSpy =
-                Mockito.spy(
-                        (SplitEnumerator<MockSourceSplit, Object>)
-                                Whitebox.getInternalState(enumerator, 
"currentEnumerator"));
+                Mockito.spy((SplitEnumerator) 
getCurrentEnumerator(enumerator));
         Whitebox.setInternalState(enumerator, "currentEnumerator", 
underlyingEnumeratorSpy);
 
         enumerator.notifyCheckpointComplete(1);
@@ -270,4 +273,9 @@ public class HybridSourceSplitEnumeratorTest {
     private static int getCurrentSourceIndex(HybridSourceSplitEnumerator 
enumerator) {
         return (int) Whitebox.getInternalState(enumerator, 
"currentSourceIndex");
     }
+
+    private static MockSplitEnumerator getCurrentEnumerator(
+            HybridSourceSplitEnumerator enumerator) {
+        return (MockSplitEnumerator) Whitebox.getInternalState(enumerator, 
"currentEnumerator");
+    }
 }
diff --git 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializerTest.java
 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializerTest.java
index d43275f..e2db86e 100644
--- 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializerTest.java
+++ 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/hybrid/HybridSourceSplitSerializerTest.java
@@ -20,7 +20,6 @@ package org.apache.flink.connector.base.source.hybrid;
 
 import org.apache.flink.api.connector.source.Source;
 import org.apache.flink.api.connector.source.mocks.MockSource;
-import org.apache.flink.api.connector.source.mocks.MockSourceSplit;
 
 import org.junit.Assert;
 import org.junit.Test;
@@ -36,8 +35,9 @@ public class HybridSourceSplitSerializerTest {
     public void testSerialization() throws Exception {
         Map<Integer, Source> switchedSources = new HashMap<>();
         switchedSources.put(0, new MockSource(null, 0));
-        HybridSourceSplitSerializer serializer = new 
HybridSourceSplitSerializer(switchedSources);
-        HybridSourceSplit split = new HybridSourceSplit(0, new 
MockSourceSplit(1));
+        byte[] splitBytes = {1, 2, 3};
+        HybridSourceSplitSerializer serializer = new 
HybridSourceSplitSerializer();
+        HybridSourceSplit split = new HybridSourceSplit(0, splitBytes, 0, 
"splitId");
         byte[] serialized = serializer.serialize(split);
         HybridSourceSplit clonedSplit = serializer.deserialize(0, serialized);
         Assert.assertEquals(split, clonedSplit);
diff --git 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/mocks/MockBaseSource.java
 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/mocks/MockBaseSource.java
index bd01adf..ff3e80e 100644
--- 
a/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/mocks/MockBaseSource.java
+++ 
b/flink-connectors/flink-connector-base/src/test/java/org/apache/flink/connector/base/source/reader/mocks/MockBaseSource.java
@@ -117,7 +117,8 @@ public class MockBaseSource implements Source<Integer, 
MockSourceSplit, List<Moc
 
             @Override
             public byte[] serialize(List<MockSourceSplit> obj) throws 
IOException {
-                return InstantiationUtil.serializeObject(obj.toArray());
+                return InstantiationUtil.serializeObject(
+                        obj.toArray(new MockSourceSplit[obj.size()]));
             }
 
             @Override

Reply via email to