This is an automated email from the ASF dual-hosted git repository.

zakelly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 39aa79ed5ec [FLINK-37569] Add userKeySerializer in 
RegisteredKeyValueStateBackend V2 (#26364)
39aa79ed5ec is described below

commit 39aa79ed5ec2dde785d0afe593187b037ebfc84c
Author: mayuehappy <[email protected]>
AuthorDate: Wed Apr 2 16:49:20 2025 +0800

    [FLINK-37569] Add userKeySerializer in RegisteredKeyValueStateBackend V2 
(#26364)
---
 .../runtime/state/RegisteredStateMetaInfoBase.java |  12 +-
 .../state/metainfo/StateMetaInfoSnapshot.java      |   3 +-
 ...redKeyAndUserKeyValueStateBackendMetaInfo.java} | 167 +++++++---------
 .../v2/RegisteredKeyValueStateBackendMetaInfo.java |  10 +-
 .../StateMetaInfoSnapshotEnumConstantsTest.java    |   4 +-
 ...gisteredKeyValueStateBackendMetaInfoV2Test.java | 218 +++++++++++++++++++++
 .../flink/state/forst/ForStKeyedStateBackend.java  |  46 ++++-
 .../flink/state/forst/ForStStateMigrationTest.java | 102 ++++++++++
 .../flink/state/forst/ForStStateTestBase.java      |  10 +-
 .../apache/flink/state/forst/ForStTestUtils.java   |  17 +-
 10 files changed, 471 insertions(+), 118 deletions(-)

diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredStateMetaInfoBase.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredStateMetaInfoBase.java
index 603d3559ac6..03f0bcf676b 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredStateMetaInfoBase.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/RegisteredStateMetaInfoBase.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.state;
 
+import org.apache.flink.api.common.state.v2.StateDescriptor;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
 
 import javax.annotation.Nonnull;
@@ -66,8 +67,15 @@ public abstract class RegisteredStateMetaInfoBase {
             case PRIORITY_QUEUE:
                 return new 
RegisteredPriorityQueueStateBackendMetaInfo<>(snapshot);
             case KEY_VALUE_V2:
-                return new org.apache.flink.runtime.state.v2
-                        .RegisteredKeyValueStateBackendMetaInfo<>(snapshot);
+                if (snapshot.getOption(
+                                
StateMetaInfoSnapshot.CommonOptionsKeys.KEYED_STATE_TYPE.toString())
+                        .equals(StateDescriptor.Type.MAP.toString())) {
+                    return new org.apache.flink.runtime.state.v2
+                            
.RegisteredKeyAndUserKeyValueStateBackendMetaInfo<>(snapshot);
+                } else {
+                    return new org.apache.flink.runtime.state.v2
+                            
.RegisteredKeyValueStateBackendMetaInfo<>(snapshot);
+                }
             default:
                 throw new IllegalArgumentException(
                         "Unknown backend state type: " + backendStateType);
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
index 193f61fb402..12432cd6629 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshot.java
@@ -80,7 +80,8 @@ public class StateMetaInfoSnapshot {
     public enum CommonSerializerKeys {
         KEY_SERIALIZER,
         NAMESPACE_SERIALIZER,
-        VALUE_SERIALIZER
+        VALUE_SERIALIZER,
+        USER_KEY_SERIALIZER
     }
 
     /** The name of the state. */
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyAndUserKeyValueStateBackendMetaInfo.java
similarity index 62%
copy from 
flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
copy to 
flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyAndUserKeyValueStateBackendMetaInfo.java
index 45107561f9a..96aac93ba69 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyAndUserKeyValueStateBackendMetaInfo.java
@@ -22,7 +22,6 @@ import org.apache.flink.api.common.state.v2.StateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
 import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
-import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
 import org.apache.flink.runtime.state.StateSerializerProvider;
 import 
org.apache.flink.runtime.state.StateSnapshotTransformer.StateSnapshotTransformFactory;
 import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
@@ -30,6 +29,7 @@ import org.apache.flink.util.CollectionUtil;
 import org.apache.flink.util.Preconditions;
 
 import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
 
 import java.util.Collections;
 import java.util.Map;
@@ -41,44 +41,33 @@ import java.util.Objects;
  *
  * @param <S> Type of state value
  */
-public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends 
RegisteredStateMetaInfoBase {
+public class RegisteredKeyAndUserKeyValueStateBackendMetaInfo<N, UK, S>
+        extends RegisteredKeyValueStateBackendMetaInfo<N, S> {
 
-    @Nonnull private final StateDescriptor.Type stateType;
-    @Nonnull private final StateSerializerProvider<N> 
namespaceSerializerProvider;
-    @Nonnull private final StateSerializerProvider<S> stateSerializerProvider;
-    @Nonnull private StateSnapshotTransformFactory<S> 
stateSnapshotTransformFactory;
+    // We keep it @Nullable since in the very first version we did not store 
this serializer here.
+    @Nullable private StateSerializerProvider<UK> userKeySerializerProvider;
 
-    public RegisteredKeyValueStateBackendMetaInfo(
-            @Nonnull String name,
-            @Nonnull StateDescriptor.Type stateType,
-            @Nonnull TypeSerializer<N> namespaceSerializer,
-            @Nonnull TypeSerializer<S> stateSerializer) {
-
-        this(
-                name,
-                stateType,
-                
StateSerializerProvider.fromNewRegisteredSerializer(namespaceSerializer),
-                
StateSerializerProvider.fromNewRegisteredSerializer(stateSerializer),
-                StateSnapshotTransformFactory.noTransform());
-    }
-
-    public RegisteredKeyValueStateBackendMetaInfo(
+    public RegisteredKeyAndUserKeyValueStateBackendMetaInfo(
             @Nonnull String name,
             @Nonnull StateDescriptor.Type stateType,
             @Nonnull TypeSerializer<N> namespaceSerializer,
             @Nonnull TypeSerializer<S> stateSerializer,
-            @Nonnull StateSnapshotTransformFactory<S> 
stateSnapshotTransformFactory) {
+            @Nullable TypeSerializer<UK> userKeySerializer) {
 
         this(
                 name,
                 stateType,
                 
StateSerializerProvider.fromNewRegisteredSerializer(namespaceSerializer),
                 
StateSerializerProvider.fromNewRegisteredSerializer(stateSerializer),
-                stateSnapshotTransformFactory);
+                userKeySerializer == null
+                        ? null
+                        : 
StateSerializerProvider.fromNewRegisteredSerializer(userKeySerializer),
+                StateSnapshotTransformFactory.noTransform());
     }
 
     @SuppressWarnings("unchecked")
-    public RegisteredKeyValueStateBackendMetaInfo(@Nonnull 
StateMetaInfoSnapshot snapshot) {
+    public RegisteredKeyAndUserKeyValueStateBackendMetaInfo(
+            @Nonnull StateMetaInfoSnapshot snapshot) {
         this(
                 snapshot.getName(),
                 StateDescriptor.Type.valueOf(
@@ -96,6 +85,17 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
                                         snapshot.getTypeSerializerSnapshot(
                                                 
StateMetaInfoSnapshot.CommonSerializerKeys
                                                         .VALUE_SERIALIZER))),
+                snapshot.getTypeSerializerSnapshot(
+                                        
StateMetaInfoSnapshot.CommonSerializerKeys
+                                                .USER_KEY_SERIALIZER)
+                                == null
+                        ? null
+                        : 
StateSerializerProvider.fromPreviousSerializerSnapshot(
+                                (TypeSerializerSnapshot<UK>)
+                                        Preconditions.checkNotNull(
+                                                
snapshot.getTypeSerializerSnapshot(
+                                                        
StateMetaInfoSnapshot.CommonSerializerKeys
+                                                                
.USER_KEY_SERIALIZER))),
                 StateSnapshotTransformFactory.noTransform());
 
         Preconditions.checkState(
@@ -103,69 +103,57 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
                         == snapshot.getBackendStateType());
     }
 
-    private RegisteredKeyValueStateBackendMetaInfo(
+    private RegisteredKeyAndUserKeyValueStateBackendMetaInfo(
             @Nonnull String name,
             @Nonnull StateDescriptor.Type stateType,
             @Nonnull StateSerializerProvider<N> namespaceSerializerProvider,
             @Nonnull StateSerializerProvider<S> stateSerializerProvider,
+            @Nullable StateSerializerProvider<UK> userKeySerializerProvider,
             @Nonnull StateSnapshotTransformFactory<S> 
stateSnapshotTransformFactory) {
-
-        super(name);
-        this.stateType = stateType;
-        this.namespaceSerializerProvider = namespaceSerializerProvider;
-        this.stateSerializerProvider = stateSerializerProvider;
-        this.stateSnapshotTransformFactory = stateSnapshotTransformFactory;
-    }
-
-    @Nonnull
-    public StateDescriptor.Type getStateType() {
-        return stateType;
-    }
-
-    @Nonnull
-    public TypeSerializer<N> getNamespaceSerializer() {
-        return namespaceSerializerProvider.currentSchemaSerializer();
+        super(
+                name,
+                stateType,
+                namespaceSerializerProvider,
+                stateSerializerProvider,
+                stateSnapshotTransformFactory);
+        this.userKeySerializerProvider = userKeySerializerProvider;
     }
 
-    @Nonnull
-    public TypeSerializer<S> getStateSerializer() {
-        return stateSerializerProvider.currentSchemaSerializer();
+    @Nullable
+    public TypeSerializer<UK> getUserKeySerializer() {
+        return userKeySerializerProvider == null
+                ? null
+                : userKeySerializerProvider.currentSchemaSerializer();
     }
 
     @Nonnull
-    public TypeSerializerSchemaCompatibility<S> updateStateSerializer(
-            TypeSerializer<S> newStateSerializer) {
-        return 
stateSerializerProvider.registerNewSerializerForRestoredState(newStateSerializer);
+    public TypeSerializerSchemaCompatibility<UK> updateUserKeySerializer(
+            TypeSerializer<UK> newStateSerializer) {
+        if (userKeySerializerProvider == null) {
+            // This means that there is no userKeySerializerProvider in the 
previous StateMetaInfo,
+            // which may be restored from an old version.
+            this.userKeySerializerProvider =
+                    
StateSerializerProvider.fromNewRegisteredSerializer(newStateSerializer);
+            return TypeSerializerSchemaCompatibility.compatibleAsIs();
+        } else {
+            return 
userKeySerializerProvider.registerNewSerializerForRestoredState(
+                    newStateSerializer);
+        }
     }
 
     @Override
     public boolean equals(Object o) {
-        if (this == o) {
-            return true;
-        }
-
-        if (o == null || getClass() != o.getClass()) {
-            return false;
-        }
-
-        RegisteredKeyValueStateBackendMetaInfo<?, ?> that =
-                (RegisteredKeyValueStateBackendMetaInfo<?, ?>) o;
-
-        if (!stateType.equals(that.stateType)) {
-            return false;
-        }
-
-        if (!getName().equals(that.getName())) {
-            return false;
-        }
-
-        return getStateSerializer().equals(that.getStateSerializer())
-                && 
getNamespaceSerializer().equals(that.getNamespaceSerializer());
+        return super.equals(o)
+                && o instanceof 
RegisteredKeyAndUserKeyValueStateBackendMetaInfo
+                && Objects.equals(
+                        getUserKeySerializer(),
+                        ((RegisteredKeyAndUserKeyValueStateBackendMetaInfo<?, 
?, ?>) o)
+                                .getUserKeySerializer());
     }
 
     @Override
     public String toString() {
-        return "RegisteredKeyedBackendStateMetaInfo{"
+        return "RegisteredKeyAndUserKeyValueStateBackendMetaInfo{"
                 + "stateType="
                 + stateType
                 + ", name='"
@@ -175,6 +163,8 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
                 + getNamespaceSerializer()
                 + ", stateSerializer="
                 + getStateSerializer()
+                + ", userKeySerializer="
+                + getUserKeySerializer()
                 + '}';
     }
 
@@ -184,6 +174,9 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
         result = 31 * result + getStateType().hashCode();
         result = 31 * result + getNamespaceSerializer().hashCode();
         result = 31 * result + getStateSerializer().hashCode();
+        if (getUserKeySerializer() != null) {
+            result = 31 * result + getUserKeySerializer().hashCode();
+        }
         return result;
     }
 
@@ -195,30 +188,9 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
 
     @Nonnull
     @Override
-    public RegisteredKeyValueStateBackendMetaInfo<N, S> 
withSerializerUpgradesAllowed() {
-        return new RegisteredKeyValueStateBackendMetaInfo<>(snapshot());
-    }
-
-    public void checkStateMetaInfo(StateDescriptor<?> stateDesc) {
-        Preconditions.checkState(
-                Objects.equals(stateDesc.getStateId(), getName()),
-                "Incompatible state names. "
-                        + "Was ["
-                        + getName()
-                        + "], "
-                        + "registered with ["
-                        + stateDesc.getStateId()
-                        + "].");
-
-        Preconditions.checkState(
-                stateDesc.getType() == getStateType(),
-                "Incompatible key/value state types. "
-                        + "Was ["
-                        + getStateType()
-                        + "], "
-                        + "registered with ["
-                        + stateDesc.getType()
-                        + "].");
+    public RegisteredKeyAndUserKeyValueStateBackendMetaInfo<N, UK, S>
+            withSerializerUpgradesAllowed() {
+        return new 
RegisteredKeyAndUserKeyValueStateBackendMetaInfo<>(snapshot());
     }
 
     @Nonnull
@@ -247,6 +219,15 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
         serializerConfigSnapshotsMap.put(
                 valueSerializerKey, stateSerializer.snapshotConfiguration());
 
+        TypeSerializer<UK> userKeySerializer = getUserKeySerializer();
+        if (userKeySerializer != null) {
+            String userKeySerializerKey =
+                    
StateMetaInfoSnapshot.CommonSerializerKeys.USER_KEY_SERIALIZER.toString();
+            serializerMap.put(userKeySerializerKey, 
userKeySerializer.duplicate());
+            serializerConfigSnapshotsMap.put(
+                    userKeySerializerKey, 
userKeySerializer.snapshotConfiguration());
+        }
+
         return new StateMetaInfoSnapshot(
                 name,
                 StateMetaInfoSnapshot.BackendStateType.KEY_VALUE_V2,
diff --git 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
index 45107561f9a..60acace17b8 100644
--- 
a/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
+++ 
b/flink-runtime/src/main/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfo.java
@@ -43,10 +43,10 @@ import java.util.Objects;
  */
 public class RegisteredKeyValueStateBackendMetaInfo<N, S> extends 
RegisteredStateMetaInfoBase {
 
-    @Nonnull private final StateDescriptor.Type stateType;
-    @Nonnull private final StateSerializerProvider<N> 
namespaceSerializerProvider;
-    @Nonnull private final StateSerializerProvider<S> stateSerializerProvider;
-    @Nonnull private StateSnapshotTransformFactory<S> 
stateSnapshotTransformFactory;
+    @Nonnull protected final StateDescriptor.Type stateType;
+    @Nonnull protected final StateSerializerProvider<N> 
namespaceSerializerProvider;
+    @Nonnull protected final StateSerializerProvider<S> 
stateSerializerProvider;
+    @Nonnull protected StateSnapshotTransformFactory<S> 
stateSnapshotTransformFactory;
 
     public RegisteredKeyValueStateBackendMetaInfo(
             @Nonnull String name,
@@ -103,7 +103,7 @@ public class RegisteredKeyValueStateBackendMetaInfo<N, S> 
extends RegisteredStat
                         == snapshot.getBackendStateType());
     }
 
-    private RegisteredKeyValueStateBackendMetaInfo(
+    protected RegisteredKeyValueStateBackendMetaInfo(
             @Nonnull String name,
             @Nonnull StateDescriptor.Type stateType,
             @Nonnull StateSerializerProvider<N> namespaceSerializerProvider,
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
index 9d3f09466be..3dd3aa184de 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/metainfo/StateMetaInfoSnapshotEnumConstantsTest.java
@@ -64,7 +64,7 @@ class StateMetaInfoSnapshotEnumConstantsTest {
 
     @Test
     void testFixedSerializerEnumConstants() {
-        
assertThat(StateMetaInfoSnapshot.CommonSerializerKeys.values()).hasSize(3);
+        
assertThat(StateMetaInfoSnapshot.CommonSerializerKeys.values()).hasSize(4);
         
assertThat(StateMetaInfoSnapshot.CommonSerializerKeys.KEY_SERIALIZER.ordinal()).isZero();
         
assertThat(StateMetaInfoSnapshot.CommonSerializerKeys.NAMESPACE_SERIALIZER.ordinal())
                 .isOne();
@@ -76,5 +76,7 @@ class StateMetaInfoSnapshotEnumConstantsTest {
                 .isEqualTo("NAMESPACE_SERIALIZER");
         
assertThat(StateMetaInfoSnapshot.CommonSerializerKeys.VALUE_SERIALIZER.toString())
                 .isEqualTo("VALUE_SERIALIZER");
+        
assertThat(StateMetaInfoSnapshot.CommonSerializerKeys.USER_KEY_SERIALIZER.toString())
+                .isEqualTo("USER_KEY_SERIALIZER");
     }
 }
diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfoV2Test.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfoV2Test.java
new file mode 100644
index 00000000000..6454eae41f6
--- /dev/null
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/state/v2/RegisteredKeyValueStateBackendMetaInfoV2Test.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state.v2;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.core.memory.ByteArrayInputStreamWithPos;
+import org.apache.flink.core.memory.ByteArrayOutputStreamWithPos;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.runtime.state.KeyedBackendSerializationProxy;
+import org.apache.flink.runtime.state.RegisteredStateMetaInfoBase;
+import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class RegisteredKeyValueStateBackendMetaInfoV2Test {
+    @Test
+    void testRegisteredKeyValueStateBackendMetaInfoV2SerializationRoundtrip() 
throws Exception {
+
+        TypeSerializer<?> keySerializer = IntSerializer.INSTANCE;
+        TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
+        TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+        TypeSerializer<?> userKeySerializer = StringSerializer.INSTANCE;
+
+        List<StateMetaInfoSnapshot> stateMetaInfoList = new ArrayList<>();
+
+        stateMetaInfoList.add(
+                new 
org.apache.flink.runtime.state.v2.RegisteredKeyValueStateBackendMetaInfo<>(
+                                "a",
+                                
org.apache.flink.api.common.state.v2.StateDescriptor.Type.VALUE,
+                                namespaceSerializer,
+                                stateSerializer)
+                        .snapshot());
+        stateMetaInfoList.add(
+                new org.apache.flink.runtime.state.v2
+                                
.RegisteredKeyAndUserKeyValueStateBackendMetaInfo<>(
+                                "b",
+                                
org.apache.flink.api.common.state.v2.StateDescriptor.Type.MAP,
+                                namespaceSerializer,
+                                stateSerializer,
+                                userKeySerializer)
+                        .snapshot());
+        stateMetaInfoList.add(
+                new 
org.apache.flink.runtime.state.v2.RegisteredKeyValueStateBackendMetaInfo<>(
+                                "c",
+                                
org.apache.flink.api.common.state.v2.StateDescriptor.Type.VALUE,
+                                namespaceSerializer,
+                                stateSerializer)
+                        .snapshot());
+
+        KeyedBackendSerializationProxy<?> serializationProxy =
+                new KeyedBackendSerializationProxy<>(keySerializer, 
stateMetaInfoList, true);
+
+        byte[] serialized;
+        try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
+            serializationProxy.write(new DataOutputViewStreamWrapper(out));
+            serialized = out.toByteArray();
+        }
+
+        serializationProxy =
+                new KeyedBackendSerializationProxy<>(
+                        Thread.currentThread().getContextClassLoader());
+
+        try (ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
+            serializationProxy.read(new DataInputViewStreamWrapper(in));
+        }
+
+        assertThat(serializationProxy.isUsingKeyGroupCompression()).isTrue();
+        assertThat(serializationProxy.getKeySerializerSnapshot())
+                .isInstanceOf(IntSerializer.IntSerializerSnapshot.class);
+
+        assertEqualStateMetaInfoSnapshotsLists(
+                stateMetaInfoList, 
serializationProxy.getStateMetaInfoSnapshots());
+    }
+
+    @Test
+    void testMapKeyedStateMetaInfoSerialization() throws Exception {
+
+        TypeSerializer<?> keySerializer = IntSerializer.INSTANCE;
+        TypeSerializer<?> namespaceSerializer = LongSerializer.INSTANCE;
+        TypeSerializer<?> stateSerializer = DoubleSerializer.INSTANCE;
+        TypeSerializer<?> userKeySerializer = StringSerializer.INSTANCE;
+        List<StateMetaInfoSnapshot> stateMetaInfoList = new ArrayList<>();
+
+        // create StateMetaInfoSnapshot without userKeySerializer
+        StateMetaInfoSnapshot oldStateMeta =
+                new 
org.apache.flink.runtime.state.v2.RegisteredKeyValueStateBackendMetaInfo<>(
+                                "test1",
+                                
org.apache.flink.api.common.state.v2.StateDescriptor.Type.MAP,
+                                namespaceSerializer,
+                                stateSerializer)
+                        .snapshot();
+
+        StateMetaInfoSnapshot oldStateMetaWithUserKey =
+                new org.apache.flink.runtime.state.v2
+                                
.RegisteredKeyAndUserKeyValueStateBackendMetaInfo<>(
+                                "test2",
+                                
org.apache.flink.api.common.state.v2.StateDescriptor.Type.MAP,
+                                namespaceSerializer,
+                                stateSerializer,
+                                userKeySerializer)
+                        .snapshot();
+
+        stateMetaInfoList.add(oldStateMeta);
+        stateMetaInfoList.add(oldStateMetaWithUserKey);
+
+        assertThat(oldStateMeta.getBackendStateType())
+                
.isEqualTo(StateMetaInfoSnapshot.BackendStateType.KEY_VALUE_V2);
+        assertThat(oldStateMetaWithUserKey.getBackendStateType())
+                
.isEqualTo(StateMetaInfoSnapshot.BackendStateType.KEY_VALUE_V2);
+
+        assertThat(
+                        oldStateMeta.getTypeSerializerSnapshot(
+                                
StateMetaInfoSnapshot.CommonSerializerKeys.USER_KEY_SERIALIZER))
+                .isNull();
+        assertThat(
+                        oldStateMetaWithUserKey
+                                .getTypeSerializerSnapshot(
+                                        
StateMetaInfoSnapshot.CommonSerializerKeys
+                                                .USER_KEY_SERIALIZER)
+                                .restoreSerializer())
+                .isEqualTo(userKeySerializer);
+
+        KeyedBackendSerializationProxy<?> serializationProxy =
+                new KeyedBackendSerializationProxy<>(keySerializer, 
stateMetaInfoList, true);
+
+        byte[] serialized;
+        try (ByteArrayOutputStreamWithPos out = new 
ByteArrayOutputStreamWithPos()) {
+            serializationProxy.write(new DataOutputViewStreamWrapper(out));
+            serialized = out.toByteArray();
+        }
+
+        serializationProxy =
+                new KeyedBackendSerializationProxy<>(
+                        Thread.currentThread().getContextClassLoader());
+
+        try (ByteArrayInputStreamWithPos in = new 
ByteArrayInputStreamWithPos(serialized)) {
+            serializationProxy.read(new DataInputViewStreamWrapper(in));
+        }
+
+        List<StateMetaInfoSnapshot> stateMetaInfoSnapshots =
+                serializationProxy.getStateMetaInfoSnapshots();
+
+        
org.apache.flink.runtime.state.v2.RegisteredKeyValueStateBackendMetaInfo 
restoredMetaInfo =
+                (RegisteredKeyValueStateBackendMetaInfo)
+                        RegisteredStateMetaInfoBase.fromMetaInfoSnapshot(
+                                stateMetaInfoSnapshots.get(0));
+        assertThat(restoredMetaInfo.getClass())
+                .isEqualTo(
+                        org.apache.flink.runtime.state.v2
+                                
.RegisteredKeyAndUserKeyValueStateBackendMetaInfo.class);
+        assertThat(restoredMetaInfo.getName()).isEqualTo("test1");
+        assertThat(
+                        ((RegisteredKeyAndUserKeyValueStateBackendMetaInfo) 
restoredMetaInfo)
+                                .getUserKeySerializer())
+                .isNull();
+        
assertThat(restoredMetaInfo.getStateSerializer()).isEqualTo(DoubleSerializer.INSTANCE);
+        
assertThat(restoredMetaInfo.getNamespaceSerializer()).isEqualTo(LongSerializer.INSTANCE);
+
+        
org.apache.flink.runtime.state.v2.RegisteredKeyValueStateBackendMetaInfo 
restoredMetaInfo1 =
+                (RegisteredKeyValueStateBackendMetaInfo)
+                        RegisteredStateMetaInfoBase.fromMetaInfoSnapshot(
+                                stateMetaInfoSnapshots.get(1));
+        assertThat(restoredMetaInfo1.getClass())
+                .isEqualTo(
+                        org.apache.flink.runtime.state.v2
+                                
.RegisteredKeyAndUserKeyValueStateBackendMetaInfo.class);
+        assertThat(restoredMetaInfo1.getName()).isEqualTo("test2");
+        assertThat(
+                        ((RegisteredKeyAndUserKeyValueStateBackendMetaInfo) 
restoredMetaInfo1)
+                                .getUserKeySerializer())
+                .isEqualTo(StringSerializer.INSTANCE);
+        
assertThat(restoredMetaInfo1.getStateSerializer()).isEqualTo(DoubleSerializer.INSTANCE);
+        
assertThat(restoredMetaInfo1.getNamespaceSerializer()).isEqualTo(LongSerializer.INSTANCE);
+    }
+
+    private void assertEqualStateMetaInfoSnapshotsLists(
+            List<StateMetaInfoSnapshot> expected, List<StateMetaInfoSnapshot> 
actual) {
+        assertThat(actual).hasSameSizeAs(expected);
+        for (int i = 0; i < expected.size(); ++i) {
+            assertEqualStateMetaInfoSnapshots(expected.get(i), actual.get(i));
+        }
+    }
+
+    private void assertEqualStateMetaInfoSnapshots(
+            StateMetaInfoSnapshot expected, StateMetaInfoSnapshot actual) {
+        assertThat(actual.getName()).isEqualTo(expected.getName());
+        
assertThat(actual.getBackendStateType()).isEqualTo(expected.getBackendStateType());
+        
assertThat(actual.getOptionsImmutable()).isEqualTo(expected.getOptionsImmutable());
+        assertThat(actual.getSerializerSnapshotsImmutable())
+                .isEqualTo(expected.getSerializerSnapshotsImmutable());
+    }
+}
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
index 5cc7e2f3d2b..b443a766b40 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/main/java/org/apache/flink/state/forst/ForStKeyedStateBackend.java
@@ -21,6 +21,7 @@ import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.v2.AggregatingStateDescriptor;
 import org.apache.flink.api.common.state.v2.ListStateDescriptor;
+import org.apache.flink.api.common.state.v2.MapStateDescriptor;
 import org.apache.flink.api.common.state.v2.ReducingStateDescriptor;
 import org.apache.flink.api.common.state.v2.State;
 import org.apache.flink.api.common.state.v2.StateDescriptor;
@@ -48,11 +49,11 @@ import 
org.apache.flink.runtime.state.PriorityQueueSetFactory;
 import org.apache.flink.runtime.state.SerializedCompositeKeyBuilder;
 import org.apache.flink.runtime.state.SnapshotResult;
 import org.apache.flink.runtime.state.SnapshotStrategyRunner;
-import org.apache.flink.runtime.state.StateSnapshotTransformer;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueElement;
 import org.apache.flink.runtime.state.heap.HeapPriorityQueueSetFactory;
 import 
org.apache.flink.runtime.state.heap.HeapPriorityQueueSnapshotRestoreWrapper;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
+import 
org.apache.flink.runtime.state.v2.RegisteredKeyAndUserKeyValueStateBackendMetaInfo;
 import 
org.apache.flink.runtime.state.v2.RegisteredKeyValueStateBackendMetaInfo;
 import org.apache.flink.runtime.state.v2.internal.InternalKeyedState;
 import org.apache.flink.runtime.state.v2.ttl.TtlStateFactory;
@@ -360,6 +361,11 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
 
         TypeSerializer<SV> stateSerializer = stateDesc.getSerializer();
 
+        TypeSerializer<?> userKeySerializer =
+                stateDesc instanceof MapStateDescriptor
+                        ? ((MapStateDescriptor<?, SV>) 
stateDesc).getUserKeySerializer()
+                        : null;
+
         ForStOperationUtils.ForStKvStateInfo newStateInfo;
         RegisteredKeyValueStateBackendMetaInfo<N, SV> newMetaInfo;
         if (oldStateInfo != null) {
@@ -372,6 +378,7 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
                             Tuple2.of(oldStateInfo.columnFamilyHandle, 
castedMetaInfo),
                             stateDesc,
                             stateSerializer,
+                            userKeySerializer,
                             namespaceSerializer);
 
             newStateInfo =
@@ -379,13 +386,22 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
                             oldStateInfo.columnFamilyHandle, newMetaInfo);
             kvStateInformation.put(stateDesc.getStateId(), newStateInfo);
         } else {
-            newMetaInfo =
-                    new RegisteredKeyValueStateBackendMetaInfo<>(
-                            stateDesc.getStateId(),
-                            stateDesc.getType(),
-                            namespaceSerializer,
-                            stateSerializer,
-                            
StateSnapshotTransformer.StateSnapshotTransformFactory.noTransform());
+            if (stateDesc.getType().equals(StateDescriptor.Type.MAP)) {
+                newMetaInfo =
+                        new RegisteredKeyAndUserKeyValueStateBackendMetaInfo<>(
+                                stateDesc.getStateId(),
+                                stateDesc.getType(),
+                                namespaceSerializer,
+                                stateSerializer,
+                                userKeySerializer);
+            } else {
+                newMetaInfo =
+                        new RegisteredKeyValueStateBackendMetaInfo<>(
+                                stateDesc.getStateId(),
+                                stateDesc.getType(),
+                                namespaceSerializer,
+                                stateSerializer);
+            }
 
             newStateInfo =
                     ForStOperationUtils.createStateInfo(
@@ -405,10 +421,11 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
         return Tuple2.of(newStateInfo.columnFamilyHandle, newMetaInfo);
     }
 
-    private <N, SV> RegisteredKeyValueStateBackendMetaInfo<N, SV> 
updateRestoredStateMetaInfo(
+    private <N, UK, SV> RegisteredKeyValueStateBackendMetaInfo<N, SV> 
updateRestoredStateMetaInfo(
             Tuple2<ColumnFamilyHandle, 
RegisteredKeyValueStateBackendMetaInfo<N, SV>> oldStateInfo,
             StateDescriptor<SV> stateDesc,
             TypeSerializer<SV> stateSerializer,
+            TypeSerializer<UK> userKeySerializer,
             TypeSerializer<N> namespaceSerializer)
             throws Exception {
 
@@ -438,6 +455,17 @@ public class ForStKeyedStateBackend<K> implements 
AsyncKeyedStateBackend<K> {
                             + ").");
         }
 
+        if (userKeySerializer != null) {
+            TypeSerializerSchemaCompatibility<UK> 
userKeySerializerCompatibility =
+                    ((RegisteredKeyAndUserKeyValueStateBackendMetaInfo<N, UK, 
SV>)
+                                    restoredKvStateMetaInfo)
+                            .updateUserKeySerializer(userKeySerializer);
+            if (!userKeySerializerCompatibility.isCompatibleAsIs()) {
+                throw new StateMigrationException(
+                        "The new serializer for a MapState requires state 
migration in order for the job to proceed. State migration not support yet.");
+            }
+        }
+
         return restoredKvStateMetaInfo;
     }
 
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
new file mode 100644
index 00000000000..a9120a4edc6
--- /dev/null
+++ 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateMigrationTest.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.state.forst;
+
+import org.apache.flink.api.common.state.v2.MapState;
+import org.apache.flink.api.common.state.v2.MapStateDescriptor;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.StringSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.fs.FileSystem;
+import org.apache.flink.core.testutils.CommonTestUtils;
+import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.SnapshotResult;
+import org.apache.flink.util.IOUtils;
+import org.apache.flink.util.StateMigrationException;
+
+import org.junit.jupiter.api.Test;
+
+import java.util.Collections;
+import java.util.concurrent.RunnableFuture;
+
+import static 
org.apache.flink.state.forst.ForStTestUtils.createKeyedStateBackend;
+import static org.junit.Assert.fail;
+
+/** Tests for {@link ForStListState}. */
+public class ForStStateMigrationTest extends ForStStateTestBase {
+
+    @Test
+    void testFortMapStateKeySchemaChanged() throws Exception {
+        MapStateDescriptor<Integer, String> descriptorV1 =
+                new MapStateDescriptor<>(
+                        "testState", IntSerializer.INSTANCE, 
StringSerializer.INSTANCE);
+
+        MapStateDescriptor<String, String> descriptorV2 =
+                new MapStateDescriptor<>(
+                        "testState", StringSerializer.INSTANCE, 
StringSerializer.INSTANCE);
+
+        MapState<Integer, String> mapState =
+                keyedBackend.createState(1, IntSerializer.INSTANCE, 
descriptorV1);
+        setCurrentContext("test", "test");
+        for (int i = 0; i < 10; i++) {
+            mapState.asyncPut(i, String.valueOf(i));
+        }
+        drain();
+        RunnableFuture<SnapshotResult<KeyedStateHandle>> snapshot =
+                keyedBackend.snapshot(
+                        1L,
+                        System.currentTimeMillis(),
+                        env.getCheckpointStorageAccess()
+                                .resolveCheckpointStorageLocation(
+                                        1L, 
CheckpointStorageLocationReference.getDefault()),
+                        CheckpointOptions.forCheckpointWithDefaultLocation());
+
+        if (!snapshot.isDone()) {
+            snapshot.run();
+        }
+        SnapshotResult<KeyedStateHandle> snapshotResult = snapshot.get();
+        KeyedStateHandle stateHandle = 
snapshotResult.getJobManagerOwnedSnapshot();
+        IOUtils.closeQuietly(keyedBackend);
+        keyedBackend.dispose();
+
+        FileSystem.initialize(new Configuration(), null);
+        Configuration configuration = new Configuration();
+        ForStStateBackend forStStateBackend =
+                new ForStStateBackend().configure(configuration, null);
+        keyedBackend =
+                createKeyedStateBackend(
+                        forStStateBackend,
+                        env,
+                        StringSerializer.INSTANCE,
+                        Collections.singletonList(stateHandle));
+        keyedBackend.setup(aec);
+        try {
+            keyedBackend.createState(1, IntSerializer.INSTANCE, descriptorV2);
+            fail("Expected a state migration exception.");
+        } catch (Exception e) {
+            if (CommonTestUtils.containsCause(e, 
StateMigrationException.class)) {
+                // StateMigrationException expected
+            } else {
+                throw e;
+            }
+        }
+    }
+}
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
index f888546bd94..8a385b027f4 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStStateTestBase.java
@@ -57,6 +57,8 @@ public class ForStStateTestBase {
 
     protected RecordContext<String> context;
 
+    protected MockEnvironment env;
+
     @BeforeEach
     public void setup(@TempDir File temporaryFolder) throws IOException {
         FileSystem.initialize(new Configuration(), null);
@@ -65,11 +67,9 @@ public class ForStStateTestBase {
         ForStStateBackend forStStateBackend =
                 new ForStStateBackend().configure(configuration, null);
 
-        keyedBackend =
-                createKeyedStateBackend(
-                        forStStateBackend,
-                        getMockEnvironment(temporaryFolder),
-                        StringSerializer.INSTANCE);
+        env = getMockEnvironment(temporaryFolder);
+
+        keyedBackend = createKeyedStateBackend(forStStateBackend, env, 
StringSerializer.INSTANCE);
 
         mailboxExecutor =
                 new MailboxExecutorImpl(
diff --git 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStTestUtils.java
 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStTestUtils.java
index 42940c570e2..62cbca93015 100644
--- 
a/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStTestUtils.java
+++ 
b/flink-state-backends/flink-statebackend-forst/src/test/java/org/apache/flink/state/forst/ForStTestUtils.java
@@ -24,15 +24,21 @@ import 
org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackendParametersImpl;
+import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.ttl.TtlTimeProvider;
 
 import java.io.IOException;
+import java.util.Collection;
 import java.util.Collections;
 
 /** Test utils for the ForSt state backend. */
 public final class ForStTestUtils {
+
     public static <K> ForStKeyedStateBackend<K> createKeyedStateBackend(
-            ForStStateBackend forStStateBackend, Environment env, 
TypeSerializer<K> keySerializer)
+            ForStStateBackend forStStateBackend,
+            Environment env,
+            TypeSerializer<K> keySerializer,
+            Collection<KeyedStateHandle> stateHandles)
             throws IOException {
 
         return forStStateBackend.createAsyncKeyedStateBackend(
@@ -47,8 +53,15 @@ public final class ForStTestUtils {
                         TtlTimeProvider.DEFAULT,
                         new UnregisteredMetricsGroup(),
                         (name, value) -> {},
-                        Collections.emptyList(),
+                        stateHandles,
                         new CloseableRegistry(),
                         1.0));
     }
+
+    public static <K> ForStKeyedStateBackend<K> createKeyedStateBackend(
+            ForStStateBackend forStStateBackend, Environment env, 
TypeSerializer<K> keySerializer)
+            throws IOException {
+        return createKeyedStateBackend(
+                forStStateBackend, env, keySerializer, 
Collections.emptyList());
+    }
 }


Reply via email to