This is an automated email from the ASF dual-hosted git repository.

srichter pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git

commit 98e4610f09f35a942e55472b5d358ebe113b0dba
Author: Stefan Richter <srich...@confluent.io>
AuthorDate: Mon Oct 9 14:23:47 2023 +0200

    [FLINK-33246][tests] Add AutoRescalingITCase.
---
 .../flink/runtime/testutils/CommonTestUtils.java   |  27 +
 .../test/checkpointing/AutoRescalingITCase.java    | 968 +++++++++++++++++++++
 .../UpdateJobResourceRequirementsITCase.java       |   2 +-
 3 files changed, 996 insertions(+), 1 deletion(-)

diff --git 
a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
 
b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
index a101a453ff0..da016150500 100644
--- 
a/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
+++ 
b/flink-runtime/src/test/java/org/apache/flink/runtime/testutils/CommonTestUtils.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.testutils;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.JobStatus;
 import org.apache.flink.core.execution.JobClient;
+import org.apache.flink.runtime.checkpoint.CheckpointStatsSnapshot;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStats;
 import org.apache.flink.runtime.execution.ExecutionState;
 import org.apache.flink.runtime.executiongraph.AccessExecutionGraph;
@@ -366,6 +367,32 @@ public class CommonTestUtils {
                 });
     }
 
+    /** Wait for on more completed checkpoint. */
+    public static void waitForOneMoreCheckpoint(JobID jobID, MiniCluster 
miniCluster)
+            throws Exception {
+        final long[] currentCheckpoint = new long[] {-1L};
+        waitUntilCondition(
+                () -> {
+                    AccessExecutionGraph graph = 
miniCluster.getExecutionGraph(jobID).get();
+                    CheckpointStatsSnapshot snapshot = 
graph.getCheckpointStatsSnapshot();
+                    if (snapshot != null) {
+                        long currentCount = 
snapshot.getCounts().getNumberOfCompletedCheckpoints();
+                        if (currentCheckpoint[0] < 0L) {
+                            currentCheckpoint[0] = currentCount;
+                        } else {
+                            return currentCount > currentCheckpoint[0];
+                        }
+                    } else if (graph.getState().isGloballyTerminalState()) {
+                        checkState(
+                                graph.getFailureInfo() != null,
+                                "Job terminated before taking required 
checkpoint.",
+                                graph.getState());
+                        throw graph.getFailureInfo().getException();
+                    }
+                    return false;
+                });
+    }
+
     /**
      * @return the path as {@link java.net.URI} to the latest checkpoint.
      * @throws FlinkJobNotFoundException if job not found
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AutoRescalingITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AutoRescalingITCase.java
new file mode 100644
index 00000000000..403449a388b
--- /dev/null
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/AutoRescalingITCase.java
@@ -0,0 +1,968 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.common.time.Deadline;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.client.program.ClusterClient;
+import org.apache.flink.client.program.rest.RestClusterClient;
+import org.apache.flink.configuration.CheckpointingOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.configuration.JobManagerOptions;
+import org.apache.flink.configuration.NettyShuffleEnvironmentOptions;
+import org.apache.flink.configuration.StateBackendOptions;
+import org.apache.flink.configuration.WebOptions;
+import org.apache.flink.core.testutils.OneShotLatch;
+import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobResourceRequirements;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
+import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
+import org.apache.flink.streaming.api.CheckpointingMode;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.CheckpointConfig;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.sink.v2.DiscardingSink;
+import 
org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.test.util.MiniClusterWithClientResource;
+import org.apache.flink.testutils.TestingUtils;
+import org.apache.flink.testutils.executor.TestExecutorResource;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.AfterClass;
+import org.junit.Before;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.File;
+import java.time.Duration;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.TimeUnit;
+
+import static 
org.apache.flink.runtime.testutils.CommonTestUtils.waitForAllTaskRunning;
+import static 
org.apache.flink.runtime.testutils.CommonTestUtils.waitForOneMoreCheckpoint;
+import static 
org.apache.flink.test.scheduling.UpdateJobResourceRequirementsITCase.waitForAvailableSlots;
+import static 
org.apache.flink.test.scheduling.UpdateJobResourceRequirementsITCase.waitForRunningTasks;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Test checkpoint rescaling under changing resource requirements. This test 
is mostly a variant of
+ * {@link RescalingITCase} with two main differences: (1) We rescale from 
checkpoints instead of
+ * savepoints and (2) rescaling without cancel/restart but triggered by 
changing resource
+ * requirements.
+ */
+@RunWith(Parameterized.class)
+public class AutoRescalingITCase extends TestLogger {
+
+    @ClassRule
+    public static final TestExecutorResource<ScheduledExecutorService> 
EXECUTOR_RESOURCE =
+            TestingUtils.defaultExecutorResource();
+
+    private static final int numTaskManagers = 2;
+    private static final int slotsPerTaskManager = 2;
+    private static final int totalSlots = numTaskManagers * 
slotsPerTaskManager;
+
+    @Parameterized.Parameters(name = "backend = {0}, buffersPerChannel = {1}")
+    public static Collection<Object[]> data() {
+        return Arrays.asList(
+                new Object[][] {
+                    {"rocksdb", 0}, {"rocksdb", 2}, {"filesystem", 0}, 
{"filesystem", 2}
+                });
+    }
+
+    public AutoRescalingITCase(String backend, int buffersPerChannel) {
+        this.backend = backend;
+        this.buffersPerChannel = buffersPerChannel;
+    }
+
+    private final String backend;
+
+    private final int buffersPerChannel;
+
+    private String currentBackend = null;
+
+    enum OperatorCheckpointMethod {
+        NON_PARTITIONED,
+        CHECKPOINTED_FUNCTION,
+        CHECKPOINTED_FUNCTION_BROADCAST,
+        LIST_CHECKPOINTED
+    }
+
+    private static MiniClusterWithClientResource cluster;
+    private static RestClusterClient<?> restClusterClient;
+
+    @ClassRule public static TemporaryFolder temporaryFolder = new 
TemporaryFolder();
+
+    @Before
+    public void setup() throws Exception {
+        // detect parameter change
+        if (!Objects.equals(currentBackend, backend)) {
+            shutDownExistingCluster();
+
+            currentBackend = backend;
+
+            Configuration config = new Configuration();
+
+            final File checkpointDir = temporaryFolder.newFolder();
+            final File savepointDir = temporaryFolder.newFolder();
+
+            config.setString(StateBackendOptions.STATE_BACKEND, 
currentBackend);
+            config.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, 
true);
+            config.setBoolean(CheckpointingOptions.LOCAL_RECOVERY, true);
+            config.setString(
+                    CheckpointingOptions.CHECKPOINTS_DIRECTORY, 
checkpointDir.toURI().toString());
+            config.setString(
+                    CheckpointingOptions.SAVEPOINT_DIRECTORY, 
savepointDir.toURI().toString());
+            config.setInteger(
+                    
NettyShuffleEnvironmentOptions.NETWORK_BUFFERS_PER_CHANNEL, buffersPerChannel);
+
+            config.set(JobManagerOptions.SCHEDULER, 
JobManagerOptions.SchedulerType.Adaptive);
+
+            // speed the test suite up
+            // - lower refresh interval -> controls how fast we invalidate 
ExecutionGraphCache
+            // - lower slot idle timeout -> controls how fast we return idle 
slots to TM
+            config.set(WebOptions.REFRESH_INTERVAL, 50L);
+            config.set(JobManagerOptions.SLOT_IDLE_TIMEOUT, 50L);
+
+            cluster =
+                    new MiniClusterWithClientResource(
+                            new MiniClusterResourceConfiguration.Builder()
+                                    .setConfiguration(config)
+                                    .setNumberTaskManagers(numTaskManagers)
+                                    
.setNumberSlotsPerTaskManager(slotsPerTaskManager)
+                                    .build());
+            cluster.before();
+            restClusterClient = cluster.getRestClusterClient();
+        }
+    }
+
+    @AfterClass
+    public static void shutDownExistingCluster() {
+        if (cluster != null) {
+            cluster.after();
+            cluster = null;
+        }
+    }
+
+    @Test
+    public void testCheckpointRescalingInKeyedState() throws Exception {
+        testCheckpointRescalingKeyedState(false);
+    }
+
+    @Test
+    public void testCheckpointRescalingOutKeyedState() throws Exception {
+        testCheckpointRescalingKeyedState(true);
+    }
+
+    /**
+     * Tests that a job with purely keyed state can be restarted from a 
checkpoint with a different
+     * parallelism.
+     */
+    public void testCheckpointRescalingKeyedState(boolean scaleOut) throws 
Exception {
+        final int numberKeys = 42;
+        final int numberElements = 1000;
+        final int parallelism = scaleOut ? totalSlots / 2 : totalSlots;
+        final int parallelism2 = scaleOut ? totalSlots : totalSlots / 2;
+        final int maxParallelism = 13;
+
+        Duration timeout = Duration.ofMinutes(3);
+        Deadline deadline = Deadline.now().plus(timeout);
+
+        ClusterClient<?> client = cluster.getClusterClient();
+
+        try {
+
+            JobGraph jobGraph =
+                    createJobGraphWithKeyedState(
+                            new Configuration(),
+                            parallelism,
+                            maxParallelism,
+                            numberKeys,
+                            numberElements);
+
+            final JobID jobID = jobGraph.getJobID();
+
+            client.submitJob(jobGraph).get();
+
+            SubtaskIndexSource.SOURCE_LATCH.trigger();
+
+            // wait til the sources have emitted numberElements for each key 
and completed a
+            // checkpoint
+            assertTrue(
+                    SubtaskIndexFlatMapper.workCompletedLatch.await(
+                            deadline.timeLeft().toMillis(), 
TimeUnit.MILLISECONDS));
+
+            // verify the current state
+
+            Set<Tuple2<Integer, Integer>> actualResult = 
CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+
+                expectedResult.add(
+                        Tuple2.of(
+                                
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism, 
keyGroupIndex),
+                                numberElements * key));
+            }
+
+            assertEquals(expectedResult, actualResult);
+
+            // clear the CollectionSink set for the restarted job
+            CollectionSink.clearElementsSet();
+
+            waitForAllTaskRunning(cluster.getMiniCluster(), 
jobGraph.getJobID(), false);
+
+            waitForOneMoreCheckpoint(jobID, cluster.getMiniCluster());
+
+            SubtaskIndexSource.SOURCE_LATCH.reset();
+
+            JobResourceRequirements.Builder builder = 
JobResourceRequirements.newBuilder();
+            for (JobVertex vertex : jobGraph.getVertices()) {
+                builder.setParallelismForJobVertex(vertex.getID(), 
parallelism2, parallelism2);
+            }
+
+            restClusterClient.updateJobResourceRequirements(jobID, 
builder.build()).join();
+
+            waitForRunningTasks(restClusterClient, jobID, 2 * parallelism2);
+            waitForAvailableSlots(restClusterClient, totalSlots - 
parallelism2);
+
+            SubtaskIndexSource.SOURCE_LATCH.trigger();
+
+            client.requestJobResult(jobID).get();
+
+            Set<Tuple2<Integer, Integer>> actualResult2 = 
CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+                expectedResult2.add(
+                        Tuple2.of(
+                                
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism2, 
keyGroupIndex),
+                                key * 2 * numberElements));
+            }
+
+            assertEquals(expectedResult2, actualResult2);
+
+        } finally {
+            // clear the CollectionSink set for the restarted job
+            CollectionSink.clearElementsSet();
+        }
+    }
+
+    /**
+     * Tests that a job cannot be restarted from a checkpoint with a different 
parallelism if the
+     * rescaled operator has non-partitioned state.
+     */
+    @Test
+    public void testCheckpointRescalingNonPartitionedStateCausesException() 
throws Exception {
+        final int parallelism = totalSlots / 2;
+        final int parallelism2 = totalSlots;
+        final int maxParallelism = 13;
+
+        ClusterClient<?> client = cluster.getClusterClient();
+
+        try {
+            JobGraph jobGraph =
+                    createJobGraphWithOperatorState(
+                            parallelism, maxParallelism, 
OperatorCheckpointMethod.NON_PARTITIONED);
+            // make sure the job does not finish before we take a checkpoint
+            StateSourceBase.canFinishLatch = new CountDownLatch(1);
+
+            final JobID jobID = jobGraph.getJobID();
+
+            client.submitJob(jobGraph).get();
+
+            // wait until the operator is started
+            waitForAllTaskRunning(cluster.getMiniCluster(), 
jobGraph.getJobID(), false);
+            // wait until the operator handles some data
+            StateSourceBase.workStartedLatch.await();
+
+            waitForOneMoreCheckpoint(jobID, cluster.getMiniCluster());
+
+            JobResourceRequirements.Builder builder = 
JobResourceRequirements.newBuilder();
+            for (JobVertex vertex : jobGraph.getVertices()) {
+                builder.setParallelismForJobVertex(vertex.getID(), 
parallelism2, parallelism2);
+            }
+
+            restClusterClient.updateJobResourceRequirements(jobID, 
builder.build()).join();
+
+            waitForRunningTasks(restClusterClient, jobID, parallelism2);
+            waitForAvailableSlots(restClusterClient, totalSlots - 
parallelism2);
+
+            StateSourceBase.canFinishLatch.countDown();
+
+            client.requestJobResult(jobID).get();
+        } catch (JobExecutionException exception) {
+            if (!(exception.getCause() instanceof IllegalStateException)) {
+                throw exception;
+            }
+        }
+    }
+
+    /**
+     * Tests that a job with non partitioned state can be restarted from a 
checkpoint with a
+     * different parallelism if the operator with non-partitioned state are 
not rescaled.
+     */
+    @Test
+    public void testCheckpointRescalingWithKeyedAndNonPartitionedState() 
throws Exception {
+        int numberKeys = 42;
+        int numberElements = 1000;
+        int parallelism = totalSlots / 2;
+        int parallelism2 = totalSlots;
+        int maxParallelism = 13;
+
+        Duration timeout = Duration.ofMinutes(3);
+        Deadline deadline = Deadline.now().plus(timeout);
+
+        ClusterClient<?> client = cluster.getClusterClient();
+
+        try {
+
+            JobGraph jobGraph =
+                    createJobGraphWithKeyedAndNonPartitionedOperatorState(
+                            parallelism,
+                            maxParallelism,
+                            parallelism,
+                            numberKeys,
+                            numberElements,
+                            numberElements);
+
+            final JobID jobID = jobGraph.getJobID();
+
+            client.submitJob(jobGraph).get();
+
+            SubtaskIndexSource.SOURCE_LATCH.trigger();
+
+            // wait til the sources have emitted numberElements for each key 
and completed a
+            // checkpoint
+            assertTrue(
+                    SubtaskIndexFlatMapper.workCompletedLatch.await(
+                            deadline.timeLeft().toMillis(), 
TimeUnit.MILLISECONDS));
+
+            // verify the current state
+
+            Set<Tuple2<Integer, Integer>> actualResult = 
CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+
+                expectedResult.add(
+                        Tuple2.of(
+                                
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism, 
keyGroupIndex),
+                                numberElements * key));
+            }
+
+            assertEquals(expectedResult, actualResult);
+
+            // clear the CollectionSink set for the restarted job
+            CollectionSink.clearElementsSet();
+
+            waitForOneMoreCheckpoint(jobID, cluster.getMiniCluster());
+
+            SubtaskIndexSource.SOURCE_LATCH.reset();
+
+            JobResourceRequirements.Builder builder = 
JobResourceRequirements.newBuilder();
+            for (JobVertex vertex : jobGraph.getVertices()) {
+                if (vertex.getMaxParallelism() >= parallelism2) {
+                    builder.setParallelismForJobVertex(vertex.getID(), 
parallelism2, parallelism2);
+                } else {
+                    builder.setParallelismForJobVertex(
+                            vertex.getID(), vertex.getMaxParallelism(), 
vertex.getMaxParallelism());
+                }
+            }
+
+            restClusterClient.updateJobResourceRequirements(jobID, 
builder.build()).join();
+
+            waitForRunningTasks(restClusterClient, jobID, parallelism2);
+            waitForAvailableSlots(restClusterClient, totalSlots - 
parallelism2);
+
+            SubtaskIndexSource.SOURCE_LATCH.trigger();
+
+            client.requestJobResult(jobID).get();
+
+            Set<Tuple2<Integer, Integer>> actualResult2 = 
CollectionSink.getElementsSet();
+
+            Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+            for (int key = 0; key < numberKeys; key++) {
+                int keyGroupIndex = 
KeyGroupRangeAssignment.assignToKeyGroup(key, maxParallelism);
+                expectedResult2.add(
+                        Tuple2.of(
+                                
KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(
+                                        maxParallelism, parallelism2, 
keyGroupIndex),
+                                key * 2 * numberElements));
+            }
+
+            assertEquals(expectedResult2, actualResult2);
+
+        } finally {
+            // clear the CollectionSink set for the restarted job
+            CollectionSink.clearElementsSet();
+        }
+    }
+
+    @Test
+    public void testCheckpointRescalingInPartitionedOperatorState() throws 
Exception {
+        testCheckpointRescalingPartitionedOperatorState(
+                false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
+    }
+
+    @Test
+    public void testCheckpointRescalingOutPartitionedOperatorState() throws 
Exception {
+        testCheckpointRescalingPartitionedOperatorState(
+                true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
+    }
+
+    @Test
+    public void testCheckpointRescalingInBroadcastOperatorState() throws 
Exception {
+        testCheckpointRescalingPartitionedOperatorState(
+                false, 
OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
+    }
+
+    @Test
+    public void testCheckpointRescalingOutBroadcastOperatorState() throws 
Exception {
+        testCheckpointRescalingPartitionedOperatorState(
+                true, 
OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST);
+    }
+
+    /** Tests rescaling of partitioned operator state. */
+    public void testCheckpointRescalingPartitionedOperatorState(
+            boolean scaleOut, OperatorCheckpointMethod checkpointMethod) 
throws Exception {
+        final int parallelism = scaleOut ? totalSlots : totalSlots / 2;
+        final int parallelism2 = scaleOut ? totalSlots / 2 : totalSlots;
+        final int maxParallelism = 13;
+
+        ClusterClient<?> client = cluster.getClusterClient();
+
+        int counterSize = Math.max(parallelism, parallelism2);
+
+        if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION
+                || checkpointMethod == 
OperatorCheckpointMethod.CHECKPOINTED_FUNCTION_BROADCAST) {
+            PartitionedStateSource.checkCorrectSnapshot = new int[counterSize];
+            PartitionedStateSource.checkCorrectRestore = new int[counterSize];
+            PartitionedStateSource.checkCorrectSnapshots.clear();
+        } else {
+            throw new UnsupportedOperationException("Unsupported method:" + 
checkpointMethod);
+        }
+
+        JobGraph jobGraph =
+                createJobGraphWithOperatorState(parallelism, maxParallelism, 
checkpointMethod);
+        // make sure the job does not finish before we take the checkpoint
+        StateSourceBase.canFinishLatch = new CountDownLatch(1);
+
+        final JobID jobID = jobGraph.getJobID();
+
+        client.submitJob(jobGraph).get();
+
+        // wait until the operator is started
+        waitForAllTaskRunning(cluster.getMiniCluster(), jobGraph.getJobID(), 
false);
+        // wait until the operator handles some data
+        StateSourceBase.workStartedLatch.await();
+
+        waitForOneMoreCheckpoint(jobID, cluster.getMiniCluster());
+
+        JobResourceRequirements.Builder builder = 
JobResourceRequirements.newBuilder();
+        for (JobVertex vertex : jobGraph.getVertices()) {
+            builder.setParallelismForJobVertex(vertex.getID(), parallelism2, 
parallelism2);
+        }
+
+        restClusterClient.updateJobResourceRequirements(jobID, 
builder.build()).join();
+
+        waitForRunningTasks(restClusterClient, jobID, 2 * parallelism2);
+        waitForAvailableSlots(restClusterClient, totalSlots - parallelism2);
+
+        StateSourceBase.canFinishLatch.countDown();
+
+        client.requestJobResult(jobID).get();
+
+        int sumExp = 0;
+        int sumAct = 0;
+
+        if (checkpointMethod == 
OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+            for (int c : PartitionedStateSource.checkCorrectSnapshot) {
+                sumExp += c;
+            }
+
+            for (int c : PartitionedStateSource.checkCorrectRestore) {
+                sumAct += c;
+            }
+        } else {
+            for (int c : PartitionedStateSource.checkCorrectSnapshot) {
+                sumExp += c;
+            }
+
+            for (int c : PartitionedStateSource.checkCorrectRestore) {
+                sumAct += c;
+            }
+
+            sumExp *= parallelism2;
+        }
+
+        assertEquals(sumExp, sumAct);
+    }
+
+    // 
------------------------------------------------------------------------------------------------------------------
+
+    private static void configureCheckpointing(CheckpointConfig config) {
+        config.setCheckpointInterval(100);
+        config.setCheckpointingMode(CheckpointingMode.EXACTLY_ONCE);
+        config.enableUnalignedCheckpoints(true);
+    }
+
+    private static JobGraph createJobGraphWithOperatorState(
+            int parallelism, int maxParallelism, OperatorCheckpointMethod 
checkpointMethod) {
+
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        configureCheckpointing(env.getCheckpointConfig());
+        env.setParallelism(parallelism);
+        env.getConfig().setMaxParallelism(maxParallelism);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        StateSourceBase.workStartedLatch = new CountDownLatch(parallelism);
+
+        SourceFunction<Integer> src;
+
+        switch (checkpointMethod) {
+            case CHECKPOINTED_FUNCTION:
+                src = new PartitionedStateSource(false);
+                break;
+            case CHECKPOINTED_FUNCTION_BROADCAST:
+                src = new PartitionedStateSource(true);
+                break;
+            case NON_PARTITIONED:
+                src = new NonPartitionedStateSource();
+                break;
+            default:
+                throw new IllegalArgumentException(checkpointMethod.name());
+        }
+
+        DataStream<Integer> input = env.addSource(src);
+
+        input.sinkTo(new DiscardingSink<>());
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    public static JobGraph createJobGraphWithKeyedState(
+            Configuration configuration,
+            int parallelism,
+            int maxParallelism,
+            int numberKeys,
+            int numberElements) {
+        StreamExecutionEnvironment env =
+                
StreamExecutionEnvironment.getExecutionEnvironment(configuration);
+        env.setParallelism(parallelism);
+        if (0 < maxParallelism) {
+            env.getConfig().setMaxParallelism(maxParallelism);
+        }
+
+        configureCheckpointing(env.getCheckpointConfig());
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().setUseSnapshotCompression(true);
+
+        DataStream<Integer> input =
+                env.addSource(new SubtaskIndexSource(numberKeys, 
numberElements, parallelism))
+                        .keyBy(
+                                new KeySelector<Integer, Integer>() {
+                                    private static final long serialVersionUID 
=
+                                            -7952298871120320940L;
+
+                                    @Override
+                                    public Integer getKey(Integer value) {
+                                        return value;
+                                    }
+                                });
+
+        SubtaskIndexFlatMapper.workCompletedLatch = new 
CountDownLatch(numberKeys);
+
+        DataStream<Tuple2<Integer, Integer>> result =
+                input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+        result.addSink(new CollectionSink<>());
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static JobGraph 
createJobGraphWithKeyedAndNonPartitionedOperatorState(
+            int parallelism,
+            int maxParallelism,
+            int fixedParallelism,
+            int numberKeys,
+            int numberElements,
+            int numberElementsAfterRestart) {
+
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        env.setParallelism(parallelism);
+        env.getConfig().setMaxParallelism(maxParallelism);
+        configureCheckpointing(env.getCheckpointConfig());
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        DataStream<Integer> input =
+                env.addSource(
+                                new SubtaskIndexNonPartitionedStateSource(
+                                        numberKeys,
+                                        numberElements,
+                                        numberElementsAfterRestart,
+                                        parallelism))
+                        .setParallelism(fixedParallelism)
+                        .setMaxParallelism(fixedParallelism)
+                        .keyBy(
+                                new KeySelector<Integer, Integer>() {
+                                    private static final long serialVersionUID 
=
+                                            -7952298871120320940L;
+
+                                    @Override
+                                    public Integer getKey(Integer value) {
+                                        return value;
+                                    }
+                                });
+
+        SubtaskIndexFlatMapper.workCompletedLatch = new 
CountDownLatch(numberKeys);
+
+        DataStream<Tuple2<Integer, Integer>> result =
+                input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+        result.addSink(new CollectionSink<>());
+
+        return env.getStreamGraph().getJobGraph();
+    }
+
+    private static class SubtaskIndexSource extends 
RichParallelSourceFunction<Integer> {
+
+        private static final long serialVersionUID = -400066323594122516L;
+
+        private final int numberKeys;
+
+        private final int originalParallelism;
+        protected int numberElements;
+
+        protected int counter = 0;
+
+        private boolean running = true;
+
+        private static final OneShotLatch SOURCE_LATCH = new OneShotLatch();
+
+        SubtaskIndexSource(int numberKeys, int numberElements, int 
originalParallelism) {
+            this.numberKeys = numberKeys;
+            this.numberElements = numberElements;
+            this.originalParallelism = originalParallelism;
+        }
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            RuntimeContext runtimeContext = getRuntimeContext();
+            final int subtaskIndex = runtimeContext.getIndexOfThisSubtask();
+
+            boolean isRestartedOrRescaled =
+                    runtimeContext.getNumberOfParallelSubtasks() != 
originalParallelism
+                            || runtimeContext.getAttemptNumber() > 0;
+            while (running) {
+                SOURCE_LATCH.await();
+                if (counter < numberElements) {
+                    synchronized (ctx.getCheckpointLock()) {
+                        for (int value = subtaskIndex;
+                                value < numberKeys;
+                                value += 
runtimeContext.getNumberOfParallelSubtasks()) {
+                            ctx.collect(value);
+                        }
+
+                        counter++;
+                    }
+                } else {
+                    if (isRestartedOrRescaled) {
+                        running = false;
+                    } else {
+                        Thread.sleep(100);
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void cancel() {
+            running = false;
+        }
+    }
+
+    private static class SubtaskIndexNonPartitionedStateSource extends 
SubtaskIndexSource
+            implements ListCheckpointed<Integer> {
+
+        private static final long serialVersionUID = 8388073059042040203L;
+        private final int numElementsAfterRestart;
+
+        SubtaskIndexNonPartitionedStateSource(
+                int numberKeys,
+                int numberElements,
+                int numElementsAfterRestart,
+                int originalParallelism) {
+            super(numberKeys, numberElements, originalParallelism);
+            this.numElementsAfterRestart = numElementsAfterRestart;
+        }
+
+        @Override
+        public List<Integer> snapshotState(long checkpointId, long timestamp) {
+            return Collections.singletonList(this.counter);
+        }
+
+        @Override
+        public void restoreState(List<Integer> state) {
+            if (state.size() != 1) {
+                throw new RuntimeException(
+                        "Test failed due to unexpected recovered state size " 
+ state.size());
+            }
+            this.counter = state.get(0);
+            this.numberElements += numElementsAfterRestart;
+        }
+    }
+
+    private static class SubtaskIndexFlatMapper
+            extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>>
+            implements CheckpointedFunction {
+
+        private static final long serialVersionUID = 5273172591283191348L;
+
+        private static CountDownLatch workCompletedLatch = new 
CountDownLatch(1);
+
+        private transient ValueState<Integer> counter;
+        private transient ValueState<Integer> sum;
+
+        private final int numberElements;
+
+        SubtaskIndexFlatMapper(int numberElements) {
+            this.numberElements = numberElements;
+        }
+
+        @Override
+        public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> 
out)
+                throws Exception {
+
+            int count = counter.value() + 1;
+            counter.update(count);
+
+            int s = sum.value() + value;
+            sum.update(s);
+
+            if (count % numberElements == 0) {
+                
out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
+                workCompletedLatch.countDown();
+            }
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) {
+            // all managed, nothing to do.
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) {
+            counter =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("counter", 
Integer.class, 0));
+            sum =
+                    context.getKeyedStateStore()
+                            .getState(new ValueStateDescriptor<>("sum", 
Integer.class, 0));
+        }
+    }
+
+    private static class CollectionSink<IN> implements SinkFunction<IN> {
+
+        private static final Set<Object> elements =
+                Collections.newSetFromMap(new ConcurrentHashMap<>());
+
+        private static final long serialVersionUID = -1652452958040267745L;
+
+        public static <IN> Set<IN> getElementsSet() {
+            return (Set<IN>) elements;
+        }
+
+        public static void clearElementsSet() {
+            elements.clear();
+        }
+
+        @Override
+        public void invoke(IN value) {
+            elements.add(value);
+        }
+    }
+
+    private static class StateSourceBase extends 
RichParallelSourceFunction<Integer> {
+
+        private static final long serialVersionUID = 7512206069681177940L;
+        private static CountDownLatch workStartedLatch = new CountDownLatch(1);
+        private static CountDownLatch canFinishLatch = new CountDownLatch(0);
+
+        protected volatile int counter = 0;
+        protected volatile boolean running = true;
+
+        @Override
+        public void run(SourceContext<Integer> ctx) throws Exception {
+            while (running) {
+                synchronized (ctx.getCheckpointLock()) {
+                    ++counter;
+                    ctx.collect(1);
+                }
+
+                Thread.sleep(2);
+
+                if (counter == 10) {
+                    workStartedLatch.countDown();
+                }
+
+                if (counter >= 500) {
+                    break;
+                }
+            }
+
+            canFinishLatch.await();
+        }
+
+        @Override
+        public void cancel() {
+            running = false;
+        }
+    }
+
+    private static class NonPartitionedStateSource extends StateSourceBase
+            implements ListCheckpointed<Integer> {
+
+        private static final long serialVersionUID = -8108185918123186841L;
+
+        @Override
+        public List<Integer> snapshotState(long checkpointId, long timestamp) {
+            return Collections.singletonList(this.counter);
+        }
+
+        @Override
+        public void restoreState(List<Integer> state) {
+            if (!state.isEmpty()) {
+                this.counter = state.get(0);
+            }
+        }
+    }
+
+    private static class PartitionedStateSource extends StateSourceBase
+            implements CheckpointedFunction {
+
+        private static final long serialVersionUID = -359715965103593462L;
+        private static final int NUM_PARTITIONS = 7;
+
+        private transient ListState<Integer> counterPartitions;
+        private final boolean broadcast;
+
+        private static final ConcurrentHashMap<Long, int[]> 
checkCorrectSnapshots =
+                new ConcurrentHashMap<>();
+        private static int[] checkCorrectSnapshot;
+        private static int[] checkCorrectRestore;
+
+        public PartitionedStateSource(boolean broadcast) {
+            this.broadcast = broadcast;
+        }
+
+        @Override
+        public void snapshotState(FunctionSnapshotContext context) throws 
Exception {
+
+            if (getRuntimeContext().getAttemptNumber() == 0) {
+                int[] snapshot =
+                        checkCorrectSnapshots.computeIfAbsent(
+                                context.getCheckpointId(),
+                                (x) -> new int[checkCorrectRestore.length]);
+                snapshot[getRuntimeContext().getIndexOfThisSubtask()] = 
counter;
+            }
+
+            counterPartitions.clear();
+
+            int div = counter / NUM_PARTITIONS;
+            int mod = counter % NUM_PARTITIONS;
+
+            for (int i = 0; i < NUM_PARTITIONS; ++i) {
+                int partitionValue = div;
+                if (mod > 0) {
+                    --mod;
+                    ++partitionValue;
+                }
+                counterPartitions.add(partitionValue);
+            }
+        }
+
+        @Override
+        public void initializeState(FunctionInitializationContext context) 
throws Exception {
+            if (broadcast) {
+                this.counterPartitions =
+                        context.getOperatorStateStore()
+                                .getUnionListState(
+                                        new ListStateDescriptor<>(
+                                                "counter_partitions", 
IntSerializer.INSTANCE));
+            } else {
+                this.counterPartitions =
+                        context.getOperatorStateStore()
+                                .getListState(
+                                        new ListStateDescriptor<>(
+                                                "counter_partitions", 
IntSerializer.INSTANCE));
+            }
+
+            if (context.isRestored()) {
+                for (int v : counterPartitions.get()) {
+                    counter += v;
+                }
+                
checkCorrectRestore[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+                context.getRestoredCheckpointId()
+                        .ifPresent((id) -> checkCorrectSnapshot = 
checkCorrectSnapshots.get(id));
+            }
+        }
+    }
+}
diff --git 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/UpdateJobResourceRequirementsITCase.java
 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/UpdateJobResourceRequirementsITCase.java
index 74829f34caf..e0a88aaab2c 100644
--- 
a/flink-tests/src/test/java/org/apache/flink/test/scheduling/UpdateJobResourceRequirementsITCase.java
+++ 
b/flink-tests/src/test/java/org/apache/flink/test/scheduling/UpdateJobResourceRequirementsITCase.java
@@ -202,7 +202,7 @@ public class UpdateJobResourceRequirementsITCase {
         }
     }
 
-    private static int getNumberRunningTasks(RestClusterClient<?> 
restClusterClient, JobID jobId) {
+    public static int getNumberRunningTasks(RestClusterClient<?> 
restClusterClient, JobID jobId) {
         final JobDetailsInfo jobDetailsInfo = 
restClusterClient.getJobDetails(jobId).join();
         return jobDetailsInfo.getJobVertexInfos().stream()
                 .map(JobDetailsInfo.JobVertexDetailsInfo::getTasksPerState)


Reply via email to