This is an automated email from the ASF dual-hosted git repository.

bbejeck pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new da8e4962911 KAFKA-20257: Refactor GlobalStateManagerImpl (#21649)
da8e4962911 is described below

commit da8e4962911c6188fbea6a8a9662a8aa7cdbd944
Author: Nick Telford <[email protected]>
AuthorDate: Tue Mar 10 00:17:08 2026 +0000

    KAFKA-20257: Refactor GlobalStateManagerImpl (#21649)
    
    The main motivation for this is to better align `GlobalStateManagerImpl`
    with `ProcessorStateManager` where we can, as it will make implementing
    elements of KIP-1035 easier, and ultimately lead to easier to maintain
    code.
    
    The primary changes here are:
    
    1. To encapsulate store metadata in
       `GlobalStateManagerImpl.StateStoreMetadata`.
    2. Moves the `restoreState`/`reprocessState` calls to the end of
       `initialize()`.
    
    Consequently, the only behavioural change here is that
    `restoreState`/`reprocessState` now happens after _all_ stores have been
    registered (`registerStore`), instead of at the end of the registration
    process for each store.
    
    This is similar to `ProcessorStateManager`, which registers and
    initializes all stores before beginning state restore (via. the
    `StateUpdater`).
    
    This change should have no real impact on users, as it will still take
    the same amount of time to fully restore all global stores, before they
    begin processing.
---
 .../internals/GlobalStateManagerImpl.java          | 101 ++++++++++++---------
 .../internals/GlobalStateManagerImplTest.java      |  84 ++++++++++-------
 .../org/apache/kafka/test/NoOpReadOnlyStore.java   |  11 ++-
 3 files changed, 117 insertions(+), 79 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
index 56231419e1c..f762338b947 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImpl.java
@@ -74,6 +74,30 @@ import static 
org.apache.kafka.streams.processor.internals.metrics.TaskMetrics.d
  * of Global State Stores. There is only ever 1 instance of this class per 
Application Instance.
  */
 public class GlobalStateManagerImpl implements GlobalStateManager {
+
+    private static class StateStoreMetadata {
+        final StateStore stateStore;
+        final List<TopicPartition> changelogPartitions;
+        final StateRestoreCallback restoreCallback;
+        final Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?>> 
reprocessFactory;
+        final RecordConverter recordConverter;
+        final Map<TopicPartition, Long> highWatermarks;
+
+        StateStoreMetadata(final StateStore stateStore,
+                           final List<TopicPartition> changelogPartitions,
+                           final 
Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?>> reprocessFactory,
+                           final StateRestoreCallback restoreCallback,
+                           final RecordConverter recordConverter,
+                           final Map<TopicPartition, Long> highWatermarks) {
+            this.stateStore = stateStore;
+            this.changelogPartitions = changelogPartitions;
+            this.reprocessFactory = reprocessFactory;
+            this.restoreCallback = reprocessFactory.isPresent() ? null : 
restoreCallback;
+            this.recordConverter = reprocessFactory.isPresent() ? null : 
recordConverter;
+            this.highWatermarks = highWatermarks;
+        }
+    }
+
     private static final long NO_DEADLINE = -1L;
 
     private final Time time;
@@ -90,6 +114,7 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
     private final Set<String> globalStoreNames = new HashSet<>();
     private final Set<String> globalNonPersistentStoresTopics = new 
HashSet<>();
     private final FixedOrderMap<String, Optional<StateStore>> globalStores = 
new FixedOrderMap<>();
+    private final Map<String, StateStoreMetadata> storeMetadata = new 
HashMap<>();
     private InternalProcessorContext<?, ?> globalProcessorContext;
     private DeserializationExceptionHandler deserializationExceptionHandler;
     private ProcessingExceptionHandler processingExceptionHandler;
@@ -177,6 +202,19 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
             }
         });
 
+        // restore or reprocess each registered store using the now-populated 
currentOffsets
+        for (final StateStoreMetadata metadata : storeMetadata.values()) {
+            try {
+                if (metadata.reprocessFactory.isPresent()) {
+                    reprocessState(metadata);
+                } else {
+                    restoreState(metadata);
+                }
+            } finally {
+                globalConsumer.unsubscribe();
+            }
+        }
+
         return Collections.unmodifiableSet(globalStoreNames);
     }
 
@@ -197,7 +235,7 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
     public void registerStore(final StateStore store,
                               final StateRestoreCallback stateRestoreCallback,
                               final CommitCallback ignored) {
-        log.info("Restoring state for global store {}", store.name());
+        log.info("Registering global store {}", store.name());
 
         // TODO (KAFKA-12887): we should not trigger user's exception handler 
for illegal-argument but always
         // fail-crash; in this case we would not need to immediately close the 
state store before throwing
@@ -228,27 +266,10 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
             )
         );
 
-        try {
-            final Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, 
?>> reprocessFactory = topology
-                .storeNameToReprocessOnRestore().getOrDefault(store.name(), 
Optional.empty());
-            if (reprocessFactory.isPresent()) {
-                reprocessState(
-                    topicPartitions,
-                    highWatermarks,
-                    reprocessFactory.get(),
-                    store.name());
-            } else {
-                restoreState(
-                    stateRestoreCallback,
-                    topicPartitions,
-                    highWatermarks,
-                    store.name(),
-                    converterForStore(store)
-                );
-            }
-        } finally {
-            globalConsumer.unsubscribe();
-        }
+        final Optional<InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?>> 
reprocessFactory = topology
+            .storeNameToReprocessOnRestore().getOrDefault(store.name(), 
Optional.empty());
+        storeMetadata.put(store.name(), new StateStoreMetadata(
+            store, topicPartitions, reprocessFactory, stateRestoreCallback, 
converterForStore(store), highWatermarks));
     }
 
     private List<TopicPartition> topicPartitionsForStore(final StateStore 
store) {
@@ -279,14 +300,12 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
     }
 
     @SuppressWarnings({"rawtypes", "unchecked", "resource"})
-    private void reprocessState(final List<TopicPartition> topicPartitions,
-                                final Map<TopicPartition, Long> highWatermarks,
-                                final 
InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?> reprocessFactory,
-                                final String storeName) {
+    private void reprocessState(final StateStoreMetadata storeMetadata) {
+        final InternalTopologyBuilder.ReprocessFactory<?, ?, ?, ?> 
reprocessFactory = storeMetadata.reprocessFactory.get();
         final Processor<?, ?, ?, ?> source = 
reprocessFactory.processorSupplier().get();
         source.init((ProcessorContext) globalProcessorContext);
 
-        for (final TopicPartition topicPartition : topicPartitions) {
+        for (final TopicPartition topicPartition : 
storeMetadata.changelogPartitions) {
             long currentDeadline = NO_DEADLINE;
 
             globalConsumer.assign(Collections.singletonList(topicPartition));
@@ -299,8 +318,8 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                 
globalConsumer.seekToBeginning(Collections.singletonList(topicPartition));
                 offset = getGlobalConsumerOffset(topicPartition);
             }
-            final Long highWatermark = highWatermarks.get(topicPartition);
-            stateRestoreListener.onRestoreStart(topicPartition, storeName, 
offset, highWatermark);
+            final Long highWatermark = 
storeMetadata.highWatermarks.get(topicPartition);
+            stateRestoreListener.onRestoreStart(topicPartition, 
storeMetadata.stateStore.name(), offset, highWatermark);
 
             long restoreCount = 0L;
 
@@ -419,20 +438,16 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
 
                 offset = getGlobalConsumerOffset(topicPartition);
 
-                stateRestoreListener.onBatchRestored(topicPartition, 
storeName, offset, batchRestoreCount);
+                stateRestoreListener.onBatchRestored(topicPartition, 
storeMetadata.stateStore.name(), offset, batchRestoreCount);
             }
-            stateRestoreListener.onRestoreEnd(topicPartition, storeName, 
restoreCount);
+            stateRestoreListener.onRestoreEnd(topicPartition, 
storeMetadata.stateStore.name(), restoreCount);
             checkpointFileCache.put(topicPartition, offset);
 
         }
     }
 
-    private void restoreState(final StateRestoreCallback stateRestoreCallback,
-                              final List<TopicPartition> topicPartitions,
-                              final Map<TopicPartition, Long> highWatermarks,
-                              final String storeName,
-                              final RecordConverter recordConverter) {
-        for (final TopicPartition topicPartition : topicPartitions) {
+    private void restoreState(final StateStoreMetadata storeMetadata) {
+        for (final TopicPartition topicPartition : 
storeMetadata.changelogPartitions) {
             long currentDeadline = NO_DEADLINE;
 
             globalConsumer.assign(Collections.singletonList(topicPartition));
@@ -446,11 +461,11 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                 offset = getGlobalConsumerOffset(topicPartition);
             }
 
-            final Long highWatermark = highWatermarks.get(topicPartition);
+            final Long highWatermark = 
storeMetadata.highWatermarks.get(topicPartition);
             final RecordBatchingStateRestoreCallback stateRestoreAdapter =
-                StateRestoreCallbackAdapter.adapt(stateRestoreCallback);
+                
StateRestoreCallbackAdapter.adapt(storeMetadata.restoreCallback);
 
-            stateRestoreListener.onRestoreStart(topicPartition, storeName, 
offset, highWatermark);
+            stateRestoreListener.onRestoreStart(topicPartition, 
storeMetadata.stateStore.name(), offset, highWatermark);
             long restoreCount = 0L;
 
             while (offset < highWatermark) {
@@ -471,17 +486,17 @@ public class GlobalStateManagerImpl implements 
GlobalStateManager {
                 final List<ConsumerRecord<byte[], byte[]>> restoreRecords = 
new ArrayList<>();
                 for (final ConsumerRecord<byte[], byte[]> record : 
records.records(topicPartition)) {
                     if (record.key() != null) {
-                        restoreRecords.add(recordConverter.convert(record));
+                        
restoreRecords.add(storeMetadata.recordConverter.convert(record));
                     }
                 }
 
                 offset = getGlobalConsumerOffset(topicPartition);
 
                 stateRestoreAdapter.restoreBatch(restoreRecords);
-                stateRestoreListener.onBatchRestored(topicPartition, 
storeName, offset, restoreRecords.size());
+                stateRestoreListener.onBatchRestored(topicPartition, 
storeMetadata.stateStore.name(), offset, restoreRecords.size());
                 restoreCount += restoreRecords.size();
             }
-            stateRestoreListener.onRestoreEnd(topicPartition, storeName, 
restoreCount);
+            stateRestoreListener.onRestoreEnd(topicPartition, 
storeMetadata.stateStore.name(), restoreCount);
             checkpointFileCache.put(topicPartition, offset);
         }
     }
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
index f1d155935b7..5b814b7f636 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/processor/internals/GlobalStateManagerImplTest.java
@@ -38,7 +38,6 @@ import org.apache.kafka.streams.processor.api.Processor;
 import org.apache.kafka.streams.processor.api.ProcessorSupplier;
 import org.apache.kafka.streams.state.TimestampedBytesStore;
 import org.apache.kafka.streams.state.internals.OffsetCheckpoint;
-import org.apache.kafka.streams.state.internals.WrappedStateStore;
 import org.apache.kafka.test.InternalMockProcessorContext;
 import org.apache.kafka.test.MockStateRestoreListener;
 import org.apache.kafka.test.NoOpReadOnlyStore;
@@ -134,8 +133,8 @@ public class GlobalStateManagerImplTest {
         storeToTopic.put(storeName4, t4.topic());
         storeToTopic.put(storeName5, t5.topic());
 
-        store1 = new NoOpReadOnlyStore<>(storeName1, true);
-        store2 = new ConverterStore<>(storeName2, true);
+        store1 = new NoOpReadOnlyStore<>(storeName1, true, 
stateRestoreCallback);
+        store2 = new ConverterStore<>(storeName2, true, stateRestoreCallback);
         store3 = new NoOpReadOnlyStore<>(storeName3);
         store4 = new NoOpReadOnlyStore<>(storeName4);
         store5 = new NoOpReadOnlyStore<>(storeName5);
@@ -289,10 +288,11 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void 
shouldNotConvertValuesIfStoreDoesNotImplementTimestampedBytesStore() {
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(1, 0, t1);
+        processorContext.setStateManger(stateManager);
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback, null);
 
         final KeyValue<byte[], byte[]> restoredRecord = 
stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -301,14 +301,11 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void 
shouldNotConvertValuesIfInnerStoreDoesNotImplementTimestampedBytesStore() {
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(1, 0, t1);
+        processorContext.setStateManger(stateManager);
 
         stateManager.initialize();
-        stateManager.registerStore(
-            new WrappedStateStore<>(store1) {
-            },
-            stateRestoreCallback,
-                null);
 
         final KeyValue<byte[], byte[]> restoredRecord = 
stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -317,10 +314,11 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void shouldConvertValuesIfStoreImplementsTimestampedBytesStore() {
+        initializeConsumer(0, 0, t1, t3, t4, t5);
         initializeConsumer(1, 0, t2);
+        processorContext.setStateManger(stateManager);
 
         stateManager.initialize();
-        stateManager.registerStore(store2, stateRestoreCallback, null);
 
         final KeyValue<byte[], byte[]> restoredRecord = 
stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -329,14 +327,11 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void 
shouldConvertValuesIfInnerStoreImplementsTimestampedBytesStore() {
+        initializeConsumer(0, 0, t1, t3, t4, t5);
         initializeConsumer(1, 0, t2);
+        processorContext.setStateManger(stateManager);
 
         stateManager.initialize();
-        stateManager.registerStore(
-            new WrappedStateStore<>(store2) {
-            },
-            stateRestoreCallback,
-            null);
 
         final KeyValue<byte[], byte[]> restoredRecord = 
stateRestoreCallback.restored.get(0);
         assertEquals(3, restoredRecord.key.length);
@@ -345,11 +340,12 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void shouldRestoreRecordsUpToHighwatermark() {
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(2, 0, t1);
+        processorContext.setStateManger(stateManager);
 
         stateManager.initialize();
 
-        stateManager.registerStore(store1, stateRestoreCallback, null);
         assertEquals(2, stateRestoreCallback.restored.size());
     }
 
@@ -357,11 +353,12 @@ public class GlobalStateManagerImplTest {
     public void shouldListenForRestoreEventsWhenReprocessing() {
         setUpReprocessing();
 
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(6, 1, t1);
+        processorContext.setStateManger(stateManager);
         consumer.setMaxPollRecords(2L);
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback, null);
 
         assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
         assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
@@ -371,13 +368,13 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void shouldListenForRestoreEvents() {
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(6, 1, t1);
+        processorContext.setStateManger(stateManager);
         consumer.setMaxPollRecords(2L);
 
         stateManager.initialize();
 
-        stateManager.registerStore(store1, stateRestoreCallback, null);
-
         assertThat(stateRestoreListener.numBatchRestored, equalTo(2L));
         assertThat(stateRestoreListener.restoreStartOffset, equalTo(1L));
         assertThat(stateRestoreListener.restoreEndOffset, equalTo(7L));
@@ -391,14 +388,15 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void shouldRestoreRecordsFromCheckpointToHighWatermark() throws 
IOException {
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(5, 5, t1);
+        processorContext.setStateManger(stateManager);
 
         final OffsetCheckpoint offsetCheckpoint = new OffsetCheckpoint(new 
File(stateManager.baseDir(),
                                                                                
 StateManagerUtil.CHECKPOINT_FILE_NAME));
         offsetCheckpoint.write(Collections.singletonMap(t1, 5L));
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback, null);
         assertEquals(5, stateRestoreCallback.restored.size());
     }
 
@@ -545,6 +543,8 @@ public class GlobalStateManagerImplTest {
 
     @Test
     public void shouldSkipNullKeysWhenRestoring() {
+        initializeConsumer(0, 0, t2, t3, t4, t5);
+        processorContext.setStateManger(stateManager);
         final HashMap<TopicPartition, Long> startOffsets = new HashMap<>();
         startOffsets.put(t1, 1L);
         final HashMap<TopicPartition, Long> endOffsets = new HashMap<>();
@@ -559,29 +559,41 @@ public class GlobalStateManagerImplTest {
         consumer.addRecord(new ConsumerRecord<>(t1.topic(), t1.partition(), 2, 
expectedKey, expectedValue));
 
         stateManager.initialize();
-        stateManager.registerStore(store1, stateRestoreCallback, null);
         final KeyValue<byte[], byte[]> restoredKv = 
stateRestoreCallback.restored.get(0);
         assertThat(stateRestoreCallback.restored, 
equalTo(Collections.singletonList(KeyValue.pair(restoredKv.key, 
restoredKv.value))));
     }
 
     @Test
     public void shouldCheckpointRestoredOffsetsToFile() throws IOException {
-        stateManager.initialize();
+        initializeConsumer(0, 0, t2, t3, t4, t5);
         initializeConsumer(10, 0, t1);
-        stateManager.registerStore(store1, stateRestoreCallback, null);
+        processorContext.setStateManger(stateManager);
+        stateManager.initialize();
         stateManager.checkpoint();
         stateManager.close();
 
         final Map<TopicPartition, Long> checkpointMap = 
stateManager.changelogOffsets();
-        assertThat(checkpointMap, equalTo(Collections.singletonMap(t1, 10L)));
-        assertThat(readOffsetsCheckpoint(), equalTo(checkpointMap));
+        // changelogOffsets() returns offsets for *all* stores
+        assertThat(checkpointMap, equalTo(mkMap(
+                mkEntry(t1, 10L),
+                mkEntry(t2, 0L),
+                mkEntry(t3, 0L),
+                mkEntry(t4, 0L),
+                mkEntry(t5, 0L)
+        )));
+
+        // checkpoint file only contains persistent store offsets
+        assertThat(readOffsetsCheckpoint(), equalTo(mkMap(
+                mkEntry(t1, 10L),
+                mkEntry(t2, 0L)
+        )));
     }
 
     @Test
     public void shouldSkipGlobalInMemoryStoreOffsetsToFile() throws 
IOException {
-        stateManager.initialize();
+        initializeConsumer(0, 0, t1, t3, t4, t5);
         initializeConsumer(10, 0, t3);
-        stateManager.registerStore(store3, stateRestoreCallback, null);
+        stateManager.initialize();
         stateManager.close();
 
         assertThat(readOffsetsCheckpoint(), equalTo(Collections.emptyMap()));
@@ -935,7 +947,7 @@ public class GlobalStateManagerImplTest {
                 throw new TimeoutException("KABOOM!");
             }
         };
-        initializeConsumer(0, 0, t1, t2, t3, t4);
+        initializeConsumer(0, 0, t1, t2, t3, t4, t5);
 
         streamsConfig = new StreamsConfig(mkMap(
             mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
@@ -978,7 +990,7 @@ public class GlobalStateManagerImplTest {
                 throw new TimeoutException("KABOOM!");
             }
         };
-        initializeConsumer(0, 0, t1, t2, t3, t4);
+        initializeConsumer(0, 0, t1, t2, t3, t4, t5);
 
         streamsConfig = new StreamsConfig(mkMap(
             mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
@@ -1019,7 +1031,7 @@ public class GlobalStateManagerImplTest {
                 throw new TimeoutException("KABOOM!");
             }
         };
-        initializeConsumer(0, 0, t1, t2, t3, t4);
+        initializeConsumer(0, 0, t1, t2, t3, t4, t5);
 
         streamsConfig = new StreamsConfig(mkMap(
             mkEntry(StreamsConfig.APPLICATION_ID_CONFIG, "appId"),
@@ -1095,6 +1107,7 @@ public class GlobalStateManagerImplTest {
                 return super.poll(timeout);
             }
         };
+        initializeConsumer(0, 0, t2, t3, t4, t5);
 
         final HashMap<TopicPartition, Long> startOffsets = new HashMap<>();
         startOffsets.put(t1, 1L);
@@ -1156,11 +1169,11 @@ public class GlobalStateManagerImplTest {
     @Test
     public void shouldFailOnDeserializationErrorsWhenReprocessing() {
         setUpReprocessing();
+        initializeConsumer(0, 0, t1, t2, t3, t4);
         initializeConsumer(2, 0, t5);
+        processorContext.setStateManger(stateManager);
 
-        stateManager.initialize();
-
-        assertThrows(StreamsException.class, () -> 
stateManager.registerStore(store5, stateRestoreCallback, null));
+        assertThrows(StreamsException.class, () -> stateManager.initialize());
     }
 
     @Test
@@ -1217,8 +1230,9 @@ public class GlobalStateManagerImplTest {
 
     private static class ConverterStore<K, V> extends NoOpReadOnlyStore<K, V> 
implements TimestampedBytesStore {
         ConverterStore(final String name,
-                       final boolean rocksdbStore) {
-            super(name, rocksdbStore);
+                       final boolean rocksdbStore,
+                       final StateRestoreCallback stateRestoreCallback) {
+            super(name, rocksdbStore, stateRestoreCallback);
         }
     }
 
diff --git a/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java 
b/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java
index 2eed5faf638..a71c20c648a 100644
--- a/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java
+++ b/streams/src/test/java/org/apache/kafka/test/NoOpReadOnlyStore.java
@@ -18,6 +18,7 @@ package org.apache.kafka.test;
 
 import org.apache.kafka.common.TopicPartition;
 import org.apache.kafka.common.serialization.Serializer;
+import org.apache.kafka.streams.processor.StateRestoreCallback;
 import org.apache.kafka.streams.processor.StateStore;
 import org.apache.kafka.streams.processor.StateStoreContext;
 import org.apache.kafka.streams.query.Position;
@@ -30,6 +31,7 @@ import java.util.Map;
 public class NoOpReadOnlyStore<K, V> implements ReadOnlyKeyValueStore<K, V>, 
StateStore {
     private final String name;
     private final boolean rocksdbStore;
+    private final StateRestoreCallback stateRestoreCallback;
     private boolean open = true;
     public boolean initialized;
     public boolean committed;
@@ -44,8 +46,15 @@ public class NoOpReadOnlyStore<K, V> implements 
ReadOnlyKeyValueStore<K, V>, Sta
 
     public NoOpReadOnlyStore(final String name,
                              final boolean rocksdbStore) {
+        this(name, rocksdbStore, null);
+    }
+
+    public NoOpReadOnlyStore(final String name,
+                             final boolean rocksdbStore,
+                             final StateRestoreCallback stateRestoreCallback) {
         this.name = name;
         this.rocksdbStore = rocksdbStore;
+        this.stateRestoreCallback = stateRestoreCallback;
     }
 
     @Override
@@ -87,7 +96,7 @@ public class NoOpReadOnlyStore<K, V> implements 
ReadOnlyKeyValueStore<K, V>, Sta
             new File(stateStoreContext.stateDir() + File.separator + 
name).mkdir();
         }
         this.initialized = true;
-        stateStoreContext.register(root, (k, v) -> { });
+        stateStoreContext.register(root, stateRestoreCallback != null ? 
stateRestoreCallback : (k, v) -> { });
     }
 
     @Override

Reply via email to