This is an automated email from the ASF dual-hosted git repository. srichter pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push: new 2ed858318f3 [FLINK-32345][state] Improve parallel download of RocksDB incremental state. (#22788) 2ed858318f3 is described below commit 2ed858318f38eb9913e5f4b3019be6b6d0a8e6fb Author: Stefan Richter <srich...@apache.org> AuthorDate: Wed Jun 21 12:17:08 2023 +0200 [FLINK-32345][state] Improve parallel download of RocksDB incremental state. (#22788) * [FLINK-32345] Improve parallel download of RocksDB incremental state. This commit improves RocksDBStateDownloader to support parallelized state download across multiple state types and across multiple state handles. This can improve our download times for scale-in. --- .../java/org/apache/flink/util/CollectionUtil.java | 81 +++++++++++ .../org/apache/flink/util/CollectionUtilTest.java | 47 ++++++ .../state/RocksDBIncrementalCheckpointUtils.java | 2 +- .../streaming/state/RocksDBStateDownloader.java | 121 +++++++++------- .../streaming/state/StateHandleDownloadSpec.java | 49 +++++++ .../RocksDBIncrementalRestoreOperation.java | 157 ++++++++++++--------- .../state/RocksDBStateDownloaderTest.java | 132 +++++++++++++---- 7 files changed, 438 insertions(+), 151 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java index 8c96e3e3554..18f4c4313c4 100644 --- a/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java +++ b/flink-core/src/main/java/org/apache/flink/util/CollectionUtil.java @@ -19,6 +19,7 @@ package org.apache.flink.util; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import javax.annotation.Nullable; @@ -27,7 +28,10 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -45,6 +49,9 @@ public final class CollectionUtil { /** A safe maximum size for arrays in the JVM. */ public static final int MAX_ARRAY_SIZE = Integer.MAX_VALUE - 8; + /** The default load factor for hash maps create with this util class. */ + static final float HASH_MAP_DEFAULT_LOAD_FACTOR = 0.75f; + private CollectionUtil() { throw new AssertionError(); } @@ -133,4 +140,78 @@ public final class CollectionUtil { } return Collections.unmodifiableMap(map); } + + /** + * Creates a new {@link HashMap} of the expected size, i.e. a hash map that will not rehash if + * expectedSize many keys are inserted, considering the load factor. + * + * @param expectedSize the expected size of the created hash map. + * @return a new hash map instance with enough capacity for the expected size. + * @param <K> the type of keys maintained by this map. + * @param <V> the type of mapped values. + */ + public static <K, V> HashMap<K, V> newHashMapWithExpectedSize(int expectedSize) { + return new HashMap<>( + computeRequiredCapacity(expectedSize, HASH_MAP_DEFAULT_LOAD_FACTOR), + HASH_MAP_DEFAULT_LOAD_FACTOR); + } + + /** + * Creates a new {@link LinkedHashMap} of the expected size, i.e. a hash map that will not + * rehash if expectedSize many keys are inserted, considering the load factor. + * + * @param expectedSize the expected size of the created hash map. + * @return a new hash map instance with enough capacity for the expected size. + * @param <K> the type of keys maintained by this map. + * @param <V> the type of mapped values. + */ + public static <K, V> LinkedHashMap<K, V> newLinkedHashMapWithExpectedSize(int expectedSize) { + return new LinkedHashMap<>( + computeRequiredCapacity(expectedSize, HASH_MAP_DEFAULT_LOAD_FACTOR), + HASH_MAP_DEFAULT_LOAD_FACTOR); + } + + /** + * Creates a new {@link HashSet} of the expected size, i.e. a hash set that will not rehash if + * expectedSize many unique elements are inserted, considering the load factor. + * + * @param expectedSize the expected size of the created hash map. + * @return a new hash map instance with enough capacity for the expected size. + * @param <E> the type of elements stored by this set. + */ + public static <E> HashSet<E> newHashSetWithExpectedSize(int expectedSize) { + return new HashSet<>( + computeRequiredCapacity(expectedSize, HASH_MAP_DEFAULT_LOAD_FACTOR), + HASH_MAP_DEFAULT_LOAD_FACTOR); + } + + /** + * Creates a new {@link LinkedHashSet} of the expected size, i.e. a hash set that will not + * rehash if expectedSize many unique elements are inserted, considering the load factor. + * + * @param expectedSize the expected size of the created hash map. + * @return a new hash map instance with enough capacity for the expected size. + * @param <E> the type of elements stored by this set. + */ + public static <E> LinkedHashSet<E> newLinkedHashSetWithExpectedSize(int expectedSize) { + return new LinkedHashSet<>( + computeRequiredCapacity(expectedSize, HASH_MAP_DEFAULT_LOAD_FACTOR), + HASH_MAP_DEFAULT_LOAD_FACTOR); + } + + /** + * Helper method to compute the right capacity for a hash map with load factor + * HASH_MAP_DEFAULT_LOAD_FACTOR. + */ + @VisibleForTesting + static int computeRequiredCapacity(int expectedSize, float loadFactor) { + Preconditions.checkArgument(expectedSize >= 0); + Preconditions.checkArgument(loadFactor > 0f); + if (expectedSize <= 2) { + return expectedSize + 1; + } + return expectedSize < (Integer.MAX_VALUE / 2 + 1) + ? (int) ((float) expectedSize / loadFactor + 1.0f) + : Integer.MAX_VALUE; + } } diff --git a/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java b/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java index abeec238879..de749f9aadb 100644 --- a/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java +++ b/flink-core/src/test/java/org/apache/flink/util/CollectionUtilTest.java @@ -18,6 +18,7 @@ package org.apache.flink.util; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -25,6 +26,7 @@ import java.util.Arrays; import java.util.Collection; import java.util.List; +import static org.apache.flink.util.CollectionUtil.HASH_MAP_DEFAULT_LOAD_FACTOR; import static org.assertj.core.api.Assertions.assertThat; /** Tests for java collection utilities. */ @@ -52,4 +54,49 @@ public class CollectionUtilTest { final Object element = new Object(); assertThat(CollectionUtil.ofNullable(element)).singleElement().isEqualTo(element); } + + @Test + public void testComputeCapacity() { + Assertions.assertEquals( + 1, CollectionUtil.computeRequiredCapacity(0, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 2, CollectionUtil.computeRequiredCapacity(1, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 3, CollectionUtil.computeRequiredCapacity(2, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 5, CollectionUtil.computeRequiredCapacity(3, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 6, CollectionUtil.computeRequiredCapacity(4, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 134, CollectionUtil.computeRequiredCapacity(100, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 1334, CollectionUtil.computeRequiredCapacity(1000, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + 13334, CollectionUtil.computeRequiredCapacity(10000, HASH_MAP_DEFAULT_LOAD_FACTOR)); + + Assertions.assertEquals(20001, CollectionUtil.computeRequiredCapacity(10000, 0.5f)); + + Assertions.assertEquals(100001, CollectionUtil.computeRequiredCapacity(10000, 0.1f)); + + Assertions.assertEquals( + 1431655808, + CollectionUtil.computeRequiredCapacity( + Integer.MAX_VALUE / 2, HASH_MAP_DEFAULT_LOAD_FACTOR)); + Assertions.assertEquals( + Integer.MAX_VALUE, + CollectionUtil.computeRequiredCapacity( + 1 + Integer.MAX_VALUE / 2, HASH_MAP_DEFAULT_LOAD_FACTOR)); + + try { + CollectionUtil.computeRequiredCapacity(-1, HASH_MAP_DEFAULT_LOAD_FACTOR); + Assertions.fail(); + } catch (IllegalArgumentException expected) { + } + + try { + CollectionUtil.computeRequiredCapacity(Integer.MIN_VALUE, HASH_MAP_DEFAULT_LOAD_FACTOR); + Assertions.fail(); + } catch (IllegalArgumentException expected) { + } + } } diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java index 23c78675068..54121709876 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBIncrementalCheckpointUtils.java @@ -186,7 +186,7 @@ public class RocksDBIncrementalCheckpointUtils { Score handleScore = stateHandleEvaluator( rawStateHandle, targetKeyGroupRange, overlapFractionThreshold); - if (handleScore.compareTo(bestScore) > 0) { + if (bestStateHandle == null || handleScore.compareTo(bestScore) > 0) { bestStateHandle = rawStateHandle; bestScore = handleScore; } diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java index d06790ced8a..0a1e43e9700 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloader.java @@ -19,23 +19,27 @@ package org.apache.flink.contrib.streaming.state; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.fs.FSDataInputStream; -import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle; import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.util.ExceptionUtils; +import org.apache.flink.util.FileUtils; import org.apache.flink.util.FlinkRuntimeException; +import org.apache.flink.util.IOUtils; import org.apache.flink.util.concurrent.FutureUtils; import org.apache.flink.util.function.ThrowingRunnable; +import org.apache.flink.shaded.guava30.com.google.common.collect.Streams; + import java.io.IOException; import java.io.OutputStream; import java.nio.file.Files; import java.nio.file.Path; -import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** Help class for downloading RocksDB state files. */ public class RocksDBStateDownloader extends RocksDBStateDataTransfer { @@ -44,45 +48,43 @@ public class RocksDBStateDownloader extends RocksDBStateDataTransfer { } /** - * Transfer all state data to the target directory using specified number of threads. + * Transfer all state data to the target directory, as specified in the download requests. * - * @param restoreStateHandle Handles used to retrieve the state data. - * @param dest The target directory which the state data will be stored. - * @throws Exception Thrown if can not transfer all the state data. + * @param downloadRequests the list of downloads. + * @throws Exception If anything about the download goes wrong. */ public void transferAllStateDataToDirectory( - IncrementalRemoteKeyedStateHandle restoreStateHandle, - Path dest, - CloseableRegistry closeableRegistry) - throws Exception { - - final Map<StateHandleID, StreamStateHandle> sstFiles = restoreStateHandle.getSharedState(); - final Map<StateHandleID, StreamStateHandle> miscFiles = - restoreStateHandle.getPrivateState(); - - downloadDataForAllStateHandles(sstFiles, dest, closeableRegistry); - downloadDataForAllStateHandles(miscFiles, dest, closeableRegistry); - } - - /** - * Copies all the files from the given stream state handles to the given path, renaming the - * files w.r.t. their {@link StateHandleID}. - */ - private void downloadDataForAllStateHandles( - Map<StateHandleID, StreamStateHandle> stateHandleMap, - Path restoreInstancePath, + Collection<StateHandleDownloadSpec> downloadRequests, CloseableRegistry closeableRegistry) throws Exception { + // We use this closer for fine-grained shutdown of all parallel downloading. + CloseableRegistry internalCloser = new CloseableRegistry(); + // Make sure we also react to external close signals. + closeableRegistry.registerCloseable(internalCloser); + List<CompletableFuture<Void>> futures = Collections.emptyList(); try { - List<Runnable> runnables = - createDownloadRunnables(stateHandleMap, restoreInstancePath, closeableRegistry); - List<CompletableFuture<Void>> futures = new ArrayList<>(runnables.size()); - for (Runnable runnable : runnables) { - futures.add(CompletableFuture.runAsync(runnable, executorService)); + try { + futures = + transferAllStateDataToDirectoryAsync(downloadRequests, internalCloser) + .collect(Collectors.toList()); + // Wait until either all futures completed successfully or one failed exceptionally. + FutureUtils.waitForAll(futures).get(); + } finally { + // Unregister and close the internal closer. In a failure case, this should + // interrupt ongoing downloads. + if (closeableRegistry.unregisterCloseable(internalCloser)) { + IOUtils.closeQuietly(internalCloser); + } } - FutureUtils.waitForAll(futures).get(); - } catch (ExecutionException e) { + } catch (Exception e) { + // Cleanup on exception: cancel all tasks and delete the created directories + futures.forEach(future -> future.cancel(true)); + downloadRequests.stream() + .map(StateHandleDownloadSpec::getDownloadDestination) + .map(Path::toFile) + .forEach(FileUtils::deleteDirectoryQuietly); + // Error reporting Throwable throwable = ExceptionUtils.stripExecutionException(e); throwable = ExceptionUtils.stripException(throwable, RuntimeException.class); if (throwable instanceof IOException) { @@ -93,24 +95,39 @@ public class RocksDBStateDownloader extends RocksDBStateDataTransfer { } } - private List<Runnable> createDownloadRunnables( - Map<StateHandleID, StreamStateHandle> stateHandleMap, - Path restoreInstancePath, + /** Asynchronously runs the specified download requests on executorService. */ + private Stream<CompletableFuture<Void>> transferAllStateDataToDirectoryAsync( + Collection<StateHandleDownloadSpec> handleWithPaths, CloseableRegistry closeableRegistry) { - List<Runnable> runnables = new ArrayList<>(stateHandleMap.size()); - for (Map.Entry<StateHandleID, StreamStateHandle> entry : stateHandleMap.entrySet()) { - StateHandleID stateHandleID = entry.getKey(); - StreamStateHandle remoteFileHandle = entry.getValue(); - - Path path = restoreInstancePath.resolve(stateHandleID.toString()); - - runnables.add( - ThrowingRunnable.unchecked( - () -> - downloadDataForStateHandle( - path, remoteFileHandle, closeableRegistry))); - } - return runnables; + return handleWithPaths.stream() + .flatMap( + downloadRequest -> + // Take all files from shared and private state. + Streams.concat( + downloadRequest.getStateHandle().getSharedState() + .entrySet().stream(), + downloadRequest.getStateHandle().getPrivateState() + .entrySet().stream()) + .map( + // Create one runnable for each StreamStateHandle + entry -> { + StateHandleID stateHandleID = entry.getKey(); + StreamStateHandle remoteFileHandle = + entry.getValue(); + Path downloadDest = + downloadRequest + .getDownloadDestination() + .resolve( + stateHandleID + .toString()); + return ThrowingRunnable.unchecked( + () -> + downloadDataForStateHandle( + downloadDest, + remoteFileHandle, + closeableRegistry)); + })) + .map(runnable -> CompletableFuture.runAsync(runnable, executorService)); } /** Copies the file from a single state handle to the given path. */ diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/StateHandleDownloadSpec.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/StateHandleDownloadSpec.java new file mode 100644 index 00000000000..93a33fdc6fa --- /dev/null +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/StateHandleDownloadSpec.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.contrib.streaming.state; + +import org.apache.flink.runtime.state.IncrementalRemoteKeyedStateHandle; + +import java.nio.file.Path; + +/** + * This class represents a download specification for the content of one {@link + * IncrementalRemoteKeyedStateHandle} to a target {@link Path}. + */ +public class StateHandleDownloadSpec { + /** The state handle to download. */ + private final IncrementalRemoteKeyedStateHandle stateHandle; + + /** The path to which the content of the state handle shall be downloaded. */ + private final Path downloadDestination; + + public StateHandleDownloadSpec( + IncrementalRemoteKeyedStateHandle stateHandle, Path downloadDestination) { + this.stateHandle = stateHandle; + this.downloadDestination = downloadDestination; + } + + public IncrementalRemoteKeyedStateHandle getStateHandle() { + return stateHandle; + } + + public Path getDownloadDestination() { + return downloadDestination; + } +} diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java index d6ec9ae6055..89998b8768a 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/restore/RocksDBIncrementalRestoreOperation.java @@ -27,6 +27,7 @@ import org.apache.flink.contrib.streaming.state.RocksDBOperationUtils; import org.apache.flink.contrib.streaming.state.RocksDBStateDownloader; import org.apache.flink.contrib.streaming.state.RocksDBWriteBatchWrapper; import org.apache.flink.contrib.streaming.state.RocksIteratorWrapper; +import org.apache.flink.contrib.streaming.state.StateHandleDownloadSpec; import org.apache.flink.contrib.streaming.state.ttl.RocksDbTtlCompactFiltersManager; import org.apache.flink.core.fs.CloseableRegistry; import org.apache.flink.core.memory.DataInputView; @@ -46,8 +47,10 @@ import org.apache.flink.runtime.state.StateHandleID; import org.apache.flink.runtime.state.StateSerializerProvider; import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.runtime.state.metainfo.StateMetaInfoSnapshot; +import org.apache.flink.util.CollectionUtil; import org.apache.flink.util.FileUtils; import org.apache.flink.util.IOUtils; +import org.apache.flink.util.Preconditions; import org.apache.flink.util.StateMigrationException; import org.rocksdb.ColumnFamilyDescriptor; @@ -69,6 +72,7 @@ import java.io.InputStream; import java.nio.file.Path; import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; @@ -186,12 +190,12 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper IncrementalRemoteKeyedStateHandle incrementalRemoteKeyedStateHandle = (IncrementalRemoteKeyedStateHandle) keyedStateHandle; restorePreviousIncrementalFilesStatus(incrementalRemoteKeyedStateHandle); - restoreFromRemoteState(incrementalRemoteKeyedStateHandle); + restoreBaseDBFromRemoteState(incrementalRemoteKeyedStateHandle); } else if (keyedStateHandle instanceof IncrementalLocalKeyedStateHandle) { IncrementalLocalKeyedStateHandle incrementalLocalKeyedStateHandle = (IncrementalLocalKeyedStateHandle) keyedStateHandle; restorePreviousIncrementalFilesStatus(incrementalLocalKeyedStateHandle); - restoreFromLocalState(incrementalLocalKeyedStateHandle); + restoreBaseDBFromLocalState(incrementalLocalKeyedStateHandle); } else { throw unexpectedStateHandleException( new Class[] { @@ -213,20 +217,46 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper lastCompletedCheckpointId = localKeyedStateHandle.getCheckpointId(); } - private void restoreFromRemoteState(IncrementalRemoteKeyedStateHandle stateHandle) + private void restoreBaseDBFromRemoteState(IncrementalRemoteKeyedStateHandle stateHandle) throws Exception { // used as restore source for IncrementalRemoteKeyedStateHandle final Path tmpRestoreInstancePath = instanceBasePath.getAbsoluteFile().toPath().resolve(UUID.randomUUID().toString()); + final StateHandleDownloadSpec downloadRequest = + new StateHandleDownloadSpec(stateHandle, tmpRestoreInstancePath); try { - restoreFromLocalState( - transferRemoteStateToLocalDirectory(tmpRestoreInstancePath, stateHandle)); + transferRemoteStateToLocalDirectory(Collections.singletonList(downloadRequest)); + restoreBaseDBFromDownloadedState(downloadRequest); } finally { - cleanUpPathQuietly(tmpRestoreInstancePath); + cleanUpPathQuietly(downloadRequest.getDownloadDestination()); } } - private void restoreFromLocalState(IncrementalLocalKeyedStateHandle localKeyedStateHandle) + /** + * This helper method creates a {@link IncrementalLocalKeyedStateHandle} for state that was + * previously downloaded for a {@link IncrementalRemoteKeyedStateHandle} and then invokes the + * restore procedure for local state on the downloaded state. + * + * @param downloadedState the specification of a completed state download. + * @throws Exception for restore problems. + */ + private void restoreBaseDBFromDownloadedState(StateHandleDownloadSpec downloadedState) + throws Exception { + // since we transferred all remote state to a local directory, we can use the same code + // as for local recovery. + IncrementalRemoteKeyedStateHandle stateHandle = downloadedState.getStateHandle(); + restoreBaseDBFromLocalState( + new IncrementalLocalKeyedStateHandle( + stateHandle.getBackendIdentifier(), + stateHandle.getCheckpointId(), + new DirectoryStateHandle(downloadedState.getDownloadDestination()), + stateHandle.getKeyGroupRange(), + stateHandle.getMetaStateHandle(), + stateHandle.getSharedState())); + } + + /** Restores RocksDB instance from local state. */ + private void restoreBaseDBFromLocalState(IncrementalLocalKeyedStateHandle localKeyedStateHandle) throws Exception { KeyedBackendSerializationProxy<K> serializationProxy = readMetaData(localKeyedStateHandle.getMetaDataState()); @@ -246,26 +276,13 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper restoreSourcePath); } - private IncrementalLocalKeyedStateHandle transferRemoteStateToLocalDirectory( - Path temporaryRestoreInstancePath, IncrementalRemoteKeyedStateHandle restoreStateHandle) - throws Exception { - + private void transferRemoteStateToLocalDirectory( + Collection<StateHandleDownloadSpec> downloadRequests) throws Exception { try (RocksDBStateDownloader rocksDBStateDownloader = new RocksDBStateDownloader(numberOfTransferringThreads)) { rocksDBStateDownloader.transferAllStateDataToDirectory( - restoreStateHandle, temporaryRestoreInstancePath, cancelStreamRegistry); + downloadRequests, cancelStreamRegistry); } - - // since we transferred all remote state to a local directory, we can use the same code as - // for - // local recovery. - return new IncrementalLocalKeyedStateHandle( - restoreStateHandle.getBackendIdentifier(), - restoreStateHandle.getCheckpointId(), - new DirectoryStateHandle(temporaryRestoreInstancePath), - restoreStateHandle.getKeyGroupRange(), - restoreStateHandle.getMetaStateHandle(), - restoreStateHandle.getSharedState()); } private void cleanUpPathQuietly(@Nonnull Path path) { @@ -284,19 +301,39 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper private void restoreWithRescaling(Collection<KeyedStateHandle> restoreStateHandles) throws Exception { - // Prepare for restore with rescaling - KeyedStateHandle initialHandle = + Preconditions.checkArgument(restoreStateHandles != null && !restoreStateHandles.isEmpty()); + + Map<StateHandleID, StateHandleDownloadSpec> allDownloadSpecs = + CollectionUtil.newHashMapWithExpectedSize(restoreStateHandles.size()); + + // Choose the best state handle for the initial DB + final KeyedStateHandle selectedInitialHandle = RocksDBIncrementalCheckpointUtils.chooseTheBestStateHandleForInitial( restoreStateHandles, keyGroupRange, overlapFractionThreshold); - // Init base DB instance - if (initialHandle != null) { - restoreStateHandles.remove(initialHandle); - initDBWithRescaling(initialHandle); - } else { - this.rocksHandle.openDB(); + Preconditions.checkNotNull(selectedInitialHandle); + + final Path absolutInstanceBasePath = instanceBasePath.getAbsoluteFile().toPath(); + + // Prepare and collect all the download request to pull remote state to a local directory + for (KeyedStateHandle stateHandle : restoreStateHandles) { + if (!(stateHandle instanceof IncrementalRemoteKeyedStateHandle)) { + throw unexpectedStateHandleException( + IncrementalRemoteKeyedStateHandle.class, stateHandle.getClass()); + } + StateHandleDownloadSpec downloadRequest = + new StateHandleDownloadSpec( + (IncrementalRemoteKeyedStateHandle) stateHandle, + absolutInstanceBasePath.resolve(UUID.randomUUID().toString())); + allDownloadSpecs.put(stateHandle.getStateHandleId(), downloadRequest); } + // Process all state downloads + transferRemoteStateToLocalDirectory(allDownloadSpecs.values()); + + // Init the base DB instance with the initial state + initBaseDBForRescaling(allDownloadSpecs.remove(selectedInitialHandle.getStateHandleId())); + // Transfer remaining key-groups from temporary instance into base DB byte[] startKeyGroupPrefixBytes = new byte[keyGroupPrefixBytes]; CompositeKeySerializationUtils.serializeKeyGroup( @@ -306,24 +343,14 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper CompositeKeySerializationUtils.serializeKeyGroup( keyGroupRange.getEndKeyGroup() + 1, stopKeyGroupPrefixBytes); - for (KeyedStateHandle rawStateHandle : restoreStateHandles) { - - if (!(rawStateHandle instanceof IncrementalRemoteKeyedStateHandle)) { - throw unexpectedStateHandleException( - IncrementalRemoteKeyedStateHandle.class, rawStateHandle.getClass()); - } - + // Insert all remaining state through creating temporary RocksDB instances + for (StateHandleDownloadSpec downloadRequest : allDownloadSpecs.values()) { logger.info( - "Starting to restore from state handle: {} with rescaling.", rawStateHandle); - Path temporaryRestoreInstancePath = - instanceBasePath - .getAbsoluteFile() - .toPath() - .resolve(UUID.randomUUID().toString()); + "Starting to restore from state handle: {} with rescaling.", + downloadRequest.getStateHandle()); + try (RestoredDBInstance tmpRestoreDBInfo = - restoreDBInstanceFromStateHandle( - (IncrementalRemoteKeyedStateHandle) rawStateHandle, - temporaryRestoreInstancePath); + restoreTempDBInstanceFromDownloadedState(downloadRequest); RocksDBWriteBatchWrapper writeBatchWrapper = new RocksDBWriteBatchWrapper( this.rocksHandle.getDb(), writeBatchSize)) { @@ -335,12 +362,13 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper // iterating only the requested descriptors automatically skips the default column // family handle - for (int i = 0; i < tmpColumnFamilyDescriptors.size(); ++i) { - ColumnFamilyHandle tmpColumnFamilyHandle = tmpColumnFamilyHandles.get(i); + for (int descIdx = 0; descIdx < tmpColumnFamilyDescriptors.size(); ++descIdx) { + ColumnFamilyHandle tmpColumnFamilyHandle = tmpColumnFamilyHandles.get(descIdx); ColumnFamilyHandle targetColumnFamilyHandle = this.rocksHandle.getOrRegisterStateColumnFamilyHandle( - null, tmpRestoreDBInfo.stateMetaInfoSnapshots.get(i)) + null, + tmpRestoreDBInfo.stateMetaInfoSnapshots.get(descIdx)) .columnFamilyHandle; try (RocksIteratorWrapper iterator = @@ -369,19 +397,19 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper } // releases native iterator resources } logger.info( - "Finished restoring from state handle: {} with rescaling.", rawStateHandle); + "Finished restoring from state handle: {} with rescaling.", + downloadRequest.getStateHandle()); } finally { - cleanUpPathQuietly(temporaryRestoreInstancePath); + cleanUpPathQuietly(downloadRequest.getDownloadDestination()); } } } - private void initDBWithRescaling(KeyedStateHandle initialHandle) throws Exception { - - assert (initialHandle instanceof IncrementalRemoteKeyedStateHandle); + private void initBaseDBForRescaling(StateHandleDownloadSpec downloadedInitialState) + throws Exception { // 1. Restore base DB from selected initial handle - restoreFromRemoteState((IncrementalRemoteKeyedStateHandle) initialHandle); + restoreBaseDBFromDownloadedState(downloadedInitialState); // 2. Clip the base DB instance try { @@ -389,7 +417,7 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper this.rocksHandle.getDb(), this.rocksHandle.getColumnFamilyHandles(), keyGroupRange, - initialHandle.getKeyGroupRange(), + downloadedInitialState.getStateHandle().getKeyGroupRange(), keyGroupPrefixBytes); } catch (RocksDBException e) { String errMsg = "Failed to clip DB after initialization."; @@ -441,18 +469,11 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper } } - private RestoredDBInstance restoreDBInstanceFromStateHandle( - IncrementalRemoteKeyedStateHandle restoreStateHandle, Path temporaryRestoreInstancePath) - throws Exception { - - try (RocksDBStateDownloader rocksDBStateDownloader = - new RocksDBStateDownloader(numberOfTransferringThreads)) { - rocksDBStateDownloader.transferAllStateDataToDirectory( - restoreStateHandle, temporaryRestoreInstancePath, cancelStreamRegistry); - } + private RestoredDBInstance restoreTempDBInstanceFromDownloadedState( + StateHandleDownloadSpec downloadRequest) throws Exception { KeyedBackendSerializationProxy<K> serializationProxy = - readMetaData(restoreStateHandle.getMetaStateHandle()); + readMetaData(downloadRequest.getStateHandle().getMetaStateHandle()); // read meta data List<StateMetaInfoSnapshot> stateMetaInfoSnapshots = serializationProxy.getStateMetaInfoSnapshots(); @@ -465,7 +486,7 @@ public class RocksDBIncrementalRestoreOperation<K> implements RocksDBRestoreOper RocksDB restoreDb = RocksDBOperationUtils.openDB( - temporaryRestoreInstancePath.toString(), + downloadRequest.getDownloadDestination().toString(), columnFamilyDescriptors, columnFamilyHandles, RocksDBOperationUtils.createColumnFamilyOptions( diff --git a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java index fcce8674887..2f903644797 100644 --- a/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java +++ b/flink-state-backends/flink-statebackend-rocksdb/src/test/java/org/apache/flink/contrib/streaming/state/RocksDBStateDownloaderTest.java @@ -28,6 +28,7 @@ import org.apache.flink.runtime.state.TestStreamStateHandle; import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; import org.apache.flink.util.TestLogger; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; @@ -37,6 +38,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -73,8 +75,10 @@ public class RocksDBStateDownloaderTest extends TestLogger { try (RocksDBStateDownloader rocksDBStateDownloader = new RocksDBStateDownloader(5)) { rocksDBStateDownloader.transferAllStateDataToDirectory( - incrementalKeyedStateHandle, - temporaryFolder.newFolder().toPath(), + Collections.singletonList( + new StateHandleDownloadSpec( + incrementalKeyedStateHandle, + temporaryFolder.newFolder().toPath())), new CloseableRegistry()); fail(); } catch (Exception e) { @@ -85,46 +89,69 @@ public class RocksDBStateDownloaderTest extends TestLogger { /** Tests that download files with multi-thread correctly. */ @Test public void testMultiThreadRestoreCorrectly() throws Exception { - Random random = new Random(); - int contentNum = 6; - byte[][] contents = new byte[contentNum][]; - for (int i = 0; i < contentNum; ++i) { - contents[i] = new byte[random.nextInt(100000) + 1]; - random.nextBytes(contents[i]); + int numRemoteHandles = 3; + int numSubHandles = 6; + byte[][][] contents = createContents(numRemoteHandles, numSubHandles); + List<StateHandleDownloadSpec> downloadRequests = new ArrayList<>(numRemoteHandles); + for (int i = 0; i < numRemoteHandles; ++i) { + downloadRequests.add( + createDownloadRequestForContent( + temporaryFolder.newFolder().toPath(), contents[i], i)); } - List<StreamStateHandle> handles = new ArrayList<>(contentNum); - for (int i = 0; i < contentNum; ++i) { - handles.add(new ByteStreamStateHandle(String.format("state%d", i), contents[i])); + try (RocksDBStateDownloader rocksDBStateDownloader = new RocksDBStateDownloader(4)) { + rocksDBStateDownloader.transferAllStateDataToDirectory( + downloadRequests, new CloseableRegistry()); } - Map<StateHandleID, StreamStateHandle> sharedStates = new HashMap<>(contentNum); - Map<StateHandleID, StreamStateHandle> privateStates = new HashMap<>(contentNum); - for (int i = 0; i < contentNum; ++i) { - sharedStates.put(new StateHandleID(String.format("sharedState%d", i)), handles.get(i)); - privateStates.put( - new StateHandleID(String.format("privateState%d", i)), handles.get(i)); + for (int i = 0; i < numRemoteHandles; ++i) { + StateHandleDownloadSpec downloadRequest = downloadRequests.get(i); + Path dstPath = downloadRequest.getDownloadDestination(); + Assert.assertTrue(dstPath.toFile().exists()); + for (int j = 0; j < numSubHandles; ++j) { + assertStateContentEqual( + contents[i][j], dstPath.resolve(String.format("sharedState-%d-%d", i, j))); + } } + } - IncrementalRemoteKeyedStateHandle incrementalKeyedStateHandle = - new IncrementalRemoteKeyedStateHandle( - UUID.randomUUID(), - KeyGroupRange.of(0, 1), - 1, - sharedStates, - privateStates, - handles.get(0)); + /** Tests cleanup on download failures. */ + @Test + public void testMultiThreadCleanupOnFailure() throws Exception { + int numRemoteHandles = 3; + int numSubHandles = 6; + byte[][][] contents = createContents(numRemoteHandles, numSubHandles); + List<StateHandleDownloadSpec> downloadRequests = new ArrayList<>(numRemoteHandles); + for (int i = 0; i < numRemoteHandles; ++i) { + downloadRequests.add( + createDownloadRequestForContent( + temporaryFolder.newFolder().toPath(), contents[i], i)); + } - Path dstPath = temporaryFolder.newFolder().toPath(); + IncrementalRemoteKeyedStateHandle stateHandle = + downloadRequests.get(downloadRequests.size() - 1).getStateHandle(); + + // Add a state handle that induces an exception + stateHandle + .getSharedState() + .put( + new StateHandleID("error-handle"), + new ThrowingStateHandle(new IOException("Test exception."))); + + CloseableRegistry closeableRegistry = new CloseableRegistry(); try (RocksDBStateDownloader rocksDBStateDownloader = new RocksDBStateDownloader(5)) { rocksDBStateDownloader.transferAllStateDataToDirectory( - incrementalKeyedStateHandle, dstPath, new CloseableRegistry()); + downloadRequests, closeableRegistry); + fail("Exception is expected"); + } catch (IOException ignore) { } - for (int i = 0; i < contentNum; ++i) { - assertStateContentEqual( - contents[i], dstPath.resolve(String.format("sharedState%d", i))); + // Check that all download directories have been deleted + for (StateHandleDownloadSpec downloadRequest : downloadRequests) { + Assert.assertFalse(downloadRequest.getDownloadDestination().toFile().exists()); } + // The passed in closable registry should not be closed by us on failure. + Assert.assertFalse(closeableRegistry.isClosed()); } private void assertStateContentEqual(byte[] expected, Path path) throws IOException { @@ -165,4 +192,49 @@ public class RocksDBStateDownloaderTest extends TestLogger { return 0; } } + + private byte[][][] createContents(int numRemoteHandles, int numSubHandles) { + Random random = new Random(); + byte[][][] contents = new byte[numRemoteHandles][numSubHandles][]; + for (int i = 0; i < numRemoteHandles; ++i) { + for (int j = 0; j < numSubHandles; ++j) { + contents[i][j] = new byte[random.nextInt(100000) + 1]; + random.nextBytes(contents[i][j]); + } + } + return contents; + } + + private StateHandleDownloadSpec createDownloadRequestForContent( + Path dstPath, byte[][] content, int remoteHandleId) { + int numSubHandles = content.length; + List<StreamStateHandle> handles = new ArrayList<>(numSubHandles); + for (int i = 0; i < numSubHandles; ++i) { + handles.add( + new ByteStreamStateHandle( + String.format("state-%d-%d", remoteHandleId, i), content[i])); + } + + Map<StateHandleID, StreamStateHandle> sharedStates = new HashMap<>(numSubHandles); + Map<StateHandleID, StreamStateHandle> privateStates = new HashMap<>(numSubHandles); + for (int i = 0; i < numSubHandles; ++i) { + sharedStates.put( + new StateHandleID(String.format("sharedState-%d-%d", remoteHandleId, i)), + handles.get(i)); + privateStates.put( + new StateHandleID(String.format("privateState-%d-%d", remoteHandleId, i)), + handles.get(i)); + } + + IncrementalRemoteKeyedStateHandle incrementalKeyedStateHandle = + new IncrementalRemoteKeyedStateHandle( + UUID.randomUUID(), + KeyGroupRange.of(0, 1), + 1, + sharedStates, + privateStates, + handles.get(0)); + + return new StateHandleDownloadSpec(incrementalKeyedStateHandle, dstPath); + } }